import torch.optim as optim
import torch.nn as nn
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.conv import _ConvNd
from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.training.params import DEFAULT_OPTIMIZER_PARAMS_SGD, DEFAULT_OPTIMIZER_PARAMS_ADAM, \
DEFAULT_OPTIMIZER_PARAMS_RMSPROP, DEFAULT_OPTIMIZER_PARAMS_RMSPROPTF
from super_gradients.training.utils import get_param
from super_gradients.training.utils.optimizers.rmsprop_tf import RMSpropTF
logger = get_logger(__name__)
OPTIMIZERS_DICT = {"SGD": {"class": optim.SGD, "params": DEFAULT_OPTIMIZER_PARAMS_SGD},
"Adam": {"class": optim.Adam, "params": DEFAULT_OPTIMIZER_PARAMS_ADAM},
"RMSprop": {"class": optim.RMSprop, "params": DEFAULT_OPTIMIZER_PARAMS_RMSPROP},
"RMSpropTF": {"class": RMSpropTF, "params": DEFAULT_OPTIMIZER_PARAMS_RMSPROPTF}}
[docs]def separate_zero_wd_params_groups_for_optimizer(module: nn.Module, net_named_params, weight_decay: float):
"""
separate param groups for batchnorm and biases and others with weight decay. return list of param groups in format
required by torch Optimizer classes.
bias + BN with weight decay=0 and the rest with the given weight decay
:param module: train net module.
:param net_named_params: list of params groups, output of SgModule.initialize_param_groups
:param weight_decay: value to set for the non BN and bias parameters
"""
# FIXME - replace usage of ids addresses to find batchnorm and biases params.
# This solution iterate 2 times over module parameters, find a way to iterate only one time.
no_decay_ids = _get_no_decay_param_ids(module)
# split param groups for optimizer
optimizer_param_groups = []
for param_group in net_named_params:
no_decay_params = []
decay_params = []
for name, param in param_group["named_params"]:
if id(param) in no_decay_ids:
no_decay_params.append(param)
else:
decay_params.append(param)
# append two param groups from the original param group, with and without weight decay.
extra_optim_params = {key: param_group[key] for key in param_group
if key not in ["named_params", "weight_decay"]}
optimizer_param_groups.append({"params": no_decay_params, "weight_decay": 0.0, **extra_optim_params})
optimizer_param_groups.append({"params": decay_params, "weight_decay": weight_decay, **extra_optim_params})
return optimizer_param_groups
def _get_no_decay_param_ids(module: nn.Module):
# FIXME - replace usage of ids addresses to find batchnorm and biases params.
# Use other common way to identify torch parameters other than id or layer names
"""
Iterate over module.modules() and returns params id addresses of batch-norm and biases params.
NOTE - ALL MODULES WITH ATTRIBUTES NAMED BIAS AND ARE INSTANCE OF nn.Parameter WILL BE CONSIDERED A BIAS PARAM FOR
ZERO WEIGHT DECAY.
"""
batchnorm_types = (_BatchNorm,)
torch_weight_with_bias_types = (_ConvNd, nn.Linear)
no_decay_ids = []
for name, m in module.named_modules():
if isinstance(m, batchnorm_types):
no_decay_ids.append(id(m.weight))
no_decay_ids.append(id(m.bias))
elif hasattr(m, "bias") and isinstance(m.bias, nn.Parameter):
if not isinstance(m, torch_weight_with_bias_types):
logger.warning(f"Module class: {m.__class__}, have a `bias` parameter attribute but is not instance of"
f" torch primitive modules, this bias parameter will be part of param group with zero"
f" weight decay.")
no_decay_ids.append(id(m.bias))
return no_decay_ids
[docs]def build_optimizer(net, lr, training_params):
"""
Wrapper function for initializing the optimizer
:param net: the nn_module to build the optimizer for
:param lr: initial learning rate
:param training_params: training_parameters
"""
default_optimizer_params = OPTIMIZERS_DICT[training_params.optimizer]["params"]
training_params.optimizer_params = get_param(training_params, 'optimizer_params', default_optimizer_params)
# OPTIMIZER PARAM GROUPS ARE SET USING DEFAULT OR MODEL SPECIFIC INIT
if hasattr(net.module, 'initialize_param_groups'):
# INITIALIZE_PARAM_GROUPS MUST RETURN A LIST OF DICTS WITH 'named_params' AND OPTIMIZER's ATTRIBUTES PER GROUP
net_named_params = net.module.initialize_param_groups(lr, training_params)
else:
net_named_params = [{'named_params': net.named_parameters()}]
if training_params.zero_weight_decay_on_bias_and_bn:
optimizer_training_params = separate_zero_wd_params_groups_for_optimizer(
net.module, net_named_params, training_params.optimizer_params['weight_decay']
)
else:
# Overwrite groups to include params instead of named params
for ind_group, param_group in enumerate(net_named_params):
param_group['params'] = [param[1] for param in list(param_group['named_params'])]
del param_group['named_params']
net_named_params[ind_group] = param_group
optimizer_training_params = net_named_params
# CREATE AN OPTIMIZER OBJECT AND INITIALIZE IT
optimizer_cls = OPTIMIZERS_DICT[training_params.optimizer]["class"]
optimizer = optimizer_cls(optimizer_training_params, lr=lr, **training_params.optimizer_params)
return optimizer