Source code for super_gradients.training.models.ddrnet

import torch
import torch.nn as nn
import torch.nn.functional as F

from super_gradients.training.models import BasicBlock, Bottleneck, SgModule, HpmStruct

"""
paper: Deep Dual-resolution Networks for Real-time and
Accurate Semantic Segmentation of Road Scenes ( https://arxiv.org/pdf/2101.06085.pdf )
code from git repo: https://github.com/ydhongHIT/DDRNet
"""


[docs]def ConvBN(in_channels: int, out_channels: int, kernel_size: int, bias=True, stride=1, padding=0, add_relu=False): seq = [nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias, stride=stride, padding=padding), nn.BatchNorm2d(out_channels)] if add_relu: seq.append(nn.ReLU(inplace=True)) return nn.Sequential(*seq)
def _make_layer(block, in_planes, planes, num_blocks, stride=1, expansion=1): layers = [] layers.append(block(in_planes, planes, stride, final_relu=num_blocks > 1, expansion=expansion)) in_planes = planes * expansion if num_blocks > 1: for i in range(1, num_blocks): if i == (num_blocks - 1): layers.append(block(in_planes, planes, stride=1, final_relu=False, expansion=expansion)) else: layers.append(block(in_planes, planes, stride=1, final_relu=True, expansion=expansion)) return nn.Sequential(*layers)
[docs]class DAPPMBranch(nn.Module): def __init__(self, kernel_size: int, stride: int, in_planes: int, branch_planes: int, inter_mode: str = 'bilinear'): """ A DAPPM branch :param kernel_size: the kernel size for the average pooling when stride=0 this parameter is omitted and AdaptiveAvgPool2d over all the input is performed :param stride: stride for the average pooling when stride=0: an AdaptiveAvgPool2d over all the input is performed (output is 1x1) when stride=1: no average pooling is performed when stride>1: average polling is performed (scaling the input down and up again) :param in_planes: :param branch_planes: width after the the first convolution :param inter_mode: interpolation mode for upscaling """ super().__init__() down_list = [] if stride == 0: # when stride is 0 average pool all the input to 1x1 down_list.append(nn.AdaptiveAvgPool2d((1, 1))) elif stride == 1: # when stride id 1 no average pooling is used pass else: down_list.append(nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=stride)) down_list.append(nn.BatchNorm2d(in_planes)) down_list.append(nn.ReLU(inplace=True)) down_list.append(nn.Conv2d(in_planes, branch_planes, kernel_size=1, bias=False)) self.down_scale = nn.Sequential(*down_list) self.up_scale = UpscaleOnline(inter_mode) if stride != 1: self.process = nn.Sequential( nn.BatchNorm2d(branch_planes), nn.ReLU(inplace=True), nn.Conv2d(branch_planes, branch_planes, kernel_size=3, padding=1, bias=False), )
[docs] def forward(self, x): """ All branches of the DAPPM but the first one receive the output of the previous branch as a second input :param x: in branch 0 - the original input of the DAPPM. in other branches - a list containing the original input and the output of the previous branch. """ if isinstance(x, list): output_of_prev_branch = x[1] x = x[0] else: output_of_prev_branch = None in_width = x.shape[-1] in_height = x.shape[-2] out = self.down_scale(x) out = self.up_scale(out, output_height=in_height, output_width=in_width) if output_of_prev_branch is not None: out = self.process(out + output_of_prev_branch) return out
[docs]class DAPPM(nn.Module): def __init__(self, in_planes: int, branch_planes: int, out_planes: int, kernel_sizes: list, strides: list, inter_mode: str = 'bilinear'): super().__init__() assert len(kernel_sizes) == len(strides), 'len of kernel_sizes and strides must be the same' self.branches = nn.ModuleList() for kernel_size, stride in zip(kernel_sizes, strides): self.branches.append(DAPPMBranch(kernel_size=kernel_size, stride=stride, in_planes=in_planes, branch_planes=branch_planes, inter_mode=inter_mode)) self.compression = nn.Sequential( nn.BatchNorm2d(branch_planes * len(self.branches)), nn.ReLU(inplace=True), nn.Conv2d(branch_planes * len(self.branches), out_planes, kernel_size=1, bias=False), ) self.shortcut = nn.Sequential( nn.BatchNorm2d(in_planes), nn.ReLU(inplace=True), nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False), )
[docs] def forward(self, x): x_list = [] for i, branch in enumerate(self.branches): if i == 0: x_list.append(branch(x)) else: x_list.append(branch([x, x_list[i - 1]])) out = self.compression(torch.cat(x_list, 1)) + self.shortcut(x) return out
[docs]class SegmentHead(nn.Module): def __init__(self, in_planes: int, inter_planes: int, out_planes: int, scale_factor: int, inter_mode: str = 'bilinear'): """ Last stage of the segmentation network. Reduces the number of output planes (usually to num_classes) while increasing the size by scale_factor :param in_planes: width of input :param inter_planes: width of internal conv. must be a multiple of scale_factor^2 when inter_mode=pixel_shuffle :param out_planes: output width :param scale_factor: scaling factor :param inter_mode: one of nearest, linear, bilinear, bicubic, trilinear, area or pixel_shuffle. when set to pixel_shuffle, an nn.PixelShuffle will be used for scaling """ super().__init__() if inter_mode == 'pixel_shuffle': assert inter_planes % (scale_factor ^ 2) == 0, 'when using pixel_shuffle, inter_planes must be a multiple of scale_factor^2' self.bn1 = nn.BatchNorm2d(in_planes) self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(inter_planes) self.relu = nn.ReLU(inplace=True) if inter_mode == 'pixel_shuffle': self.conv2 = nn.Conv2d(inter_planes, inter_planes, kernel_size=1, padding=0, bias=True) self.upscale = nn.PixelShuffle(scale_factor) else: self.conv2 = nn.Conv2d(inter_planes, out_planes, kernel_size=1, padding=0, bias=True) self.upscale = nn.Upsample(scale_factor=scale_factor, mode=inter_mode) self.scale_factor = scale_factor
[docs] def forward(self, x): x = self.conv1(self.relu(self.bn1(x))) out = self.conv2(self.relu(self.bn2(x))) out = self.upscale(out) return out
[docs]class UpscaleOnline(nn.Module): """ In some cases the required scale/size for the scaling is known only when the input is received. This class support such cases. only the interpolation mode is set in advance. """ def __init__(self, mode='bilinear'): super().__init__() self.mode = mode
[docs] def forward(self, x, output_height: int, output_width: int): return F.interpolate(x, size=[output_height, output_width], mode=self.mode)
[docs]class DDRBackBoneBase(nn.Module): """A base class defining functions that must be supported by DDRBackBones """
[docs] def validate_backbone_attributes(self): expected_attributes = ['stem', 'layer1', 'layer2', 'layer3', 'layer4', 'input_channels'] for attribute in expected_attributes: assert hasattr(self, attribute), f'Invalid backbone - attribute \'{attribute}\' is missing'
[docs] def get_backbone_output_number_of_channels(self): """Return a dictionary of the shapes of each output of the backbone to determine the in_channels of the skip and compress layers""" output_shapes = {} x = torch.randn(1, self.input_channels, 320, 320) x = self.stem(x) x = self.layer1(x) x = self.layer2(x) output_shapes['layer2'] = x.shape[1] x = self.layer3(x) output_shapes['layer3'] = x.shape[1] x = self.layer4(x) output_shapes['layer4'] = x.shape[1] return output_shapes
[docs]class BasicDDRBackBone(DDRBackBoneBase): def __init__(self, block: nn.Module.__class__, width: int, layers: list, input_channels: int): super().__init__() self.input_channels = input_channels self.stem = nn.Sequential( ConvBN(in_channels=input_channels, out_channels=width, kernel_size=3, stride=2, padding=1, add_relu=True), ConvBN(in_channels=width, out_channels=width, kernel_size=3, stride=2, padding=1, add_relu=True), ) self.layer1 = _make_layer(block=block, in_planes=width, planes=width, num_blocks=layers[0]) self.layer2 = _make_layer(block=block, in_planes=width, planes=width * 2, num_blocks=layers[1], stride=2) self.layer3 = _make_layer(block=block, in_planes=width * 2, planes=width * 4, num_blocks=layers[2], stride=2) self.layer4 = _make_layer(block=block, in_planes=width * 4, planes=width * 8, num_blocks=layers[3], stride=2)
[docs]class RegnetDDRBackBone(DDRBackBoneBase): """ Translation of Regnet to fit DDR model """ def __init__(self, regnet_module: nn.Module.__class__): super().__init__() self.input_channels = regnet_module.net.stem.conv.in_channels self.stem = regnet_module.net.stem self.layer1 = regnet_module.net.stage_0 self.layer2 = regnet_module.net.stage_1 self.layer3 = regnet_module.net.stage_2 self.layer4 = regnet_module.net.stage_3
[docs]class DDRNet(SgModule): def __init__(self, backbone: DDRBackBoneBase.__class__, additional_layers: list, upscale_module: nn.Module, num_classes: int, highres_planes: int, spp_width: int, head_width: int, aux_head: bool = False, ssp_inter_mode: str = 'bilinear', segmentation_inter_mode: str = 'bilinear', skip_block: nn.Module.__class__ = None, layer5_block: nn.Module.__class__ = Bottleneck, layer5_bottleneck_expansion: int = 2, classification_mode=False, spp_kernel_sizes: list = [1, 5, 9, 17, 0], spp_strides: list = [1, 2, 4, 8, 0]): """ :param backbone: the low resolution branch of DDR, expected to have specific attributes in the class :param additional_layers: list of num blocks for the highres stage and layer5 :param upscale_module: upscale to use in the backbone (DAPPM and Segmentation head are using bilinear interpolation) :param num_classes: number of classes :param highres_planes: number of channels in the high resolution net :param aux_head: add a second segmentation head (fed from after compress3 + upscale). this head can be used during training (see paper https://arxiv.org/pdf/2101.06085.pdf for details) :param ssp_inter_mode: the interpolation used in the SPP block :param segmentation_inter_mode: the interpolation used in the segmentation head :param skip_block: allows specifying a different block (from 'block') for the skip layer :param layer5_block: type of block to use in layer5 and layer5_skip :param layer5_bottleneck_expansion: determines the expansion rate for Bottleneck block :param spp_kernel_sizes: list of kernel sizes for the spp module pooling :param spp_strides: list of strides for the spp module pooling """ super().__init__() self.aux_head = aux_head self.upscale = upscale_module self.ssp_inter_mode = ssp_inter_mode self.segmentation_inter_mode = segmentation_inter_mode self.relu = nn.ReLU(inplace=False) self.classification_mode = classification_mode assert not (aux_head and classification_mode), "auxiliary head cannot be used in classification mode" assert isinstance(backbone, DDRBackBoneBase), 'The backbone must inherit from AbstractDDRBackBone' self.backbone = backbone self.backbone.validate_backbone_attributes() out_chan_backbone = self.backbone.get_backbone_output_number_of_channels() self.compression3 = ConvBN(in_channels=out_chan_backbone['layer3'], out_channels=highres_planes, kernel_size=1, bias=False) self.compression4 = ConvBN(in_channels=out_chan_backbone['layer4'], out_channels=highres_planes, kernel_size=1, bias=False) self.down3 = ConvBN(in_channels=highres_planes, out_channels=out_chan_backbone['layer3'], kernel_size=3, stride=2, padding=1, bias=False) self.down4 = nn.Sequential( ConvBN(in_channels=highres_planes, out_channels=highres_planes * 2, kernel_size=3, stride=2, padding=1, bias=False, add_relu=True), ConvBN(in_channels=highres_planes * 2, out_channels=out_chan_backbone['layer4'], kernel_size=3, stride=2, padding=1, bias=False)) self.layer3_skip = _make_layer(block=skip_block, in_planes=out_chan_backbone['layer2'], planes=highres_planes, num_blocks=additional_layers[1]) self.layer4_skip = _make_layer(block=skip_block, in_planes=highres_planes, planes=highres_planes, num_blocks=additional_layers[2]) self.layer5_skip = _make_layer(block=layer5_block, in_planes=highres_planes, planes=highres_planes, num_blocks=additional_layers[3], expansion=layer5_bottleneck_expansion) # when training the backbones on Imagenet: # - layer 5 has stride 1 # - a new high_to_low_fusion is added with to 3x3 convs with stride 2 (and double the width) # - a classification head is placed instead of the segmentation head if self.classification_mode: self.layer5 = _make_layer(block=layer5_block, in_planes=out_chan_backbone['layer4'], planes=out_chan_backbone['layer4'], num_blocks=additional_layers[0], expansion=layer5_bottleneck_expansion) highres_planes_out = highres_planes * layer5_bottleneck_expansion self.high_to_low_fusion = nn.Sequential(ConvBN(in_channels=highres_planes_out, out_channels=highres_planes_out * 2, kernel_size=3, stride=2, padding=1, add_relu=True), ConvBN(in_channels=highres_planes_out * 2, out_channels=out_chan_backbone['layer4'] * layer5_bottleneck_expansion, kernel_size=3, stride=2, padding=1, add_relu=True)) self.average_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Linear(in_features=out_chan_backbone['layer4'] * layer5_bottleneck_expansion, out_features=num_classes) else: self.layer5 = _make_layer(block=layer5_block, in_planes=out_chan_backbone['layer4'], planes=out_chan_backbone['layer4'], num_blocks=additional_layers[0], stride=2, expansion=layer5_bottleneck_expansion) self.spp = DAPPM(in_planes=out_chan_backbone['layer4'] * layer5_bottleneck_expansion, branch_planes=spp_width, out_planes=highres_planes * layer5_bottleneck_expansion, inter_mode=self.ssp_inter_mode, kernel_sizes=spp_kernel_sizes, strides=spp_strides) if self.aux_head: self.seghead_extra = SegmentHead(highres_planes, head_width, num_classes, 8, inter_mode=self.segmentation_inter_mode) self.final_layer = SegmentHead(highres_planes * layer5_bottleneck_expansion, head_width, num_classes, 8, inter_mode=self.segmentation_inter_mode) self._initialize_weights() def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0)
[docs] def forward(self, x): width_output = x.shape[-1] // 8 height_output = x.shape[-2] // 8 x = self.backbone.stem(x) x = self.backbone.layer1(x) out_layer2 = self.backbone.layer2(self.relu(x)) out_layer3 = self.backbone.layer3(self.relu(out_layer2)) out_layer3_skip = self.layer3_skip(self.relu(out_layer2)) x = out_layer3 + self.down3(self.relu(out_layer3_skip)) x_skip = out_layer3_skip + self.upscale(self.compression3(self.relu(out_layer3)), height_output, width_output) # save for auxiliary head if self.aux_head: temp = x_skip out_layer4 = self.backbone.layer4(self.relu(x)) out_layer4_skip = self.layer4_skip(self.relu(x_skip)) x = out_layer4 + self.down4(self.relu(out_layer4_skip)) x_skip = out_layer4_skip + self.upscale(self.compression4(self.relu(out_layer4)), height_output, width_output) out_layer5_skip = self.layer5_skip(self.relu(x_skip)) if self.classification_mode: x_skip = self.high_to_low_fusion(self.relu(out_layer5_skip)) x = self.layer5(self.relu(x)) x = self.average_pool(x + x_skip) x = self.fc(x.squeeze()) return x else: x = self.upscale(self.spp(self.layer5(self.relu(x))), height_output, width_output) x = self.final_layer(x + out_layer5_skip) if self.aux_head: x_extra = self.seghead_extra(temp) return x, x_extra else: return x
[docs]class DDRNetCustom(DDRNet): def __init__(self, arch_params: HpmStruct): """ Parse arch_params and translate the parameters to build the original DDRNet architecture """ super().__init__(backbone=arch_params.backbone, additional_layers=arch_params.additional_layers, upscale_module=arch_params.upscale_module, num_classes=arch_params.num_classes, highres_planes=arch_params.highres_planes, spp_width=arch_params.spp_planes, head_width=arch_params.head_planes, aux_head=arch_params.aux_head, ssp_inter_mode=arch_params.ssp_inter_mode, segmentation_inter_mode=arch_params.segmentation_inter_mode, skip_block=arch_params.skip_block, layer5_block=arch_params.layer5_block, layer5_bottleneck_expansion=arch_params.layer5_bottleneck_expansion, classification_mode=arch_params.classification_mode, spp_kernel_sizes=arch_params.spp_kernel_sizes, spp_strides=arch_params.spp_strides)
DEFAULT_DDRNET_23_PARAMS = { "input_channels": 3, "block": BasicBlock, "skip_block": BasicBlock, "layer5_block": Bottleneck, "layer5_bottleneck_expansion": 2, "layers": [2, 2, 2, 2, 1, 2, 2, 1], "upscale_module": UpscaleOnline(), "planes": 64, "highres_planes": 128, "head_planes": 128, "aux_head": False, "segmentation_inter_mode": 'bilinear', "classification_mode": False, "spp_planes": 128, "ssp_inter_mode": 'bilinear', "spp_kernel_sizes": [1, 5, 9, 17, 0], "spp_strides": [1, 2, 4, 8, 0], } DEFAULT_DDRNET_23_SLIM_PARAMS = { **DEFAULT_DDRNET_23_PARAMS, "planes": 32, "highres_planes": 64, "head_planes": 64, }
[docs]class DDRNet23(DDRNetCustom): def __init__(self, arch_params: HpmStruct): _arch_params = HpmStruct(**DEFAULT_DDRNET_23_PARAMS) _arch_params.override(**arch_params.to_dict()) # BUILD THE BACKBONE AND INSERT TO THE _arch_params backbone_layers, _arch_params.additional_layers = _arch_params.layers[:4], _arch_params.layers[4:] _arch_params.backbone = BasicDDRBackBone(block=_arch_params.block, width=_arch_params.planes, layers=backbone_layers, input_channels=_arch_params.input_channels) super().__init__(_arch_params)
[docs]class DDRNet23Slim(DDRNetCustom): def __init__(self, arch_params: HpmStruct): _arch_params = HpmStruct(**DEFAULT_DDRNET_23_SLIM_PARAMS) _arch_params.override(**arch_params.to_dict()) # BUILD THE BACKBONE AND INSERT TO THE _arch_params backbone_layers, _arch_params.additional_layers = _arch_params.layers[:4], _arch_params.layers[4:] _arch_params.backbone = BasicDDRBackBone(block=_arch_params.block, width=_arch_params.planes, layers=backbone_layers, input_channels=_arch_params.input_channels) super().__init__(_arch_params)
[docs]class AnyBackBoneDDRNet23(DDRNetCustom): def __init__(self, arch_params: HpmStruct): _arch_params = HpmStruct(**DEFAULT_DDRNET_23_PARAMS) _arch_params.override(**arch_params.to_dict()) assert len(_arch_params.layers) == 4 or len(_arch_params.layers) == 8, \ 'The length of \'arch_params.layers\' must be 4 or 8' # TAKE THE LAST 4 NUMBERS AS THE ADDITIONAL LAYERS SPECIFICATION _arch_params.additional_layers = _arch_params.layers[-4:] assert hasattr(_arch_params, 'backbone'), 'AnyBackBoneDDRNet_23 requires having a backbone in arch_params' if hasattr(_arch_params, 'input_channels'): assert _arch_params.backbone.input_channels == _arch_params.input_channels, \ '\'input_channels\' was given in arch_params with a different value than existing in the backbone' super().__init__(_arch_params)