nfnets package

Submodules

nfnets.base module

class nfnets.base.WSConv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')[source]

Bases: torch.nn.modules.conv.Conv2d

Reference: https://github.com/deepmind/deepmind-research/blob/master/nfnets/base.py#L121

forward(input, eps=0.0001)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

standardize_weight(eps)[source]

nfnets.sgd_agc module

class nfnets.sgd_agc.SGD_AGC(params, lr=<required parameter>, momentum=0, dampening=0, weight_decay=0, nesterov=False, clipping=0.01, eps=0.001)[source]

Bases: torch.optim.optimizer.Optimizer

Implements stochastic gradient descent (optionally with momentum).

Nesterov momentum is based on the formula from `On the importance of initialization and momentum in deep learning`__ AGC from NFNets: https://arxiv.org/abs/2102.06171.pdf.

Parameters
  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups

  • lr (float) – learning rate

  • momentum (float, optional) – momentum factor (default: 0)

  • weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)

  • dampening (float, optional) – dampening for momentum (default: 0)

  • nesterov (bool, optional) – enables Nesterov momentum (default: False)

  • dampening – dampening for momentum (default: 0.01)

  • eps (float, optional) – dampening for momentum (default: 1e-3)

Example

>>> optimizer = torch.optim.SGD_AGC(model.parameters(), lr=0.1, momentum=0.9)
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()

Note

The implementation has been adapted from the PyTorch framework and the official NF-Nets paper. The implementation of SGD with Momentum/Nesterov subtly differs from Sutskever et. al. and implementations in some other frameworks.

Considering the specific case of Momentum, the update can be written as

\[\begin{split}\begin{aligned} v_{t+1} & = \mu * v_{t} + g_{t+1}, \\ p_{t+1} & = p_{t} - \text{lr} * v_{t+1}, \end{aligned}\end{split}\]

where \(p\), \(g\), \(v\) and \(\mu\) denote the parameters, gradient, velocity, and momentum respectively.

This is in contrast to Sutskever et. al. and other frameworks which employ an update of the form

\[\begin{split}\begin{aligned} v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\ p_{t+1} & = p_{t} - v_{t+1}. \end{aligned}\end{split}\]

The Nesterov version is analogously modified.

step(closure=None)[source]

Performs a single optimization step.

Parameters

closure (callable, optional) – A closure that reevaluates the model and returns the loss.

nfnets.sgd_agc.unitwise_norm(x: torch.Tensor)[source]

nfnets.utils module

nfnets.utils.replace_conv(module: torch.nn.modules.module.Module)[source]

Recursively replaces every convolution with WSConv2d.

Usage: replace_conv(model) #(In-line replacement) :param module: target’s model whose convolutions must be replaced. :type module: nn.Module