Source code for nfnets.base

import torch
from torch import nn


[docs]class WSConv2d(nn.Conv2d): """WSConv2d Reference: https://github.com/deepmind/deepmind-research/blob/master/nfnets/base.py#L121 """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'): super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode) nn.init.kaiming_normal_(self.weight) self.gain = torch.ones(self.weight.size(0), requires_grad=True)
[docs] def standardize_weight(self, eps): var, mean = torch.var_mean(self.weight, dim=(1, 2, 3), keepdims=True) fan_in = torch.prod(torch.tensor(self.weight.shape[0:])) scale = torch.rsqrt(torch.max(var * fan_in, torch.tensor(eps))) * self.gain.view_as(var) shift = mean * scale return self.weight * scale - shift
[docs] def forward(self, input, eps=1e-4): self.weight.data.copy_(self.standardize_weight(eps)) return super().forward(input)