Source code for super_gradients.training.kd_model.kd_model

import torch.nn

from super_gradients.training.models.all_architectures import KD_ARCHITECTURES
from super_gradients.training.models.kd_modules.kd_module import KDModule
from super_gradients.training.sg_model import SgModel
from typing import Union
from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.training import utils as core_utils
from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
from super_gradients.training.utils import get_param
from super_gradients.training.utils.checkpoint_utils import read_ckpt_state_dict, \
    load_checkpoint_to_model
from super_gradients.training.exceptions.kd_model_exceptions import ArchitectureKwargsException, \
    UnsupportedKDArchitectureException, InconsistentParamsException, UnsupportedKDModelArgException, \
    TeacherKnowledgeException, UndefinedNumClassesException
from super_gradients.training.utils.callbacks import KDModelMetricsUpdateCallback
from super_gradients.training.utils.ema import KDModelEMA
logger = get_logger(__name__)


[docs]class KDModel(SgModel): def __init__(self, *args, **kwargs): super(KDModel, self).__init__(*args, **kwargs) self.student_architecture = None self.teacher_architecture = None self.student_arch_params = None self.teacher_arch_params = None
[docs] def build_model(self, # noqa: C901 - too complex architecture: Union[str, KDModule] = 'kd_module', arch_params={}, checkpoint_params={}, *args, **kwargs): """ :param architecture: (Union[str, KDModule]) Defines the network's architecture from models/KD_ARCHITECTURES (default='kd_module') :param arch_params: (dict) Architecture H.P. e.g.: block, num_blocks, num_classes, etc to be passed to kd architecture class (discarded when architecture is KDModule instance) :param checkpoint_params: (dict) A dictionary like object with the following keys/values: student_pretrained_weights: String describing the dataset of the pretrained weights (for example "imagenent") for the student network. teacher_pretrained_weights: String describing the dataset of the pretrained weights (for example "imagenent") for the teacher network. teacher_checkpoint_path: Local path to the teacher's checkpoint. Note that when passing pretrained_weights through teacher_arch_params these weights will be overridden by the pretrained checkpoint. (default=None) load_kd_model_checkpoint: Whether to load an entire KDModule checkpoint (used to continue KD training) (default=False) kd_model_source_ckpt_folder_name: Folder name to load an entire KDModule checkpoint from (self.experiment_name if none is given) to resume KD training (default=None) kd_model_external_checkpoint_path: The path to the external checkpoint to be loaded. Can be absolute or relative (ie: path/to/checkpoint.pth). If provided, will automatically attempt to load the checkpoint even if the load_checkpoint flag is not provided. (deafult=None) :keyword student_architecture: (Union[str, SgModule]) Defines the student's architecture from models/ALL_ARCHITECTURES (when str), or directly defined the student network (when SgModule). :keyword teacher_architecture: (Union[str, SgModule]) Defines the teacher's architecture from models/ALL_ARCHITECTURES (when str), or directly defined the teacher network (when SgModule). :keyword student_arch_params: (dict) Architecture H.P. e.g.: block, num_blocks, num_classes, etc for student net. (deafult={}) :keyword teacher_arch_params: (dict) Architecture H.P. e.g.: block, num_blocks, num_classes, etc for teacher net. (deafult={}) :keyword run_teacher_on_eval: (bool)- whether to run self.teacher at eval mode regardless of self.train(mode) """ kwargs.setdefault("student_architecture", None) kwargs.setdefault("teacher_architecture", None) kwargs.setdefault("student_arch_params", {}) kwargs.setdefault("teacher_arch_params", {}) kwargs.setdefault("run_teacher_on_eval", False) self._validate_args(arch_params, architecture, checkpoint_params, **kwargs) self.student_architecture = kwargs.get("student_architecture") self.teacher_architecture = kwargs.get("teacher_architecture") self.student_arch_params = kwargs.get("student_arch_params") self.teacher_arch_params = kwargs.get("teacher_arch_params") super(KDModel, self).build_model(architecture=architecture, arch_params=arch_params, checkpoint_params=checkpoint_params, **kwargs)
def _validate_args(self, arch_params, architecture, checkpoint_params, **kwargs): student_architecture = get_param(kwargs, "student_architecture") teacher_architecture = get_param(kwargs, "teacher_architecture") student_arch_params = get_param(kwargs, "student_arch_params") teacher_arch_params = get_param(kwargs, "teacher_arch_params") if get_param(checkpoint_params, 'pretrained_weights') is not None: raise UnsupportedKDModelArgException("pretrained_weights", "checkpoint_params") if not isinstance(architecture, KDModule): if student_architecture is None or teacher_architecture is None: raise ArchitectureKwargsException() if architecture not in KD_ARCHITECTURES.keys(): raise UnsupportedKDArchitectureException(architecture) # DERIVE NUMBER OF CLASSES FROM DATASET INTERFACE IF NOT SPECIFIED OR ARCH PARAMS FOR TEACHER AND STUDENT self._validate_num_classes(student_arch_params, teacher_arch_params) arch_params['num_classes'] = student_arch_params['num_classes'] # MAKE SURE TEACHER'S PRETRAINED NUM CLASSES EQUALS TO THE ONES BELONGING TO STUDENT AS WE CAN'T REPLACE # THE TEACHER'S HEAD teacher_pretrained_weights = core_utils.get_param(checkpoint_params, 'teacher_pretrained_weights', default_val=None) if teacher_pretrained_weights is not None: teacher_pretrained_num_classes = PRETRAINED_NUM_CLASSES[teacher_pretrained_weights] if teacher_pretrained_num_classes != teacher_arch_params['num_classes']: raise InconsistentParamsException("Pretrained dataset number of classes", "teacher's arch params", "number of classes", "student's number of classes") teacher_checkpoint_path = get_param(checkpoint_params, "teacher_checkpoint_path") load_kd_model_checkpoint = get_param(checkpoint_params, "load_checkpoint") # CHECK THAT TEACHER NETWORK HOLDS KNOWLEDGE FOR THE STUDENT TO LEARN FROM OR THAT WE ARE LOADING AN ENTIRE KD if not (teacher_pretrained_weights or teacher_checkpoint_path or load_kd_model_checkpoint or isinstance(teacher_architecture, torch.nn.Module)): raise TeacherKnowledgeException() def _validate_num_classes(self, student_arch_params, teacher_arch_params): """ Checks validity of num_classes for num_classes (i.e existence and consistency between subnets) :param student_arch_params: (dict) Architecture H.P. e.g.: block, num_blocks, num_classes, etc for student :param teacher_arch_params: (dict) Architecture H.P. e.g.: block, num_blocks, num_classes, etc for teacher """ self._validate_subnet_num_classes(student_arch_params) self._validate_subnet_num_classes(teacher_arch_params) if teacher_arch_params['num_classes'] != student_arch_params['num_classes']: raise InconsistentParamsException("num_classes", "student_arch_params", "num_classes", "teacher_arch_params") def _validate_subnet_num_classes(self, subnet_arch_params): """ Derives num_classes in student_arch_params/teacher_arch_params from dataset interface or raises an error when none is given :param subnet_arch_params: Arch params for student/teacher """ if 'num_classes' not in subnet_arch_params.keys(): if self.dataset_interface is None: raise UndefinedNumClassesException() else: subnet_arch_params['num_classes'] = len(self.classes) def _instantiate_net(self, architecture: Union[KDModule, KDModule.__class__, str], arch_params: dict, checkpoint_params: dict, *args, **kwargs) -> tuple: """ Instantiates kd_module according to architecture and arch_params, handles pretrained weights for the student and teacher networks, and the required module manipulation (i.e head replacement) for the teacher network. :param architecture: String, KDModule or uninstantiated KDModule class describing the netowrks architecture. :param arch_params: Architecture's parameters passed to networks c'tor. :param checkpoint_params: checkpoint loading related parameters dictionary with 'pretrained_weights' key, s.t it's value is a string describing the dataset of the pretrained weights (for example "imagenent"). :return: instantiated netowrk i.e KDModule, architecture_class (will be none when architecture is not str) """ student_architecture = get_param(kwargs, "student_architecture") teacher_architecture = get_param(kwargs, "teacher_architecture") student_arch_params = get_param(kwargs, "student_arch_params") teacher_arch_params = get_param(kwargs, "teacher_arch_params") student_arch_params = core_utils.HpmStruct(**student_arch_params) teacher_arch_params = core_utils.HpmStruct(**teacher_arch_params) student_pretrained_weights = get_param(checkpoint_params, 'student_pretrained_weights') teacher_pretrained_weights = get_param(checkpoint_params, 'teacher_pretrained_weights') student = super()._instantiate_net(student_architecture, student_arch_params, {"pretrained_weights": student_pretrained_weights}) teacher = super()._instantiate_net(teacher_architecture, teacher_arch_params, {"pretrained_weights": teacher_pretrained_weights}) run_teacher_on_eval = get_param(kwargs, "run_teacher_on_eval", default_val=False) if isinstance(architecture, str): architecture_cls = KD_ARCHITECTURES[architecture] net = architecture_cls(arch_params=arch_params, student=student, teacher=teacher, run_teacher_on_eval=run_teacher_on_eval) elif isinstance(architecture, KDModule.__class__): net = architecture(arch_params=arch_params, student=student, teacher=teacher, run_teacher_on_eval=run_teacher_on_eval) else: net = architecture return net def _load_checkpoint_to_model(self): """ Initializes teacher weights with teacher_checkpoint_path if needed, then handles checkpoint loading for the entire KD network following the same logic as in SgModel. """ teacher_checkpoint_path = get_param(self.checkpoint_params, "teacher_checkpoint_path") teacher_net = self.net.module.teacher if teacher_checkpoint_path is not None: # WARN THAT TEACHER_CKPT WILL OVERRIDE TEACHER'S PRETRAINED WEIGHTS teacher_pretrained_weights = get_param(self.checkpoint_params, "teacher_pretrained_weights") if teacher_pretrained_weights: logger.warning( teacher_checkpoint_path + " checkpoint is " "overriding " + teacher_pretrained_weights + " for teacher model") # ALWAYS LOAD ITS EMA IF IT EXISTS load_teachers_ema = 'ema_net' in read_ckpt_state_dict(teacher_checkpoint_path).keys() load_checkpoint_to_model(ckpt_local_path=teacher_checkpoint_path, load_backbone=False, net=teacher_net, strict='no_key_matching', load_weights_only=True, load_ema_as_net=load_teachers_ema) super(KDModel, self)._load_checkpoint_to_model() def _add_metrics_update_callback(self, phase): """ Adds KDModelMetricsUpdateCallback to be fired at phase :param phase: Phase for the metrics callback to be fired at """ self.phase_callbacks.append(KDModelMetricsUpdateCallback(phase)) def _get_hyper_param_config(self): """ Creates a training hyper param config for logging with additional KD related hyper params. """ hyper_param_config = super()._get_hyper_param_config() hyper_param_config.update({"student_architecture": self.student_architecture, "teacher_architecture": self.teacher_architecture, "student_arch_params": self.student_arch_params, "teacher_arch_params": self.teacher_arch_params }) return hyper_param_config def _instantiate_ema_model(self, decay: float = 0.9999, beta: float = 15, exp_activation: bool = True) -> KDModelEMA: """Instantiate KD ema model for KDModule. If the model is of class KDModule, the instance will be adapted to work on knowledge distillation. :param decay: the maximum decay value. as the training process advances, the decay will climb towards this value until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay) :param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to its final value. beta=15 is ~40% of the training process. :param exp_activation: """ return KDModelEMA(self.net, decay, beta, exp_activation) def _save_best_checkpoint(self, epoch, state): """ Overrides parent best_ckpt saving to modify the state dict so that we only save the student. """ if self.ema: best_net = core_utils.WrappedModel(self.ema_model.ema.module.student) state.pop("ema_net") else: best_net = core_utils.WrappedModel(self.net.module.student) state["net"] = best_net.state_dict() self.sg_logger.add_checkpoint(tag=self.ckpt_best_name, state_dict=state, global_step=epoch)