Module imodels.util.distillation
Expand source code
from sklearn.base import RegressorMixin, BaseEstimator, is_regressor
class DistilledRegressor(BaseEstimator, RegressorMixin):
"""
Class to implement distillation. Currently only supports regression.
Params
------
teacher: initial model to be trained
student: model to be distilled from teacher's predictions
"""
def __init__(self, teacher: BaseEstimator, student: BaseEstimator):
self.teacher = teacher
self.student = student
self._validate_student()
self._check_teacher_type()
def _validate_student(self):
if is_regressor(self.student):
pass
else:
if not hasattr(self.student, "prediction_task"):
raise ValueError("Student must be either a scikit-learn or imodels regressor")
elif self.student.prediction_task == "classification":
raise ValueError("Student must be a regressor")
def _check_teacher_type(self):
if hasattr(self.teacher, "prediction_task"):
self.teacher_type = self.teacher.prediction_task
elif hasattr(self.teacher, "_estimator_type"):
if is_regressor(self.teacher):
self.teacher_type = "regression"
else:
self.teacher_type = "classification"
def set_teacher_params(self, **params):
self.teacher.set_params(**params)
def set_student_params(self, **params):
self.student.set_params(**params)
def fit(self, X, y, **kwargs):
self.teacher.fit(X, y, **kwargs)
if self.teacher_type == "regression":
preds = self.teacher.predict(X)
else:
preds = self.teacher.predict_proba(X)[:, 1]
self.student.fit(X, preds)
def predict(self, X):
return self.student.predict(X)
Classes
class DistilledRegressor (teacher: sklearn.base.BaseEstimator, student: sklearn.base.BaseEstimator)
-
Class to implement distillation. Currently only supports regression. Params
teacher: initial model to be trained student: model to be distilled from teacher's predictions
Expand source code
class DistilledRegressor(BaseEstimator, RegressorMixin): """ Class to implement distillation. Currently only supports regression. Params ------ teacher: initial model to be trained student: model to be distilled from teacher's predictions """ def __init__(self, teacher: BaseEstimator, student: BaseEstimator): self.teacher = teacher self.student = student self._validate_student() self._check_teacher_type() def _validate_student(self): if is_regressor(self.student): pass else: if not hasattr(self.student, "prediction_task"): raise ValueError("Student must be either a scikit-learn or imodels regressor") elif self.student.prediction_task == "classification": raise ValueError("Student must be a regressor") def _check_teacher_type(self): if hasattr(self.teacher, "prediction_task"): self.teacher_type = self.teacher.prediction_task elif hasattr(self.teacher, "_estimator_type"): if is_regressor(self.teacher): self.teacher_type = "regression" else: self.teacher_type = "classification" def set_teacher_params(self, **params): self.teacher.set_params(**params) def set_student_params(self, **params): self.student.set_params(**params) def fit(self, X, y, **kwargs): self.teacher.fit(X, y, **kwargs) if self.teacher_type == "regression": preds = self.teacher.predict(X) else: preds = self.teacher.predict_proba(X)[:, 1] self.student.fit(X, preds) def predict(self, X): return self.student.predict(X)
Ancestors
- sklearn.base.BaseEstimator
- sklearn.base.RegressorMixin
Methods
def fit(self, X, y, **kwargs)
-
Expand source code
def fit(self, X, y, **kwargs): self.teacher.fit(X, y, **kwargs) if self.teacher_type == "regression": preds = self.teacher.predict(X) else: preds = self.teacher.predict_proba(X)[:, 1] self.student.fit(X, preds)
def predict(self, X)
-
Expand source code
def predict(self, X): return self.student.predict(X)
def set_student_params(self, **params)
-
Expand source code
def set_student_params(self, **params): self.student.set_params(**params)
def set_teacher_params(self, **params)
-
Expand source code
def set_teacher_params(self, **params): self.teacher.set_params(**params)