Source code for super_gradients.training.models.repvgg

'''
Repvgg Pytorch Implementation. This model trains a vgg with residual blocks
but during inference (in deployment mode) will convert the model to vgg model.
Pretrained models: https://drive.google.com/drive/folders/1Avome4KvNp0Lqh2QwhXO6L5URQjzCjUq
Refrerences:
    [1] https://github.com/DingXiaoH/RepVGG
    [2] https://arxiv.org/pdf/2101.03697.pdf

Based on https://github.com/DingXiaoH/RepVGG
'''
from typing import Union

import torch.nn as nn
import numpy as np
import torch
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
from super_gradients.training.models import SgModule
import torch.nn.functional as F
from super_gradients.training.utils.module_utils import fuse_repvgg_blocks_residual_branches
from super_gradients.training.utils.utils import get_param


[docs]class SEBlock(nn.Module): def __init__(self, input_channels, internal_neurons): super(SEBlock, self).__init__() self.down = nn.Conv2d(in_channels=input_channels, out_channels=internal_neurons, kernel_size=1, stride=1, bias=True) self.up = nn.Conv2d(in_channels=internal_neurons, out_channels=input_channels, kernel_size=1, stride=1, bias=True) self.input_channels = input_channels
[docs] def forward(self, inputs): x = F.avg_pool2d(inputs, kernel_size=inputs.size(3)) x = self.down(x) x = F.relu(x) x = self.up(x) x = torch.sigmoid(x) x = x.view(-1, self.input_channels, 1, 1) return inputs * x
[docs]def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1, dilation=1): result = nn.Sequential() result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False, dilation=dilation)) result.add_module('bn', nn.BatchNorm2d(num_features=out_channels)) return result
[docs]class RepVGGBlock(nn.Module): ''' Repvgg block consists of three branches 3x3: a branch of a 3x3 convolution + batchnorm + relu 1x1: a branch of a 1x1 convolution + batchnorm + relu no_conv_branch: a branch with only batchnorm which will only be used if input channel == output channel (usually in all but the first block of each stage) ''' def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, build_residual_branches=True, use_relu=True, use_se=False): super(RepVGGBlock, self).__init__() self.groups = groups self.in_channels = in_channels assert kernel_size == 3 assert padding == dilation self.nonlinearity = nn.ReLU() if use_relu else nn.Identity() self.se = nn.Identity() if not use_se else SEBlock(out_channels, internal_neurons=out_channels // 16) self.no_conv_branch = nn.BatchNorm2d( num_features=in_channels) if out_channels == in_channels and stride == 1 else None self.branch_3x3 = conv_bn(in_channels=in_channels, out_channels=out_channels, dilation=dilation, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups) self.branch_1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=0, groups=groups) if not build_residual_branches: self.fuse_block_residual_branches() else: self.build_residual_branches = True
[docs] def forward(self, inputs): if not self.build_residual_branches: return self.nonlinearity(self.se(self.rbr_reparam(inputs))) if self.no_conv_branch is None: id_out = 0 else: id_out = self.no_conv_branch(inputs) return self.nonlinearity(self.se(self.branch_3x3(inputs) + self.branch_1x1(inputs) + id_out))
def _get_equivalent_kernel_bias(self): """ Fuses the 3x3, 1x1 and identity branches into a single 3x3 conv layer """ kernel3x3, bias3x3 = self._fuse_bn_tensor(self.branch_3x3) kernel1x1, bias1x1 = self._fuse_bn_tensor(self.branch_1x1) kernelid, biasid = self._fuse_bn_tensor(self.no_conv_branch) return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid def _pad_1x1_to_3x3_tensor(self, kernel1x1): """ padding the 1x1 convolution weights with zeros to be able to fuse the 3x3 conv layer with the 1x1 :param kernel1x1: weights of the 1x1 convolution :type kernel1x1: :return: padded 1x1 weights :rtype: """ if kernel1x1 is None: return 0 else: return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1]) def _fuse_bn_tensor(self, branch): """ Fusing of the batchnorm into the conv layer. If the branch is the identity branch (no conv) the kernel will simply be eye. :param branch: :type branch: :return: :rtype: """ if branch is None: return 0, 0 if isinstance(branch, nn.Sequential): kernel = branch.conv.weight running_mean = branch.bn.running_mean running_var = branch.bn.running_var gamma = branch.bn.weight beta = branch.bn.bias eps = branch.bn.eps else: assert isinstance(branch, nn.BatchNorm2d) if not hasattr(self, 'id_tensor'): input_dim = self.in_channels // self.groups kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32) for i in range(self.in_channels): kernel_value[i, i % input_dim, 1, 1] = 1 self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device) kernel = self.id_tensor running_mean = branch.running_mean running_var = branch.running_var gamma = branch.weight beta = branch.bias eps = branch.eps std = (running_var + eps).sqrt() t = (gamma / std).reshape(-1, 1, 1, 1) return kernel * t, beta - running_mean * gamma / std
[docs] def fuse_block_residual_branches(self): """ converts a repvgg block from training model (with branches) to deployment mode (vgg like model) :return: :rtype: """ if hasattr(self, "build_residual_branches") and not self.build_residual_branches: return kernel, bias = self._get_equivalent_kernel_bias() self.rbr_reparam = nn.Conv2d(in_channels=self.branch_3x3.conv.in_channels, out_channels=self.branch_3x3.conv.out_channels, kernel_size=self.branch_3x3.conv.kernel_size, stride=self.branch_3x3.conv.stride, padding=self.branch_3x3.conv.padding, dilation=self.branch_3x3.conv.dilation, groups=self.branch_3x3.conv.groups, bias=True) self.rbr_reparam.weight.data = kernel self.rbr_reparam.bias.data = bias for para in self.parameters(): para.detach_() self.__delattr__('branch_3x3') self.__delattr__('branch_1x1') if hasattr(self, 'no_conv_branch'): self.__delattr__('no_conv_branch') self.build_residual_branches = False
[docs]class RepVGG(SgModule): def __init__(self, struct, num_classes=1000, width_multiplier=None, build_residual_branches=True, use_se=False, backbone_mode=False, in_channels=3): """ :param struct: list containing number of blocks per repvgg stage :param num_classes: number of classes if nut in backbone mode :param width_multiplier: list of per stage width multiplier or float if using single value for all stages :param build_residual_branches: whether to add residual connections or not :param use_se: use squeeze and excitation layers :param backbone_mode: if true, dropping the final linear layer :param in_channels: input channels """ super(RepVGG, self).__init__() if isinstance(width_multiplier, float): width_multiplier = [width_multiplier] * 4 else: assert len(width_multiplier) == 4 self.build_residual_branches = build_residual_branches self.use_se = use_se self.backbone_mode = backbone_mode self.in_planes = int(64 * width_multiplier[0]) self.stem = RepVGGBlock(in_channels=in_channels, out_channels=self.in_planes, kernel_size=3, stride=2, padding=1, build_residual_branches=build_residual_branches, use_se=self.use_se) self.cur_layer_idx = 1 self.stage1 = self._make_stage(int(64 * width_multiplier[0]), struct[0], stride=2) self.stage2 = self._make_stage(int(128 * width_multiplier[1]), struct[1], stride=2) self.stage3 = self._make_stage(int(256 * width_multiplier[2]), struct[2], stride=2) self.stage4 = self._make_stage(int(512 * width_multiplier[3]), struct[3], stride=2) if not self.backbone_mode: self.avgpool = nn.AdaptiveAvgPool2d(output_size=1) self.linear = nn.Linear(int(512 * width_multiplier[3]), num_classes) if not build_residual_branches: self.eval() # fusing has to be made in eval mode. When called in init, model will be built in eval mode fuse_repvgg_blocks_residual_branches(self) def _make_stage(self, planes, struct, stride): strides = [stride] + [1] * (struct - 1) blocks = [] for stride in strides: blocks.append(RepVGGBlock(in_channels=self.in_planes, out_channels=planes, kernel_size=3, stride=stride, padding=1, groups=1, build_residual_branches=self.build_residual_branches, use_se=self.use_se)) self.in_planes = planes self.cur_layer_idx += 1 return nn.Sequential(*blocks)
[docs] def forward(self, x): out = self.stem(x) out = self.stage1(out) out = self.stage2(out) out = self.stage3(out) out = self.stage4(out) if not self.backbone_mode: out = self.avgpool(out) out = out.view(out.size(0), -1) out = self.linear(out) return out
[docs] def prep_model_for_conversion(self, input_size: Union[tuple, list] = None, **kwargs): if self.build_residual_branches: fuse_repvgg_blocks_residual_branches(self)
[docs] def train(self, mode: bool = True): assert not mode or self.build_residual_branches, "Trying to train a model without residual branches, " \ "set arch_params.build_residual_branches to True and retrain the model" super(RepVGG, self).train(mode=mode)
[docs]class RepVggCustom(RepVGG): def __init__(self, arch_params): super().__init__(struct=arch_params.struct, num_classes=arch_params.num_classes, width_multiplier=arch_params.width_multiplier, build_residual_branches=arch_params.build_residual_branches, use_se=get_param(arch_params, 'use_se', False), backbone_mode=get_param(arch_params, 'backbone_mode', False), in_channels=get_param(arch_params, 'in_channels', 3))
[docs]class RepVggA0(RepVggCustom): def __init__(self, arch_params): arch_params.override(struct=[2, 4, 14, 1], width_multiplier=[0.75, 0.75, 0.75, 2.5]) super().__init__(arch_params=arch_params)
[docs]class RepVggA1(RepVggCustom): def __init__(self, arch_params): arch_params.override(struct=[2, 4, 14, 1], width_multiplier=[1, 1, 1, 2.5]) super().__init__(arch_params=arch_params)
[docs]class RepVggA2(RepVggCustom): def __init__(self, arch_params): arch_params.override(struct=[2, 4, 14, 1], width_multiplier=[1.5, 1.5, 1.5, 2.75]) super().__init__(arch_params=arch_params)
[docs]class RepVggB0(RepVggCustom): def __init__(self, arch_params): arch_params.override(struct=[4, 6, 16, 1], width_multiplier=[1, 1, 1, 2.5]) super().__init__(arch_params=arch_params)
[docs]class RepVggB1(RepVggCustom): def __init__(self, arch_params): arch_params.override(struct=[4, 6, 16, 1], width_multiplier=[2, 2, 2, 4]) super().__init__(arch_params=arch_params)
[docs]class RepVggB2(RepVggCustom): def __init__(self, arch_params): arch_params.override(struct=[4, 6, 16, 1], width_multiplier=[2.5, 2.5, 2.5, 5]) super().__init__(arch_params=arch_params)
[docs]class RepVggB3(RepVggCustom): def __init__(self, arch_params): arch_params.override(struct=[4, 6, 16, 1], width_multiplier=[3, 3, 3, 5]) super().__init__(arch_params=arch_params)
[docs]class RepVggD2SE(RepVggCustom): def __init__(self, arch_params): arch_params.override(struct=[8, 14, 24, 1], width_multiplier=[2.5, 2.5, 2.5, 5]) super().__init__(arch_params=arch_params)