Source code for super_gradients.training.models.regnet


"""
Regnet - from paper: Designing Network Design Spaces - https://arxiv.org/pdf/2003.13678.pdf
Implementation of paradigm described in paper published by Facebook AI Research (FAIR)
@author: Signatrix GmbH
Code taken from: https://github.com/signatrix/regnet - MIT Licence
"""
import numpy as np
import torch.nn as nn
from math import sqrt
from super_gradients.training.models.sg_module import SgModule
from super_gradients.training.utils.regularization_utils import DropPath
from super_gradients.training.utils.utils import get_param





[docs]class Stem(nn.Module): # From figure 3 def __init__(self, in_channels, out_channels): super(Stem, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False) self.bn = nn.BatchNorm2d(out_channels) self.rl = nn.ReLU()
[docs] def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.rl(x) return x
[docs]class XBlock(nn.Module): # From figure 4 def __init__(self, in_channels, out_channels, bottleneck_ratio, group_width, stride, se_ratio=None, droppath_prob=0.): super(XBlock, self).__init__() inter_channels = int(out_channels // bottleneck_ratio) groups = int(inter_channels // group_width) self.conv_block_1 = nn.Sequential( nn.Conv2d(in_channels, inter_channels, kernel_size=1, bias=False), nn.BatchNorm2d(inter_channels), nn.ReLU() ) self.conv_block_2 = nn.Sequential( nn.Conv2d(inter_channels, inter_channels, kernel_size=3, stride=stride, groups=groups, padding=1, bias=False), nn.BatchNorm2d(inter_channels), nn.ReLU() ) if se_ratio is not None: se_channels = in_channels // se_ratio self.se = nn.Sequential( nn.AdaptiveAvgPool2d(output_size=1), nn.Conv2d(inter_channels, se_channels, kernel_size=1, bias=True), nn.ReLU(), nn.Conv2d(se_channels, inter_channels, kernel_size=1, bias=True), nn.Sigmoid(), ) else: self.se = None self.conv_block_3 = nn.Sequential( nn.Conv2d(inter_channels, out_channels, kernel_size=1, bias=False), nn.BatchNorm2d(out_channels) ) if stride != 1 or in_channels != out_channels: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(out_channels) ) else: self.shortcut = None self.drop_path = DropPath(drop_prob=droppath_prob) self.rl = nn.ReLU()
[docs] def forward(self, x): x1 = self.conv_block_1(x) x1 = self.conv_block_2(x1) if self.se is not None: x1 = x1 * self.se(x1) x1 = self.conv_block_3(x1) if self.shortcut is not None: x2 = self.shortcut(x) else: x2 = x x1 = self.drop_path(x1) x = self.rl(x1 + x2) return x
[docs]class Stage(nn.Module): # From figure 3 def __init__(self, num_blocks, in_channels, out_channels, bottleneck_ratio, group_width, stride, se_ratio, droppath_prob): super(Stage, self).__init__() self.blocks = nn.Sequential() self.blocks.add_module("block_0", XBlock(in_channels, out_channels, bottleneck_ratio, group_width, stride, se_ratio, droppath_prob)) for i in range(1, num_blocks): self.blocks.add_module("block_{}".format(i), XBlock(out_channels, out_channels, bottleneck_ratio, group_width, 1, se_ratio, droppath_prob))
[docs] def forward(self, x): x = self.blocks(x) return x
[docs]class AnyNetX(SgModule): def __init__(self, ls_num_blocks, ls_block_width, ls_bottleneck_ratio, ls_group_width, stride, num_classes, se_ratio, backbone_mode, dropout_prob=0., droppath_prob=0., input_channels=3): super(AnyNetX, self).__init__() verify_correctness_of_parameters(ls_num_blocks, ls_block_width, ls_bottleneck_ratio, ls_group_width) self.net = nn.Sequential() self.backbone_mode = backbone_mode prev_block_width = 32 self.net.add_module("stem", Stem(in_channels=input_channels, out_channels=prev_block_width)) for i, (num_blocks, block_width, bottleneck_ratio, group_width) in enumerate(zip(ls_num_blocks, ls_block_width, ls_bottleneck_ratio, ls_group_width)): self.net.add_module("stage_{}".format(i), Stage(num_blocks, prev_block_width, block_width, bottleneck_ratio, group_width, stride, se_ratio, droppath_prob)) prev_block_width = block_width # FOR BACK BONE MODE - DO NOT ADD THE HEAD (AVG_POOL + FC) if not self.backbone_mode: self.net.add_module("head", Head(ls_block_width[-1], num_classes, dropout_prob)) self.initialize_weight()
[docs] def initialize_weight(self): for m in self.modules(): if isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(mean=0.0, std=sqrt(2.0 / fan_out)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1.0) m.bias.data.zero_() elif isinstance(m, nn.Linear): m.weight.data.normal_(mean=0.0, std=0.01) m.bias.data.zero_()
[docs] def forward(self, x): x = self.net(x) return x
[docs]class RegNetX(AnyNetX): def __init__(self, initial_width, slope, quantized_param, network_depth, bottleneck_ratio, group_width, stride, arch_params, se_ratio=None, input_channels=3): # We need to derive block width and number of blocks from initial parameters. parameterized_width = initial_width + slope * np.arange(network_depth) # From equation 2 parameterized_block = np.log(parameterized_width / initial_width) / np.log(quantized_param) # From equation 3 parameterized_block = np.round(parameterized_block) quantized_width = initial_width * np.power(quantized_param, parameterized_block) # We need to convert quantized_width to make sure that it is divisible by 8 quantized_width = 8 * np.round(quantized_width / 8) ls_block_width, ls_num_blocks = np.unique(quantized_width.astype(np.int), return_counts=True) # At this points, for each stage, the above-calculated block width could be incompatible to group width # due to bottleneck ratio. Hence, we need to adjust the formers. # Group width could be swapped to number of groups, since their multiplication is block width ls_group_width = np.array([min(group_width, block_width // bottleneck_ratio) for block_width in ls_block_width]) ls_block_width = np.round(ls_block_width // bottleneck_ratio / group_width) * group_width ls_bottleneck_ratio = [bottleneck_ratio for _ in range(len(ls_block_width))] # GET THE BACKBONE MODE FROM arch_params IF EXISTS - O.W. - SET AS FALSE backbone_mode = get_param(arch_params, 'backbone_mode', False) dropout_prob = get_param(arch_params, 'dropout_prob', 0.) droppath_prob = get_param(arch_params, 'droppath_prob', 0.) super(RegNetX, self).__init__(ls_num_blocks, ls_block_width.astype(np.int).tolist(), ls_bottleneck_ratio, ls_group_width.tolist(), stride, arch_params.num_classes, se_ratio, backbone_mode, dropout_prob, droppath_prob, input_channels)
[docs]class RegNetY(RegNetX): # RegNetY = RegNetX + SE def __init__(self, initial_width, slope, quantized_param, network_depth, bottleneck_ratio, group_width, stride, arch_params, se_ratio, input_channels=3): super(RegNetY, self).__init__(initial_width, slope, quantized_param, network_depth, bottleneck_ratio, group_width, stride, arch_params, se_ratio, input_channels)
[docs]def verify_correctness_of_parameters(ls_num_blocks, ls_block_width, ls_bottleneck_ratio, ls_group_width): """VERIFY THAT THE GIVEN PARAMETERS FIT THE SEARCH SPACE DEFINED IN THE REGNET PAPER""" err_message = 'Parameters don\'t fit' assert len(set(ls_bottleneck_ratio)) == 1, f"{err_message} AnyNetXb" assert len(set(ls_group_width)) == 1, f"{err_message} AnyNetXc" assert all(i <= j for i, j in zip(ls_block_width, ls_block_width[1:])) is True, f"{err_message} AnyNetXd" if len(ls_num_blocks) > 2: assert all(i <= j for i, j in zip(ls_num_blocks[:-2], ls_num_blocks[1:-1])) is True, f"{err_message} AnyNetXe" # For each stage & each layer, number of channels (block width / bottleneck ratio) must be divisible by group width for block_width, bottleneck_ratio, group_width in zip(ls_block_width, ls_bottleneck_ratio, ls_group_width): assert int(block_width // bottleneck_ratio) % group_width == 0
[docs]class CustomRegNet(RegNetX): def __init__(self, arch_params): """All parameters must be provided in arch_params other than SE""" super().__init__(initial_width=arch_params.initial_width, slope=arch_params.slope, quantized_param=arch_params.quantized_param, network_depth=arch_params.network_depth, bottleneck_ratio=arch_params.bottleneck_ratio, group_width=arch_params.group_width, stride=arch_params.stride, arch_params=arch_params, se_ratio=arch_params.se_ratio if hasattr(arch_params, 'se_ratio') else None, input_channels=get_param(arch_params, 'input_channels', 3))
[docs]class NASRegNet(RegNetX): def __init__(self, arch_params): """All parameters are provided as a single structure list: arch_params.structure""" structure = arch_params.structure super().__init__(initial_width=structure[0], slope=structure[1], quantized_param=structure[2], network_depth=structure[3], bottleneck_ratio=structure[4], group_width=structure[5], stride=structure[6], se_ratio=structure[7] if structure[7] > 0 else None, arch_params=arch_params)
[docs]class RegNetY200(RegNetY): def __init__(self, arch_params): super().__init__(24, 36, 2.5, 13, 1, 8, 2, arch_params, 4)
[docs]class RegNetY400(RegNetY): def __init__(self, arch_params): super().__init__(48, 28, 2.1, 16, 1, 8, 2, arch_params, 4)
[docs]class RegNetY600(RegNetY): def __init__(self, arch_params): super().__init__(48, 33, 2.3, 15, 1, 16, 2, arch_params, 4)
[docs]class RegNetY800(RegNetY): def __init__(self, arch_params): super().__init__(56, 39, 2.4, 14, 1, 16, 2, arch_params, 4)