Source code for super_gradients.training.utils.optimizer_utils

import torch.optim as optim
import torch.nn as nn

from super_gradients.training.params import DEFAULT_OPTIMIZER_PARAMS_SGD, DEFAULT_OPTIMIZER_PARAMS_ADAM, \
    DEFAULT_OPTIMIZER_PARAMS_RMSPROP, DEFAULT_OPTIMIZER_PARAMS_RMSPROPTF
from super_gradients.training.utils import get_param
from super_gradients.training.utils.optimizers.rmsprop_tf import RMSpropTF

OPTIMIZERS_DICT = {"SGD": {"class": optim.SGD, "params": DEFAULT_OPTIMIZER_PARAMS_SGD},
                   "Adam": {"class": optim.Adam, "params": DEFAULT_OPTIMIZER_PARAMS_ADAM},
                   "RMSprop": {"class": optim.RMSprop, "params": DEFAULT_OPTIMIZER_PARAMS_RMSPROP},
                   "RMSpropTF": {"class": RMSpropTF, "params": DEFAULT_OPTIMIZER_PARAMS_RMSPROPTF}}


[docs]def separate_zero_wd_params_groups_for_optimizer(module: nn.Module, net_named_params, weight_decay: float): """ separate param groups for batchnorm and biases and others with weight decay. return list of param groups in format required by torch Optimizer classes. bias + BN with weight decay=0 and the rest with the given weight decay :param module: train net module. :param net_named_params: list of params groups, output of SgModule.initialize_param_groups :param weight_decay: value to set for the non BN and bias parameters """ # FIXME - replace usage of ids addresses to find batchnorm and biases params. # This solution iterate 2 times over module parameters, find a way to iterate only one time. no_decay_ids = _get_no_decay_param_ids(module) # split param groups for optimizer optimizer_param_groups = [] for param_group in net_named_params: no_decay_params = [] decay_params = [] for name, param in param_group["named_params"]: if id(param) in no_decay_ids: no_decay_params.append(param) else: decay_params.append(param) # append two param groups from the original param group, with and without weight decay. extra_optim_params = {key: param_group[key] for key in param_group if key not in ["named_params", "weight_decay"]} optimizer_param_groups.append({"params": no_decay_params, "weight_decay": 0.0, **extra_optim_params}) optimizer_param_groups.append({"params": decay_params, "weight_decay": weight_decay, **extra_optim_params}) return optimizer_param_groups
def _get_no_decay_param_ids(module: nn.Module): # FIXME - replace usage of ids addresses to find batchnorm and biases params. # Use other common way to identify torch parameters other than id or layer names """ Iterate over module.modules() and returns params id addresses of batch-norm and biases params. """ weight_types = ( nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d, nn.Linear ) batchnorm_types = ( nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d ) no_decay_ids = [] for name, m in module.named_modules(): if isinstance(m, batchnorm_types): no_decay_ids.append(id(m.weight)) no_decay_ids.append(id(m.bias)) elif isinstance(m, weight_types) and m.bias is not None: no_decay_ids.append(id(m.bias)) return no_decay_ids
[docs]def build_optimizer(net, lr, training_params): """ Wrapper function for initializing the optimizer :param net: the nn_module to build the optimizer for :param lr: initial learning rate :param training_params: training_parameters """ default_optimizer_params = OPTIMIZERS_DICT[training_params.optimizer]["params"] training_params.optimizer_params = get_param(training_params, 'optimizer_params', default_optimizer_params) # OPTIMIZER PARAM GROUPS ARE SET USING DEFAULT OR MODEL SPECIFIC INIT if hasattr(net.module, 'initialize_param_groups'): # INITIALIZE_PARAM_GROUPS MUST RETURN A LIST OF DICTS WITH 'named_params' AND OPTIMIZER's ATTRIBUTES PER GROUP net_named_params = net.module.initialize_param_groups(lr, training_params) else: net_named_params = [{'named_params': net.named_parameters()}] if training_params.zero_weight_decay_on_bias_and_bn: optimizer_training_params = separate_zero_wd_params_groups_for_optimizer( net.module, net_named_params, training_params.optimizer_params['weight_decay'] ) else: # Overwrite groups to include params instead of named params for ind_group, param_group in enumerate(net_named_params): param_group['params'] = [param[1] for param in list(param_group['named_params'])] del param_group['named_params'] net_named_params[ind_group] = param_group optimizer_training_params = net_named_params # CREATE AN OPTIMIZER OBJECT AND INITIALIZE IT optimizer_cls = OPTIMIZERS_DICT[training_params.optimizer]["class"] optimizer = optimizer_cls(optimizer_training_params, lr=lr, **training_params.optimizer_params) return optimizer