Source code for super_gradients.training.kd_trainer.kd_trainer

import hydra
import torch.nn
from omegaconf import DictConfig
from torch.utils.data import DataLoader

from super_gradients.common import MultiGPUMode
from super_gradients.training.dataloaders import dataloaders
from super_gradients.training.models import SgModule
from super_gradients.training.models.all_architectures import KD_ARCHITECTURES
from super_gradients.training.models.kd_modules.kd_module import KDModule
from super_gradients.training.sg_trainer import Trainer
from typing import Union
from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.training import utils as core_utils, models
from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
from super_gradients.training.utils import get_param, HpmStruct
from super_gradients.training.utils.checkpoint_utils import read_ckpt_state_dict, \
    load_checkpoint_to_model
from super_gradients.training.exceptions.kd_trainer_exceptions import ArchitectureKwargsException, \
    UnsupportedKDArchitectureException, InconsistentParamsException, UnsupportedKDModelArgException, \
    TeacherKnowledgeException, UndefinedNumClassesException
from super_gradients.training.utils.callbacks import KDModelMetricsUpdateCallback
from super_gradients.training.utils.ema import KDModelEMA
from super_gradients.training.utils.sg_trainer_utils import parse_args

logger = get_logger(__name__)


[docs]class KDTrainer(Trainer): def __init__(self, experiment_name: str, device: str = None, multi_gpu: Union[MultiGPUMode, str] = MultiGPUMode.OFF, ckpt_root_dir: str = None): super().__init__(experiment_name=experiment_name, device=device, multi_gpu=multi_gpu, ckpt_root_dir=ckpt_root_dir) self.student_architecture = None self.teacher_architecture = None self.student_arch_params = None self.teacher_arch_params = None
[docs] @classmethod def train_from_config(cls, cfg: Union[DictConfig, dict]) -> None: """ Trains according to cfg recipe configuration. @param cfg: The parsed DictConfig from yaml recipe files @return: output of kd_trainer.train(...) (i.e results tuple) """ # INSTANTIATE ALL OBJECTS IN CFG cfg = hydra.utils.instantiate(cfg) kwargs = parse_args(cfg, cls.__init__) trainer = KDTrainer(**kwargs) # INSTANTIATE DATA LOADERS train_dataloader = dataloaders.get(name=cfg.train_dataloader, dataset_params=cfg.dataset_params.train_dataset_params, dataloader_params=cfg.dataset_params.train_dataloader_params) val_dataloader = dataloaders.get(name=cfg.val_dataloader, dataset_params=cfg.dataset_params.val_dataset_params, dataloader_params=cfg.dataset_params.val_dataloader_params) student = models.get(cfg.student_architecture, arch_params=cfg.student_arch_params, strict_load=cfg.student_checkpoint_params.strict_load, pretrained_weights=cfg.student_checkpoint_params.pretrained_weights, checkpoint_path=cfg.student_checkpoint_params.checkpoint_path, load_backbone=cfg.student_checkpoint_params.load_backbone) teacher = models.get(cfg.teacher_architecture, arch_params=cfg.teacher_arch_params, strict_load=cfg.teacher_checkpoint_params.strict_load, pretrained_weights=cfg.teacher_checkpoint_params.pretrained_weights, checkpoint_path=cfg.teacher_checkpoint_params.checkpoint_path, load_backbone=cfg.teacher_checkpoint_params.load_backbone) # TRAIN trainer.train(training_params=cfg.training_hyperparams, student=student, teacher=teacher, kd_architecture=cfg.architecture, kd_arch_params=cfg.arch_params, run_teacher_on_eval=cfg.run_teacher_on_eval, train_loader=train_dataloader, valid_loader=val_dataloader)
def _validate_args(self, arch_params, architecture, checkpoint_params, **kwargs): student_architecture = get_param(kwargs, "student_architecture") teacher_architecture = get_param(kwargs, "teacher_architecture") student_arch_params = get_param(kwargs, "student_arch_params") teacher_arch_params = get_param(kwargs, "teacher_arch_params") if get_param(checkpoint_params, 'pretrained_weights') is not None: raise UnsupportedKDModelArgException("pretrained_weights", "checkpoint_params") if not isinstance(architecture, KDModule): if student_architecture is None or teacher_architecture is None: raise ArchitectureKwargsException() if architecture not in KD_ARCHITECTURES.keys(): raise UnsupportedKDArchitectureException(architecture) # DERIVE NUMBER OF CLASSES FROM DATASET INTERFACE IF NOT SPECIFIED OR ARCH PARAMS FOR TEACHER AND STUDENT self._validate_num_classes(student_arch_params, teacher_arch_params) arch_params['num_classes'] = student_arch_params['num_classes'] # MAKE SURE TEACHER'S PRETRAINED NUM CLASSES EQUALS TO THE ONES BELONGING TO STUDENT AS WE CAN'T REPLACE # THE TEACHER'S HEAD teacher_pretrained_weights = core_utils.get_param(checkpoint_params, 'teacher_pretrained_weights', default_val=None) if teacher_pretrained_weights is not None: teacher_pretrained_num_classes = PRETRAINED_NUM_CLASSES[teacher_pretrained_weights] if teacher_pretrained_num_classes != teacher_arch_params['num_classes']: raise InconsistentParamsException("Pretrained dataset number of classes", "teacher's arch params", "number of classes", "student's number of classes") teacher_checkpoint_path = get_param(checkpoint_params, "teacher_checkpoint_path") load_kd_model_checkpoint = get_param(checkpoint_params, "load_checkpoint") # CHECK THAT TEACHER NETWORK HOLDS KNOWLEDGE FOR THE STUDENT TO LEARN FROM OR THAT WE ARE LOADING AN ENTIRE KD if not (teacher_pretrained_weights or teacher_checkpoint_path or load_kd_model_checkpoint or isinstance( teacher_architecture, torch.nn.Module)): raise TeacherKnowledgeException() def _validate_num_classes(self, student_arch_params, teacher_arch_params): """ Checks validity of num_classes for num_classes (i.e existence and consistency between subnets) :param student_arch_params: (dict) Architecture H.P. e.g.: block, num_blocks, num_classes, etc for student :param teacher_arch_params: (dict) Architecture H.P. e.g.: block, num_blocks, num_classes, etc for teacher """ self._validate_subnet_num_classes(student_arch_params) self._validate_subnet_num_classes(teacher_arch_params) if teacher_arch_params['num_classes'] != student_arch_params['num_classes']: raise InconsistentParamsException("num_classes", "student_arch_params", "num_classes", "teacher_arch_params") def _validate_subnet_num_classes(self, subnet_arch_params): """ Derives num_classes in student_arch_params/teacher_arch_params from dataset interface or raises an error when none is given :param subnet_arch_params: Arch params for student/teacher """ if 'num_classes' not in subnet_arch_params.keys(): if self.dataset_interface is None: raise UndefinedNumClassesException() else: subnet_arch_params['num_classes'] = len(self.classes) def _instantiate_net(self, architecture: Union[KDModule, KDModule.__class__, str], arch_params: dict, checkpoint_params: dict, *args, **kwargs) -> tuple: """ Instantiates kd_module according to architecture and arch_params, handles pretrained weights for the student and teacher networks, and the required module manipulation (i.e head replacement) for the teacher network. :param architecture: String, KDModule or uninstantiated KDModule class describing the netowrks architecture. :param arch_params: Architecture's parameters passed to networks c'tor. :param checkpoint_params: checkpoint loading related parameters dictionary with 'pretrained_weights' key, s.t it's value is a string describing the dataset of the pretrained weights (for example "imagenent"). :return: instantiated netowrk i.e KDModule, architecture_class (will be none when architecture is not str) """ student_architecture = get_param(kwargs, "student_architecture") teacher_architecture = get_param(kwargs, "teacher_architecture") student_arch_params = get_param(kwargs, "student_arch_params") teacher_arch_params = get_param(kwargs, "teacher_arch_params") student_arch_params = core_utils.HpmStruct(**student_arch_params) teacher_arch_params = core_utils.HpmStruct(**teacher_arch_params) student_pretrained_weights = get_param(checkpoint_params, 'student_pretrained_weights') teacher_pretrained_weights = get_param(checkpoint_params, 'teacher_pretrained_weights') student = super()._instantiate_net(student_architecture, student_arch_params, {"pretrained_weights": student_pretrained_weights}) teacher = super()._instantiate_net(teacher_architecture, teacher_arch_params, {"pretrained_weights": teacher_pretrained_weights}) run_teacher_on_eval = get_param(kwargs, "run_teacher_on_eval", default_val=False) return self._instantiate_kd_net(arch_params, architecture, run_teacher_on_eval, student, teacher) def _instantiate_kd_net(self, arch_params, architecture, run_teacher_on_eval, student, teacher): if isinstance(architecture, str): architecture_cls = KD_ARCHITECTURES[architecture] net = architecture_cls(arch_params=arch_params, student=student, teacher=teacher, run_teacher_on_eval=run_teacher_on_eval) elif isinstance(architecture, KDModule.__class__): net = architecture(arch_params=arch_params, student=student, teacher=teacher, run_teacher_on_eval=run_teacher_on_eval) else: net = architecture return net def _load_checkpoint_to_model(self): """ Initializes teacher weights with teacher_checkpoint_path if needed, then handles checkpoint loading for the entire KD network following the same logic as in Trainer. """ teacher_checkpoint_path = get_param(self.checkpoint_params, "teacher_checkpoint_path") teacher_net = self.net.module.teacher if teacher_checkpoint_path is not None: # WARN THAT TEACHER_CKPT WILL OVERRIDE TEACHER'S PRETRAINED WEIGHTS teacher_pretrained_weights = get_param(self.checkpoint_params, "teacher_pretrained_weights") if teacher_pretrained_weights: logger.warning( teacher_checkpoint_path + " checkpoint is " "overriding " + teacher_pretrained_weights + " for teacher model") # ALWAYS LOAD ITS EMA IF IT EXISTS load_teachers_ema = 'ema_net' in read_ckpt_state_dict(teacher_checkpoint_path).keys() load_checkpoint_to_model(ckpt_local_path=teacher_checkpoint_path, load_backbone=False, net=teacher_net, strict='no_key_matching', load_weights_only=True, load_ema_as_net=load_teachers_ema) super(KDTrainer, self)._load_checkpoint_to_model() def _add_metrics_update_callback(self, phase): """ Adds KDModelMetricsUpdateCallback to be fired at phase :param phase: Phase for the metrics callback to be fired at """ self.phase_callbacks.append(KDModelMetricsUpdateCallback(phase)) def _get_hyper_param_config(self): """ Creates a training hyper param config for logging with additional KD related hyper params. """ hyper_param_config = super()._get_hyper_param_config() hyper_param_config.update({"student_architecture": self.student_architecture, "teacher_architecture": self.teacher_architecture, "student_arch_params": self.student_arch_params, "teacher_arch_params": self.teacher_arch_params }) return hyper_param_config def _instantiate_ema_model(self, decay: float = 0.9999, beta: float = 15, exp_activation: bool = True) -> KDModelEMA: """Instantiate KD ema model for KDModule. If the model is of class KDModule, the instance will be adapted to work on knowledge distillation. :param decay: the maximum decay value. as the training process advances, the decay will climb towards this value until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay) :param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to its final value. beta=15 is ~40% of the training process. :param exp_activation: """ return KDModelEMA(self.net, decay, beta, exp_activation) def _save_best_checkpoint(self, epoch, state): """ Overrides parent best_ckpt saving to modify the state dict so that we only save the student. """ if self.ema: best_net = core_utils.WrappedModel(self.ema_model.ema.module.student) state.pop("ema_net") else: best_net = core_utils.WrappedModel(self.net.module.student) state["net"] = best_net.state_dict() self.sg_logger.add_checkpoint(tag=self.ckpt_best_name, state_dict=state, global_step=epoch)
[docs] def train(self, model: KDModule = None, training_params: dict = dict(), student: SgModule = None, teacher: torch.nn.Module = None, kd_architecture: Union[KDModule.__class__, str] = 'kd_module', kd_arch_params: dict = dict(), run_teacher_on_eval=False, train_loader: DataLoader = None, valid_loader: DataLoader = None, *args, **kwargs): """ Trains the student network (wrapped in KDModule network). :param model: KDModule, network to train. When none is given will initialize KDModule according to kd_architecture, student and teacher (default=None) :param training_params: dict, Same as in Trainer.train() :param student: SgModule - the student trainer :param teacher: torch.nn.Module- the teacher trainer :param kd_architecture: KDModule architecture to use, currently only 'kd_module' is supported (default='kd_module'). :param kd_arch_params: architecture params to pas to kd_architecture constructor. :param run_teacher_on_eval: bool- whether to run self.teacher at eval mode regardless of self.train(mode) :param train_loader: Dataloader for train set. :param valid_loader: Dataloader for validation. """ kd_net = self.net or model if kd_net is None: if student is None or teacher is None: raise ValueError("Must pass student and teacher models or net (KDModule).") kd_net = self._instantiate_kd_net(arch_params=HpmStruct(**kd_arch_params), architecture=kd_architecture, run_teacher_on_eval=run_teacher_on_eval, student=student, teacher=teacher) super(KDTrainer, self).train(model=kd_net, training_params=training_params, train_loader=train_loader, valid_loader=valid_loader)