nfnets package¶
Submodules¶
nfnets.base module¶
-
class
nfnets.base.
WSConv2d
(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')[source]¶ Bases:
torch.nn.modules.conv.Conv2d
Reference: https://github.com/deepmind/deepmind-research/blob/master/nfnets/base.py#L121
-
forward
(input, eps=0.0001)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
nfnets.sgd_agc module¶
-
class
nfnets.sgd_agc.
SGD_AGC
(params, lr=<required parameter>, momentum=0, dampening=0, weight_decay=0, nesterov=False, clipping=0.01, eps=0.001)[source]¶ Bases:
torch.optim.optimizer.Optimizer
Implements stochastic gradient descent (optionally with momentum).
Nesterov momentum is based on the formula from `On the importance of initialization and momentum in deep learning`__ AGC from NFNets: https://arxiv.org/abs/2102.06171.pdf.
- Parameters
params (iterable) – iterable of parameters to optimize or dicts defining parameter groups
lr (float) – learning rate
momentum (float, optional) – momentum factor (default: 0)
weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)
dampening (float, optional) – dampening for momentum (default: 0)
nesterov (bool, optional) – enables Nesterov momentum (default: False)
dampening – dampening for momentum (default: 0.01)
eps (float, optional) – dampening for momentum (default: 1e-3)
Example
>>> optimizer = torch.optim.SGD_AGC(model.parameters(), lr=0.1, momentum=0.9) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step()
Note
The implementation has been adapted from the PyTorch framework and the official NF-Nets paper. The implementation of SGD with Momentum/Nesterov subtly differs from Sutskever et. al. and implementations in some other frameworks.
Considering the specific case of Momentum, the update can be written as
\[\begin{split}\begin{aligned} v_{t+1} & = \mu * v_{t} + g_{t+1}, \\ p_{t+1} & = p_{t} - \text{lr} * v_{t+1}, \end{aligned}\end{split}\]where \(p\), \(g\), \(v\) and \(\mu\) denote the parameters, gradient, velocity, and momentum respectively.
This is in contrast to Sutskever et. al. and other frameworks which employ an update of the form
\[\begin{split}\begin{aligned} v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\ p_{t+1} & = p_{t} - v_{t+1}. \end{aligned}\end{split}\]The Nesterov version is analogously modified.