# encoding: utf-8
__author__ = "Dimitrios Karkalousos"
# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/core/optim/lr_scheduler.py
import copy
import dataclasses
import math
import warnings
from functools import partial
from typing import Any, Dict, Optional, Union
import hydra
import torch.optim as optim
import torch.optim.lr_scheduler as pt_scheduler
import torch.utils.data.dataloader as dataloader
from omegaconf import DictConfig, OmegaConf
from torch.optim.lr_scheduler import _LRScheduler # type: ignore
from mridc.core.conf.schedulers import SchedulerParams, get_scheduler_config, register_scheduler_params
from mridc.utils import logging
from mridc.utils.model_utils import maybe_update_config_version
[docs]class WarmupPolicy(_LRScheduler):
"""Adds warmup kwargs and warmup logic to lr policy. All arguments should be passed as kwargs for clarity.
Parameters
----------
warmup_steps: Number of training steps in warmup stage.
warmup_ratio: Ratio of warmup steps to total steps.
max_steps: Total number of steps while training or `None` for infinite training.
Returns
-------
lr: Learning rate for current step.
"""
def __init__(self, optimizer, *, warmup_steps=None, warmup_ratio=None, max_steps=None, min_lr=0.0, last_epoch=-1):
"""
Parameters
----------
optimizer: optimizer
warmup_steps: Number of training steps in warmup stage
warmup_ratio: Ratio of warmup steps to total steps
max_steps: Total number of steps while training or `None` for infinite training
min_lr: Minimum learning rate
last_epoch: Last epoch
"""
if warmup_steps is not None and warmup_ratio is not None:
raise AssertionError("Either use particular number of step or ratio")
if warmup_ratio is not None and max_steps is None:
raise AssertionError("If there is a ratio, there should be a total steps")
# It is necessary to assign all attributes *before* __init__,
# as class is wrapped by an inner class.
self.max_steps = max_steps
if warmup_steps is not None:
self.warmup_steps = warmup_steps
elif warmup_ratio is not None:
self.warmup_steps = int(warmup_ratio * max_steps)
else:
self.warmup_steps = 0
self.min_lr = min_lr
super().__init__(optimizer, last_epoch)
[docs] def get_lr(self):
"""Get learning rate at current step."""
if not self._get_lr_called_within_step:
warnings.warn(
"To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", UserWarning
)
step = self.last_epoch
if 0 < self.warmup_steps >= step:
return self._get_warmup_lr(step)
if step > self.max_steps:
return [self.min_lr for _ in self.base_lrs]
return self._get_lr(step)
def _get_warmup_lr(self, step):
"""Linear warmup"""
lr_val = (step + 1) / (self.warmup_steps + 1)
return [initial_lr * lr_val for initial_lr in self.base_lrs]
def _get_lr(self, step):
"""Simple const lr policy"""
return self.base_lrs
[docs]class SquareRootConstantPolicy(_LRScheduler):
"""Adds warmup kwargs and warmup logic to lr policy. All arguments should be passed as kwargs for clarity.
Parameters
----------
warmup_steps: Number of training steps in warmup stage
warmup_ratio: Ratio of warmup steps to total steps
max_steps: Total number of steps while training or `None` for infinite training
"""
def __init__(
self, optimizer, *, constant_steps=None, constant_ratio=None, max_steps=None, min_lr=0.0, last_epoch=-1
):
"""
Parameters
----------
optimizer: optimizer
constant_steps: Number of training steps in constant stage
constant_ratio: Ratio of constant steps to total steps
max_steps: Total number of steps while training or `None` for infinite training
min_lr: Minimum learning rate
last_epoch: Last epoch
"""
if constant_steps is not None and constant_ratio is not None:
raise AssertionError("Either use particular number of step or ratio")
if constant_ratio is not None and max_steps is None:
raise AssertionError("If there is a ratio, there should be a total steps")
# It is necessary to assign all attributes *before* __init__, as class is wrapped by an inner class.
self.max_steps = max_steps
if constant_steps is not None:
self.constant_steps = constant_steps
elif constant_ratio is not None:
self.constant_steps = int(constant_ratio * max_steps)
else:
self.constant_steps = 0
self.constant_lr = 1 / (constant_steps**0.5)
self.min_lr = min_lr
super().__init__(optimizer, last_epoch)
[docs] def get_lr(self):
"""Get learning rate at current step."""
if not self._get_lr_called_within_step:
warnings.warn(
"To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", UserWarning
)
step = self.last_epoch
if step <= self.constant_steps:
return [self.constant_lr for _ in self.base_lrs]
if step > self.max_steps:
return [self.min_lr for _ in self.base_lrs]
return self._get_lr(step)
def _get_lr(self, step):
"""Simple const lr policy"""
return self.base_lrs
[docs]class WarmupHoldPolicy(WarmupPolicy):
"""
Variant of WarmupPolicy which maintains high learning rate for a defined number of steps. All arguments should be
passed as kwargs for clarity,
Parameters
----------
warmup_steps: Number of training steps in warmup stage
warmup_ratio: Ratio of warmup steps to total steps
hold_steps: Number of training steps to hold the learning rate after warm up
hold_ratio: Ratio of hold steps to total steps
max_steps: Total number of steps while training or `None` for infinite training
Results
-------
Learning rate is linearly increased from 0 to 1 over warmup steps, then linearly decreased from 1 to 0 over hold
steps.
"""
def __init__(
self,
optimizer,
*,
warmup_steps=None,
warmup_ratio=None,
hold_steps=None,
hold_ratio=None,
max_steps=None,
min_lr=0.0,
last_epoch=-1,
):
"""
Parameters
----------
optimizer: optimizer
warmup_steps: Number of training steps in warmup stage.
warmup_ratio: Ratio of warmup steps to total steps.
hold_steps: Number of training steps to hold the learning rate after warm up.
hold_ratio: Ratio of hold steps to total steps.
max_steps: Total number of steps while training or `None` for infinite training.
min_lr: Minimum learning rate.
last_epoch: Last epoch.
"""
if hold_steps is not None and hold_ratio is not None:
raise AssertionError("Either use particular number of step or ratio")
if hold_ratio is not None and max_steps is None:
raise AssertionError("If there is a ratio, there should be a total steps")
self.min_lr = min_lr
self._last_warmup_lr = 0.0
# Necessary to duplicate as class attributes are hidden in inner class
self.max_steps = max_steps
if warmup_steps is not None:
self.warmup_steps = warmup_steps
elif warmup_ratio is not None:
self.warmup_steps = int(warmup_ratio * max_steps)
else:
self.warmup_steps = 0
if hold_steps is not None:
self.hold_steps = hold_steps + self.warmup_steps
elif hold_ratio is not None:
self.hold_steps = int(hold_ratio * max_steps) + self.warmup_steps
else:
self.hold_steps = 0
super().__init__(
optimizer,
warmup_steps=warmup_steps,
warmup_ratio=warmup_ratio,
max_steps=max_steps,
last_epoch=last_epoch,
min_lr=min_lr,
)
[docs] def get_lr(self):
"""Get learning rate at current step."""
if not self._get_lr_called_within_step:
warnings.warn(
"To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning
)
step = self.last_epoch
# Warmup phase
if 0 < self.warmup_steps >= step:
return self._get_warmup_lr(step)
# Hold phase
if self.hold_steps < step >= self.warmup_steps:
return self.base_lrs
if step > self.max_steps:
return [self.min_lr for _ in self.base_lrs]
return self._get_lr(step)
[docs]class WarmupAnnealHoldPolicy(_LRScheduler):
"""
Adds warmup kwargs and warmup logic to lr policy. All arguments should be passed as kwargs for clarity.
Parameters
----------
warmup_steps: Number of training steps in warmup stage
warmup_ratio: Ratio of warmup steps to total steps
max_steps: Total number of steps while training or `None` for infinite training
min_lr: Minimum lr to hold the learning rate after decay at.
constant_steps: Number of steps to keep lr constant at.
constant_ratio: Ratio of steps to keep lr constant.
"""
def __init__(
self,
optimizer,
*,
warmup_steps=None,
warmup_ratio=None,
constant_steps=None,
constant_ratio=None,
max_steps=None,
min_lr=0.0,
last_epoch=-1,
):
"""
Parameters
----------
optimizer: Optimizer
warmup_steps: Number of training steps in warmup stage.
warmup_ratio: Ratio of warmup steps to total steps.
constant_steps: Number of steps to keep lr constant at.
constant_ratio: Ratio of steps to keep lr constant.
max_steps: Total number of steps while training or `None` for infinite training.
min_lr: Minimum lr to hold the learning rate after decay at.
last_epoch: The index of last epoch.
"""
if warmup_steps is not None and warmup_ratio is not None:
raise AssertionError("Either use particular number of step or ratio")
if constant_steps is not None and constant_ratio is not None:
raise AssertionError("Either use constant_steps or constant_ratio")
if warmup_ratio is not None and max_steps is None:
raise AssertionError("If there is a ratio, there should be a total steps")
# It is necessary to assign all attributes *before* __init__, as class is wrapped by an inner class.
self.max_steps = max_steps
if warmup_steps is not None:
self.warmup_steps = warmup_steps
elif warmup_ratio is not None:
self.warmup_steps = int(warmup_ratio * max_steps)
else:
self.warmup_steps = 0
if constant_steps is not None:
self.constant_steps = constant_steps
elif constant_ratio is not None:
self.constant_steps = int(constant_ratio * max_steps)
else:
self.constant_steps = 0
self.decay_steps = max_steps - (self.constant_steps + self.warmup_steps)
self.min_lr = min_lr
super().__init__(optimizer, last_epoch)
[docs] def get_lr(self):
"""Get learning rate at current step."""
if not self._get_lr_called_within_step:
warnings.warn(
"To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", UserWarning
)
step = self.last_epoch
# Warmup steps
if 0 < self.warmup_steps >= step:
return self._get_warmup_lr(step)
# Constant steps after warmup and decay
if self.constant_steps > 0 and (self.warmup_steps + self.decay_steps) < step <= self.max_steps:
return self._get_constant_lr(step)
# Min lr after max steps of updates
if step > self.max_steps:
return [self.min_lr for _ in self.base_lrs]
return self._get_lr(step)
def _get_warmup_lr(self, step):
"""Get learning rate at warmup stage."""
lr_val = (step + 1) / (self.warmup_steps + 1)
return [initial_lr * lr_val for initial_lr in self.base_lrs]
def _get_constant_lr(self, step):
"""Get learning rate at constant stage."""
return [self.min_lr for _ in self.base_lrs]
def _get_lr(self, step):
"""Simple const lr policy"""
return self.base_lrs
def _sqrt_annealing(initial_lr, step, max_steps, min_lr):
"""Anneal learning rate by sqrt."""
mult = ((max_steps - step) / max_steps) ** 0.5
out_lr = initial_lr * mult
out_lr = max(out_lr, min_lr)
return out_lr
def _square_annealing(initial_lr, step, max_steps, min_lr):
"""Anneal learning rate by square."""
mult = ((max_steps - step) / max_steps) ** 2
out_lr = initial_lr * mult
out_lr = max(out_lr, min_lr)
return out_lr
def _cosine_annealing(initial_lr, step, max_steps, min_lr):
"""Anneal learning rate by cosine."""
mult = 0.5 * (1 + math.cos(math.pi * step / max_steps))
return (initial_lr - min_lr) * mult + min_lr
def _linear_warmup_with_cosine_annealing(max_lr, warmup_steps, step, decay_steps, min_lr):
"""Anneal learning rate by linear warmup and cosine annealing."""
if max_lr <= min_lr:
raise AssertionError
# Use linear warmup for the initial part.
if warmup_steps > 0 and step <= warmup_steps:
return max_lr * float(step) / float(warmup_steps)
# For any steps larger than `decay_steps`, use `min_lr`.
if step > warmup_steps + decay_steps:
return min_lr
# If we are done with the warmup period, use the decay style.
num_steps_ = step - warmup_steps
decay_steps_ = decay_steps
decay_ratio = float(num_steps_) / float(decay_steps_)
if decay_ratio < 0.0:
raise AssertionError
if decay_ratio > 1.0:
raise AssertionError
delta_lr = max_lr - min_lr
coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
return min_lr + coeff * delta_lr
def _poly_decay(initial_lr, step, decay_steps, power, min_lr, cycle):
"""Polynomial decay of learning rate."""
if cycle:
multiplier = 1.0 if step == 0 else math.ceil(step / decay_steps)
decay_steps *= multiplier
else:
step = min(step, decay_steps)
p = step / decay_steps
lr = (initial_lr - min_lr) * math.pow(1.0 - p, power)
lr += min_lr
return lr
[docs]class SquareAnnealing(WarmupPolicy):
"""Anneal learning rate by square."""
def __init__(self, optimizer, *, max_steps, min_lr=1e-5, last_epoch=-1, **kwargs):
super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs)
def _get_lr(self, step):
"""Get learning rate at current step."""
return [
_square_annealing(
initial_lr=initial_lr,
step=step - self.warmup_steps,
max_steps=self.max_steps - self.warmup_steps,
min_lr=self.min_lr,
)
for initial_lr in self.base_lrs
]
[docs]class SquareRootAnnealing(WarmupPolicy):
"""Anneal learning rate by square root."""
def __init__(self, optimizer, *, max_steps, min_lr=0, last_epoch=-1, **kwargs):
super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs)
def _get_lr(self, step):
"""Get learning rate at current step."""
return [
_sqrt_annealing(
initial_lr=initial_lr,
step=step,
max_steps=self.max_steps,
min_lr=self.min_lr,
)
for initial_lr in self.base_lrs
]
[docs]class CosineAnnealing(WarmupAnnealHoldPolicy):
"""Anneal learning rate by cosine."""
def __init__(self, optimizer, *, max_steps, min_lr=0, last_epoch=-1, **kwargs):
super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs)
def _get_lr(self, step):
"""Get learning rate at current step."""
for initial_lr in self.base_lrs:
if initial_lr < self.min_lr:
raise ValueError(
f"{self} received an initial learning rate that was lower than the minimum learning rate."
)
return (
[
_cosine_annealing(
initial_lr=initial_lr,
step=step - self.warmup_steps,
max_steps=self.max_steps - self.warmup_steps,
min_lr=self.min_lr,
)
for initial_lr in self.base_lrs
]
if self.constant_steps is None or self.constant_steps == 0
else self._get_linear_warmup_with_cosine_annealing_lr(step)
)
def _get_warmup_lr(self, step):
"""Get the warmup learning rate for the given step."""
if self.constant_steps is None or self.constant_steps == 0:
return super()._get_warmup_lr(step)
# Use linear warmup for the initial part.
return self._get_linear_warmup_with_cosine_annealing_lr(step)
def _get_constant_lr(self, step):
"""Only called when constant_steps is not None and not 0."""
return self._get_linear_warmup_with_cosine_annealing_lr(step)
def _get_linear_warmup_with_cosine_annealing_lr(self, step):
"""Cosine Schedule, slightly different warmup schedule + constant LR at the end."""
return [
_linear_warmup_with_cosine_annealing(
max_lr=self.base_lrs[0],
warmup_steps=self.warmup_steps,
step=step,
decay_steps=self.decay_steps,
min_lr=self.min_lr,
)
for _ in self.base_lrs
]
[docs]class NoamAnnealing(_LRScheduler):
"""Noam learning rate annealing."""
def __init__(
self, optimizer, *, d_model, warmup_steps=None, warmup_ratio=None, max_steps=None, min_lr=0.0, last_epoch=-1
):
self._normalize = d_model ** (-0.5)
if warmup_steps is not None and warmup_ratio is not None:
raise AssertionError("Either use particular number of step or ratio")
if warmup_ratio is not None and max_steps is None:
raise AssertionError("If there is a ratio, there should be a total steps")
# It is necessary to assign all attributes *before* __init__,
# as class is wrapped by an inner class.
self.max_steps = max_steps
if warmup_steps is not None:
self.warmup_steps = warmup_steps
elif warmup_ratio is not None:
self.warmup_steps = int(warmup_ratio * max_steps)
else:
self.warmup_steps = 0
self.min_lr = min_lr
super().__init__(optimizer, last_epoch)
[docs] def get_lr(self):
"""Get learning rate at current step."""
if not self._get_lr_called_within_step:
warnings.warn(
"To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", UserWarning
)
step = max(1, self.last_epoch)
if step > self.max_steps:
return [self.min_lr for _ in self.base_lrs]
for initial_lr in self.base_lrs:
if initial_lr < self.min_lr:
raise ValueError(
f"{self} received an initial learning rate that was lower than the minimum learning rate."
)
return [self._noam_annealing(initial_lr=initial_lr, step=step) for initial_lr in self.base_lrs]
def _noam_annealing(self, initial_lr, step):
"""Noam learning rate annealing."""
mult = self._normalize * min(step ** (-0.5), step * (self.warmup_steps ** (-1.5)))
out_lr = initial_lr * mult
if step > self.warmup_steps:
out_lr = max(out_lr, self.min_lr)
return out_lr
[docs]class WarmupAnnealing(WarmupPolicy):
"""Warmup learning rate annealing."""
def __init__(self, optimizer, *, max_steps, last_epoch=-1, min_lr=0.0, **kwargs):
super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs)
def _get_lr(self, step):
"""Get learning rate at current step."""
delta_lr = self.base_lrs[0] - self.min_lr
mult = (step - self.warmup_steps) / (self.max_steps - self.warmup_steps)
return [self.min_lr + (1 - mult) * delta_lr for _ in self.base_lrs]
[docs]class InverseSquareRootAnnealing(WarmupPolicy):
"""Inverse square root learning rate annealing."""
def __init__(self, optimizer, *, max_steps, last_epoch=-1, min_lr=0.0, **kwargs):
super().__init__(optimizer=optimizer, max_steps=max_steps, **kwargs, last_epoch=last_epoch, min_lr=min_lr)
def _get_lr(self, step):
"""Get learning rate at current step."""
denom = ((step + 1) / (self.warmup_steps + 1)) ** 0.5
return [initial_lr / denom for initial_lr in self.base_lrs]
[docs]class T5InverseSquareRootAnnealing(SquareRootConstantPolicy):
"""Inverse square root learning rate annealing."""
def __init__(self, optimizer, *, max_steps, last_epoch=-1, min_lr=0.0, **kwargs):
super().__init__(optimizer=optimizer, max_steps=max_steps, **kwargs, last_epoch=last_epoch, min_lr=min_lr)
def _get_lr(self, step):
"""Get learning rate at current step."""
return [1 / (step**0.5) for _ in self.base_lrs]
[docs]class PolynomialDecayAnnealing(WarmupPolicy):
"""Polynomial decay learning rate annealing."""
def __init__(self, optimizer, *, max_steps, min_lr=0.0, power=1.0, cycle=False, last_epoch=-1, **kwargs):
self.power = power
self.cycle = cycle
super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs)
def _get_lr(self, step):
"""Get learning rate at current step."""
return [
_poly_decay(
initial_lr,
step=step - self.warmup_steps,
decay_steps=self.max_steps - self.warmup_steps,
power=self.power,
min_lr=self.min_lr,
cycle=self.cycle,
)
for initial_lr in self.base_lrs
]
[docs]class PolynomialHoldDecayAnnealing(WarmupHoldPolicy):
"""Polynomial decay learning rate annealing."""
def __init__(self, optimizer, *, max_steps, min_lr=0.0, power=1.0, cycle=False, last_epoch=-1, **kwargs):
self.power = power
self.cycle = cycle
super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs)
def _get_lr(self, step):
"""Get learning rate at current step."""
return [
_poly_decay(
initial_lr,
step=step - self.hold_steps,
decay_steps=self.max_steps - max(self.warmup_steps, self.hold_steps),
power=self.power,
min_lr=self.min_lr,
cycle=self.cycle,
)
for initial_lr in self.base_lrs
]
[docs]def register_scheduler(name: str, scheduler: _LRScheduler, scheduler_params: SchedulerParams):
"""
Checks if the scheduler name exists in the registry, and if it doesn't, adds it.
This allows custom schedulers to be added and called by name during instantiation.
Parameters
----------
name: Name of the optimizer. Will be used as key to retrieve the optimizer.
scheduler: Scheduler class (inherits from _LRScheduler)
scheduler_params: The parameters as a dataclass of the scheduler
"""
if name in AVAILABLE_SCHEDULERS:
raise ValueError(f"Cannot override pre-existing schedulers. Conflicting scheduler name = {name}")
AVAILABLE_SCHEDULERS[name] = scheduler
sched_name = f"{scheduler.__name__}_params"
register_scheduler_params(name=sched_name, scheduler_params=scheduler_params)
[docs]def get_scheduler(name: str, **kwargs: Optional[Dict[str, Any]]) -> _LRScheduler:
"""
Convenience method to obtain an _LRScheduler class and partially instantiate it with optimizer kwargs.
Parameters
----------
name: Name of the scheduler in the registry.
kwargs: Optional kwargs of the scheduler used during instantiation.
Returns
-------
A partially instantiated _LRScheduler
"""
if name not in AVAILABLE_SCHEDULERS:
raise ValueError(
f"Cannot resolve scheduler{name}'. Available optimizers are : " f"{AVAILABLE_SCHEDULERS.keys()}"
)
scheduler_cls = AVAILABLE_SCHEDULERS[name]
return partial(scheduler_cls, **kwargs)
[docs]def prepare_lr_scheduler(
optimizer: optim.Optimizer,
scheduler_config: Union[Dict[str, Any], DictConfig, None],
train_dataloader: Optional[dataloader.DataLoader] = None,
) -> Optional[Dict[str, Any]]:
"""
Constructs an LR Scheduler (optionally) for a given optimizer, based on a config with the following schema.
Parameters
----------
optimizer: The optimizer to use for the scheduler.
name: <name of optimizer>
lr: <maximal learning rate>
# <additional optimizer arguments>
args:
name: auto # special keyword, resolves to correct optimizer config for given optimizer name
# cls: mridc.core.config.optimizers.NovogradParams # explicit instantiation by class path
params: # optional override parameters for the optimizer config
betas: [0.8, 0.5]
weight_decay: 0.001
scheduler_config: The scheduler config.
name: <name of scheduler>
iters_per_batch: null # computed at runtime; mandatory to have
max_steps: null # computed at runtime or explicitly set here; mandatory to have
# pytorch lightning args <mandatory>
monitor: val_loss
reduce_on_plateau: false
# <scheduler config override>
args:
name: auto # special keyword, resolves to correct optimizer config for given optimizer name
# cls: mridc.core.config.schedulers.CosineAnnealingParams # explicit instantiation by class path
params: # optional override parameters for the optimizer config
warmup_steps: null
warmup_ratio: null
min_lr: 0.0
last_epoch: -1
train_dataloader: Optional requirement, must be passed if "iters_per_batch" is defined instead of "max_steps". \
Used to compute effective "max_steps".
Returns
-------
A dictionary containing the LR Scheduler implementation if the config was successfully parsed along with other \
parameters required by Pytorch Lightning, otherwise None.
"""
if scheduler_config is not None:
scheduler_config = maybe_update_config_version(scheduler_config)
# Build nested dictionary for convenience out of structured objects
if isinstance(scheduler_config, DictConfig):
scheduler_config = OmegaConf.to_container(scheduler_config, resolve=True)
elif dataclasses.is_dataclass(scheduler_config):
# Recursively transform data classes to basic dictionaries
scheduler_config = OmegaConf.create(scheduler_config)
scheduler_config = OmegaConf.to_container(scheduler_config, resolve=True)
# Test to see if config follows above schema
add_max_args_flag = True
interval = "step"
if scheduler_config is not None:
if "args" in scheduler_config:
scheduler_args = scheduler_config.pop("args")
else:
scheduler_args = copy.deepcopy(scheduler_config)
# Remove extra parameters from scheduler_args nest
# Assume all other parameters are to be passed into scheduler constructor
if "name" in scheduler_args and scheduler_args["name"] == "ReduceLROnPlateau":
add_max_args_flag = False
interval = "epoch"
scheduler_args.pop("name", None)
scheduler_args.pop("t_max_epochs", None)
scheduler_args.pop("t_accumulate_grad_batches", None)
scheduler_args.pop("t_limit_train_batches", None)
scheduler_args.pop("t_num_workers", None)
scheduler_args.pop("monitor", None)
scheduler_args.pop("reduce_on_plateau", None)
else:
# Return gracefully in case `sched` was not supplied; inform user
logging.info("Scheduler not initialized as no `sched` config supplied to setup_optimizer()")
return None
# Try instantiation of scheduler params from config class path
if "_target_" in scheduler_args:
scheduler_args_cfg = OmegaConf.create(scheduler_args)
scheduler_conf = hydra.utils.instantiate(scheduler_args_cfg)
scheduler_args = vars(scheduler_conf)
# Get name of the scheduler
scheduler_name = scheduler_conf.__class__.__name__
if "Params" in scheduler_name:
scheduler_name = scheduler_name.replace("Params", "")
else:
# Class path instantiation failed; try resolving "name" component
# Get name of the scheduler
if "name" in scheduler_config:
scheduler_name = scheduler_config["name"]
else:
logging.warning(
"Could not resolve classpath for Scheduler Config, and `name` "
"was not provided either. \n"
"Scheduler cannot be instantiated !"
)
return None
# If class path was not provided, perhaps `name` is provided for resolution
if "name" in scheduler_args:
# If `auto` is passed as name for resolution of optimizer name,
# then lookup optimizer name and resolve its parameter config
if scheduler_args["name"] == "auto":
scheduler_params_name = f"{scheduler_name}Params"
else:
scheduler_params_name = scheduler_args["name"]
# Get override arguments provided in the config yaml file / Dict Config
scheduler_params_override = scheduler_args.get("params", {})
# If params is itself a dict config object provided explicitly in Dict Config
# Resolve to dictionary for convenience
if isinstance(scheduler_params_override, DictConfig):
scheduler_params_override = OmegaConf.to_container(scheduler_params_override, resolve=True)
# Get and instantiate the Config dataclass for this scheduler
scheduler_params_cls = get_scheduler_config(scheduler_params_name, **scheduler_params_override)
scheduler_params = scheduler_params_cls # instantiate the parameters object
scheduler_args = vars(scheduler_params) # extract just the dictionary from the Config object
# Extract value to monitor in losses, if provided.
if "monitor" in scheduler_config:
monitor = scheduler_config.get("monitor")
else:
# Default to train loss
monitor = "loss"
# Store exact max_steps if it is provided
if "max_steps" in scheduler_config and scheduler_config["max_steps"] is not None:
max_steps = scheduler_config["max_steps"]
elif "t_max_epochs" in scheduler_config:
# Compute effective max_steps if t_max_epochs is provided
if train_dataloader is None:
logging.warning(
"As `t_max_epochs` is provided/computed, it is required to pass the train dataloader in order\n"
"to compute effective maximum number of steps.\n"
"Scheduler will not be instantiated !"
)
return None
# Raise exception if neither `max_steps` nor `t_max_epochs` is provided
if scheduler_config.get("t_max_epochs", None) is None:
logging.warning(
"`t_max_epochs` cannot be None when `max_steps` is not not provided.\n"
"This can occur when `train dataloader` is not available to correctly "
"prepare the scheduler.\n"
"Scheduler will not be instantiated !"
)
return None
# Get iters_per_batch
max_epochs = scheduler_config.get("t_max_epochs")
accumulate_grad_batches = scheduler_config.get("t_accumulate_grad_batches")
limit_train_batches = scheduler_config.get("t_limit_train_batches")
num_workers = scheduler_config.get("t_num_workers")
# Compute effective num max_steps
num_samples = len(train_dataloader.dataset) # type: ignore
# we may need to override ModelPT setup_optimization
if train_dataloader.batch_size is not None:
batch_size = train_dataloader.batch_size
elif hasattr(train_dataloader, "batch_sampler") and train_dataloader.batch_sampler is not None:
if train_dataloader.batch_sampler.micro_batch_size is not None:
batch_size = train_dataloader.batch_sampler.micro_batch_size
else:
raise ValueError(f"Could not find batch_size from batch_sampler: {train_dataloader.batch_sampler}")
else:
raise ValueError(f"Could not find batch_size from train_dataloader: {train_dataloader}")
drop_last = train_dataloader.drop_last
max_steps = compute_max_steps(
max_epochs=max_epochs,
accumulate_grad_batches=accumulate_grad_batches,
limit_train_batches=limit_train_batches,
num_workers=num_workers,
num_samples=num_samples,
batch_size=batch_size,
drop_last=drop_last,
)
else:
logging.warning(
"Neither `max_steps` nor `iters_per_batch` were provided to `optim.sched`, "
"cannot compute effective `max_steps` !\n"
"Scheduler will not be instantiated !"
)
return None
# Inject max_steps (effective or provided) into the scheduler config
if add_max_args_flag and scheduler_config.get("name", "") != "ExponentialLR":
scheduler_args["max_steps"] = max_steps
# Get the scheduler class from the config
scheduler_cls = get_scheduler(scheduler_name, **scheduler_args)
# Instantiate the LR schedule
schedule = scheduler_cls(optimizer, **scheduler_args)
logging.info(
'Scheduler "%s" \nwill be used during training (effective maximum steps = %d) - \nParameters : \n(%s)',
str(schedule),
max_steps,
OmegaConf.to_yaml(OmegaConf.create(scheduler_args)),
)
# Wrap the schedule in PTL arguments to perform stepwise computation
# Rather than epoch level computation
reduce_lr_on_plateau = isinstance(schedule, optim.lr_scheduler.ReduceLROnPlateau)
return {
"scheduler": schedule,
"interval": interval,
"frequency": 1,
"monitor": monitor,
"reduce_on_plateau": reduce_lr_on_plateau,
}
[docs]def compute_max_steps(
max_epochs, accumulate_grad_batches, limit_train_batches, num_workers, num_samples, batch_size, drop_last
):
"""Compute effective max_steps from the provided parameters."""
_round = math.floor if drop_last else math.ceil
sampler_num_samples = math.ceil(num_samples / max(1, num_workers))
if drop_last and num_workers > 1:
logging.warning(
"Please note that drop_last is broken in pytorch 1.6.0. We will fix when pytorch 1.7.0 is released"
)
steps_per_epoch = _round(sampler_num_samples / batch_size)
if isinstance(limit_train_batches, int) or limit_train_batches == 0.0:
steps_per_epoch = min(steps_per_epoch, int(limit_train_batches))
elif steps_per_epoch != float("inf"):
# limit_train_batches is a percentage of batches per epoch
steps_per_epoch = int(steps_per_epoch * limit_train_batches)
return math.ceil(steps_per_epoch / accumulate_grad_batches) * max_epochs
AVAILABLE_SCHEDULERS = {
"WarmupPolicy": WarmupPolicy,
"WarmupHoldPolicy": WarmupHoldPolicy,
"SquareAnnealing": SquareAnnealing,
"CosineAnnealing": CosineAnnealing,
"NoamAnnealing": NoamAnnealing,
"WarmupAnnealing": WarmupAnnealing,
"InverseSquareRootAnnealing": InverseSquareRootAnnealing,
"T5InverseSquareRootAnnealing": T5InverseSquareRootAnnealing,
"SquareRootAnnealing": SquareRootAnnealing,
"PolynomialDecayAnnealing": PolynomialDecayAnnealing,
"PolynomialHoldDecayAnnealing": PolynomialHoldDecayAnnealing,
"StepLR": pt_scheduler.StepLR,
"ExponentialLR": pt_scheduler.ExponentialLR,
"ReduceLROnPlateau": pt_scheduler.ReduceLROnPlateau,
"CyclicLR": pt_scheduler.CyclicLR,
}