"""
Shelfnet
paper: https://arxiv.org/abs/1811.11254
based on: https://github.com/juntang-zhuang/ShelfNet
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from super_gradients.training.models.sg_module import SgModule
from super_gradients.training.utils import HpmStruct
from super_gradients.training.models.resnet import BasicBlock, ResNet, Bottleneck
[docs]class FCNHead(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
inter_channels = in_channels // 4
self.fcn = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(inter_channels),
nn.ReLU(),
nn.Dropout2d(0.1, False),
nn.Conv2d(inter_channels, out_channels, 1))
[docs] def forward(self, x):
return self.fcn(x)
[docs]class ShelfBlock(nn.Module):
def __init__(self, in_planes: int, planes: int, stride: int = 1, dropout: float = 0.25):
"""
S-Block implementation from the ShelfNet paper
:param in_planes: input planes
:param planes: output planes
:param stride: convolution stride
:param dropout: dropout percentage
"""
super().__init__()
if in_planes != planes:
self.conv0 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=True)
self.relu0 = nn.ReLU(inplace=True)
self.in_planes = in_planes
self.planes = planes
self.conv1 = nn.Conv2d(self.planes, self.planes, kernel_size=3, stride=stride, padding=1, bias=True)
self.bn1 = nn.BatchNorm2d(self.planes)
self.relu1 = nn.ReLU(inplace=True)
self.dropout = nn.Dropout2d(p=dropout)
self.bn2 = nn.BatchNorm2d(self.planes)
self.relu2 = nn.ReLU(inplace=True)
[docs] def forward(self, x):
if self.in_planes != self.planes:
x = self.conv0(x)
x = self.relu0(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu1(out)
out = self.dropout(out)
out = self.conv1(out)
out = self.bn2(out)
out = out + x
return self.relu2(out)
[docs]class ShelfResNetBackBone(ResNet):
"""
ShelfResNetBackBone - A class that Inherits from the original ResNet class and manipulates the forward pass,
to create a backbone for the ShelfNet architecture
"""
def __init__(self, block, num_blocks, num_classes=10, width_mult=1):
super().__init__(block=block, num_blocks=num_blocks, num_classes=num_classes, width_mult=width_mult,
backbone_mode=True)
[docs] def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.maxpool(out)
feat4 = self.layer1(out) # 1/4
feat8 = self.layer2(feat4) # 1/8
feat16 = self.layer3(feat8) # 1/16
feat32 = self.layer4(feat16) # 1/32
return feat4, feat8, feat16, feat32
[docs]class ShelfResNetBackBone18(ShelfResNetBackBone):
def __init__(self, num_classes: int):
super().__init__(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)
[docs]class ShelfResNetBackBone34(ShelfResNetBackBone):
def __init__(self, num_classes: int):
super().__init__(BasicBlock, [3, 4, 6, 3], num_classes=num_classes)
[docs]class ShelfResNetBackBone503343(ShelfResNetBackBone):
def __init__(self, num_classes: int):
super().__init__(Bottleneck, [3, 3, 4, 3], num_classes=num_classes)
[docs]class ShelfResNetBackBone50(ShelfResNetBackBone):
def __init__(self, num_classes: int):
super().__init__(Bottleneck, [3, 4, 6, 3], num_classes=num_classes)
[docs]class ShelfResNetBackBone101(ShelfResNetBackBone):
def __init__(self, num_classes: int):
super().__init__(Bottleneck, [3, 4, 23, 3], num_classes=num_classes)
[docs]class ShelfNetModuleBase(SgModule):
"""
ShelfNetModuleBase - Base class for the different Modules of the ShelfNet Architecture
"""
def __init__(self):
super().__init__()
[docs] def forward(self, x):
raise NotImplementedError
[docs] def get_params(self):
wd_params, nowd_params = [], []
for name, module in self.named_modules():
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
wd_params.append(module.weight)
if module.bias is not None:
nowd_params.append(module.bias)
elif isinstance(module, nn.BatchNorm2d):
nowd_params += list(module.parameters())
return wd_params, nowd_params
[docs]class ConvBNReLU(ShelfNetModuleBase):
def __init__(self, in_chan: int, out_chan: int, ks: int = 3, stride: int = 1, padding: int = 1):
super(ConvBNReLU, self).__init__()
self.conv = nn.Conv2d(in_chan,
out_chan,
kernel_size=ks,
stride=stride,
padding=padding,
bias=False)
self.bn = nn.BatchNorm2d(out_chan)
self.init_weight()
[docs] def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = F.relu(x)
return x
[docs] def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if ly.bias is not None:
nn.init.constant_(ly.bias, 0)
[docs]class DecoderBase(ShelfNetModuleBase):
def __init__(self, planes: int, layers: int, kernel: int = 3, block=ShelfBlock):
super().__init__()
self.planes = planes
self.layers = layers
self.kernel = kernel
self.padding = int((kernel - 1) / 2)
self.inconv = block(planes, planes)
# CREATE MODULE FOR BOTTOM BLOCK
self.bottom = block(planes * (2 ** (layers - 1)), planes * (2 ** (layers - 1)))
# CREATE MODULE LIST FOR UP BRANCH
self.up_conv_list = nn.ModuleList()
self.up_dense_list = nn.ModuleList()
[docs] def forward(self, x):
raise NotImplementedError
[docs]class DecoderHW(DecoderBase):
"""
DecoderHW - The Decoder for the Heavy-Weight ShelfNet Architecture
"""
def __init__(self, planes, layers, block=ShelfBlock, *args, **kwargs):
super().__init__(planes=planes, layers=layers, block=block, *args, **kwargs)
for i in range(0, layers - 1):
self.up_conv_list.append(
nn.ConvTranspose2d(planes * 2 ** (layers - 1 - i), planes * 2 ** max(0, layers - i - 2), kernel_size=3,
stride=2, padding=1, output_padding=1, bias=True))
self.up_dense_list.append(block(planes * 2 ** max(0, layers - i - 2), planes * 2 ** max(0, layers - i - 2)))
[docs] def forward(self, x):
# BOTTOM BRANCH
out = self.bottom(x[-1])
bottom = out
# UP BRANCH
up_out = []
up_out.append(bottom)
for j in range(0, self.layers - 1):
out = self.up_conv_list[j](out) + x[self.layers - j - 2]
out = self.up_dense_list[j](out)
up_out.append(out)
return up_out
[docs]class DecoderLW(DecoderBase):
"""
DecoderLW - The Decoder for the Light-Weight ShelfNet Architecture
"""
def __init__(self, planes, layers, block=ShelfBlock, *args, **kwargs):
super().__init__(planes=planes, layers=layers, block=block, *args, **kwargs)
for i in range(0, layers - 1):
self.up_conv_list.append(
AttentionRefinementModule(planes * 2 ** (layers - 1 - i), planes * 2 ** max(0, layers - i - 2)))
self.up_dense_list.append(
ConvBNReLU(in_chan=planes * 2 ** max(0, layers - i - 2), out_chan=planes * 2 ** max(0, layers - i - 2),
ks=3, stride=1))
[docs] def forward(self, x):
# BOTTOM BRANCH
out = self.bottom(x[-1])
bottom = out
# UP BRANCH
up_out = []
up_out.append(bottom)
for j in range(0, self.layers - 1):
out = self.up_conv_list[j](out)
out_interpolate = F.interpolate(out, (out.size(2) * 2, out.size(3) * 2), mode='nearest')
out = out_interpolate + x[self.layers - j - 2]
out = self.up_dense_list[j](out)
up_out.append(out)
return up_out
[docs]class AttentionRefinementModule(nn.Module):
def __init__(self, in_chan, out_chan):
super().__init__()
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False)
self.bn_atten = nn.BatchNorm2d(out_chan)
self.sigmoid_atten = nn.Sigmoid()
self.init_weight()
[docs] def forward(self, x):
feat = self.conv(x)
atten = F.avg_pool2d(feat, feat.size()[2:])
atten = self.conv_atten(atten)
atten = self.bn_atten(atten)
atten = self.sigmoid_atten(atten)
out = torch.mul(feat, atten)
return out
[docs] def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if ly.bias is not None:
nn.init.constant_(ly.bias, 0)
[docs]class LadderBlockBase(ShelfNetModuleBase):
def __init__(self, planes: int, layers: int, kernel: int = 3, block=ShelfBlock):
super().__init__()
self.planes = planes
self.layers = layers
self.kernel = kernel
self.padding = int((kernel - 1) / 2)
self.inconv = block(planes, planes)
# CREATE MODULE LIST FOR DOWN BRANCH
self.down_module_list = nn.ModuleList()
for i in range(0, layers - 1):
self.down_module_list.append(block(planes * (2 ** i), planes * (2 ** i)))
# USE STRIDED CONV INSTEAD OF POOLING
self.down_conv_list = nn.ModuleList()
for i in range(0, layers - 1):
self.down_conv_list.append(
nn.Conv2d(planes * 2 ** i, planes * 2 ** (i + 1), stride=2, kernel_size=kernel, padding=self.padding))
# CREATE MODULE FOR BOTTOM BLOCK
self.bottom = block(planes * (2 ** (layers - 1)), planes * (2 ** (layers - 1)))
# CREATE MODULE LIST FOR UP BRANCH
self.up_conv_list = nn.ModuleList()
self.up_dense_list = nn.ModuleList()
[docs] def forward(self, x):
raise NotImplementedError
[docs]class LadderBlockHW(LadderBlockBase):
"""
LadderBlockHW - LadderBlock for the Heavy-Weight ShelfNet Architecture
"""
def __init__(self, planes, layers, block=ShelfBlock, *args, **kwargs):
super().__init__(planes=planes, layers=layers, block=block, *args, **kwargs)
for i in range(0, layers - 1):
self.up_conv_list.append(nn.ConvTranspose2d(planes * 2 ** (layers - i - 1),
planes * 2 ** max(0, layers - i - 2),
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
bias=True))
self.up_dense_list.append(block(planes * 2 ** max(0, layers - i - 2), planes * 2 ** max(0, layers - i - 2)))
[docs] def forward(self, x):
out = self.inconv(x[-1])
down_out = []
# down branch
for i in range(0, self.layers - 1):
out = out + x[-i - 1]
out = self.down_module_list[i](out)
down_out.append(out)
out = self.down_conv_list[i](out)
out = F.relu(out)
# bottom branch
out = self.bottom(out)
bottom = out
# up branch
up_out = []
up_out.append(bottom)
for j in range(0, self.layers - 1):
out = self.up_conv_list[j](out) + down_out[self.layers - j - 2]
out = self.up_dense_list[j](out)
up_out.append(out)
return up_out
[docs]class LadderBlockLW(LadderBlockBase):
"""
LadderBlockLW - LadderBlock for the Light-Weight ShelfNet Architecture
"""
def __init__(self, planes, layers, block=ShelfBlock, *args, **kwargs):
super().__init__(planes=planes, layers=layers, block=block, *args, **kwargs)
for i in range(0, layers - 1):
self.up_conv_list.append(
AttentionRefinementModule(planes * 2 ** (layers - 1 - i), planes * 2 ** max(0, layers - i - 2))
)
self.up_dense_list.append(
ConvBNReLU(in_chan=planes * 2 ** max(0, layers - i - 2), out_chan=planes * 2 ** max(0, layers - i - 2),
ks=3, stride=1))
[docs] def forward(self, x):
out = self.inconv(x[-1])
down_out = []
# DOWN BRANCH
for i in range(0, self.layers - 1):
out = out + x[-i - 1]
out = self.down_module_list[i](out)
down_out.append(out)
out = self.down_conv_list[i](out)
out = F.relu(out)
# BOTTOM BRANCH
out = self.bottom(out)
bottom = out
# UP BRANCH
up_out = []
up_out.append(bottom)
for j in range(0, self.layers - 1):
out = self.up_conv_list[j](out)
out = F.interpolate(out, (out.size(2) * 2, out.size(3) * 2), mode='nearest') + down_out[self.layers - j - 2]
out = self.up_dense_list[j](out)
up_out.append(out)
return up_out
[docs]class NetOutput(ShelfNetModuleBase):
def __init__(self, in_chan: int, mid_chan: int, classes_num: int):
super(NetOutput, self).__init__()
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
self.conv_out = nn.Conv2d(mid_chan, classes_num, kernel_size=3, bias=False,
padding=1)
self.init_weight()
[docs] def forward(self, x):
x = self.conv(x)
x = self.conv_out(x)
return x
[docs] def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if ly.bias is not None:
nn.init.constant_(ly.bias, 0)
[docs]class ShelfNetBase(ShelfNetModuleBase):
"""
ShelfNetBase - ShelfNet Base Generic Architecture
"""
def __init__(self, backbone: ShelfResNetBackBone, planes: int, layers: int, classes_num: int = 21,
image_size: int = 512,
net_output_mid_channels_num: int = 64, arch_params: HpmStruct = None):
self.classes_num = arch_params.num_classes if (arch_params and hasattr(arch_params, 'num_classes')) else classes_num
self.image_size = arch_params.image_size if (arch_params and hasattr(arch_params, 'image_size')) else image_size
super().__init__()
self.net_output_mid_channels_num = net_output_mid_channels_num
self.backbone = backbone(self.classes_num)
self.layers = layers
self.planes = planes
# INITIALIZE WITH AUXILARY HEAD OUTPUTS ONN -> TURN IT OFF TO RUN A FORWARD PASS WITHOUT THE AUXILARY HEADS
self.auxilary_head_outputs = True
# DECODER AND LADDER SHOULD BE IMPLEMENTED BY THE INHERITING CLASS
self.decoder = None
self.ladder = None
# BUILD THE CONV_OUT LIST BASED ON THE AMOUNT OF LAYERS IN THE SHELFNET
self.conv_out_list = torch.nn.ModuleList()
[docs] def forward(self, x):
raise NotImplementedError
[docs] def update_param_groups(self, param_groups: list, lr: float, epoch: int, iter: int, training_params: HpmStruct,
total_batch: int) \
-> list:
"""
update_optimizer_for_param_groups - Updates the specific parameters with different LR
"""
# LEARNING RATE FOR THE BACKBONE IS lr
param_groups[0]['lr'] = lr
for i in range(1, len(param_groups)):
# LEARNING RATE FOR OTHER SHELFNET PARAMS IS lr * 10
param_groups[i]['lr'] = lr * 10
return param_groups
[docs]class ShelfNetHW(ShelfNetBase):
"""
ShelfNetHW - Heavy-Weight Version of ShelfNet
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ladder = LadderBlockHW(planes=self.net_output_mid_channels_num, layers=self.layers)
self.decoder = DecoderHW(planes=self.net_output_mid_channels_num, layers=self.layers)
self.se_layer = nn.Linear(self.net_output_mid_channels_num * 2 ** 3, self.classes_num)
self.aux_head = FCNHead(1024, self.classes_num)
self.final = nn.Conv2d(self.net_output_mid_channels_num, self.classes_num, 1)
# THE MID CHANNELS NUMBER OF THE NET OUTPUT BLOCK
net_out_planes = self.planes
mid_channels_num = self.net_output_mid_channels_num
# INITIALIZE THE conv_out_list
for i in range(self.layers):
self.conv_out_list.append(
ConvBNReLU(in_chan=net_out_planes, out_chan=mid_channels_num, ks=1, padding=0))
mid_channels_num *= 2
net_out_planes *= 2
[docs] def forward(self, x):
image_size = x.size()[2:]
backbone_features_list = list(self.backbone(x))
conv_bn_relu_results_list = []
for feature, conv_bn_relu in zip(backbone_features_list, self.conv_out_list):
out = conv_bn_relu(feature)
conv_bn_relu_results_list.append(out)
decoder_out_list = self.decoder(conv_bn_relu_results_list)
ladder_out_list = self.ladder(decoder_out_list)
preds = [self.final(ladder_out_list[-1])]
# SE_LOSS ENCODING
enc = F.max_pool2d(ladder_out_list[0], kernel_size=ladder_out_list[0].size()[2:])
enc = torch.squeeze(enc, -1)
enc = torch.squeeze(enc, -1)
se = self.se_layer(enc)
preds.append(se)
# UP SAMPLING THE TOP LAYER FOR PREDICTION
preds[0] = F.interpolate(preds[0], image_size, mode='bilinear', align_corners=True)
# AUXILARY HEAD OUTPUT (ONLY RELEVANT FOR LOSS CALCULATION) - USE self.auxilary_head_outputs=FALSE FOR INFERENCE
if self.auxilary_head_outputs or self.training:
aux_out = self.aux_head(backbone_features_list[2])
aux_out = F.interpolate(aux_out, image_size, mode='bilinear', align_corners=True)
preds.append(aux_out)
return tuple(preds)
else:
return preds[0]
[docs] def initialize_param_groups(self, lr: float, training_params: HpmStruct) -> list:
"""
initialize_optimizer_for_model_param_groups - Initializes the weights of the optimizer
Initializes the Backbone, the Output and the Auxilary Head
differently
:param optimizer_cls: The nn.optim (optimizer class) to initialize
:param lr: lr to set for the optimizer
:param training_params:
:return: list of dictionaries with named params and optimizer attributes
"""
# OPTIMIZER PARAMETER GROUPS
params_list = []
# OPTIMIZE BACKBONE USING DIFFERENT LR
params_list.append({'named_params': self.backbone.named_parameters(), 'lr': lr})
# OPTIMIZE MAIN SHELFNET ARCHITECTURE LAYERS
params_list.append({'named_params': list(self.ladder.named_parameters()) + list(
self.decoder.named_parameters()) + list(self.se_layer.named_parameters()) + list(
self.conv_out_list.named_parameters()) + list(self.final.named_parameters()) + list(
self.aux_head.named_parameters()), 'lr': lr * 10})
return params_list
[docs]class ShelfNetLW(ShelfNetBase):
"""
ShelfNetLW - Light-Weight Implementation for ShelfNet
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.net_output_list = nn.ModuleList()
self.ladder = LadderBlockLW(planes=self.planes, layers=self.layers)
self.decoder = DecoderLW(planes=self.planes, layers=self.layers)
[docs] def forward(self, x):
H, W = x.size()[2:]
# SHELFNET LW ARCHITECTURE USES ONLY LAST 3 PARTIAL OUTPUTs OF THE BACKBONE'S 4 OUTPUT LAYERS
backbone_features_tuple = self.backbone(x)[1:]
if isinstance(self, ShelfNet18_LW):
# FOR SHELFNET18 USE 1x1 CONVS AFTER THE BACKBONE'S FORWARD PASS TO MANIPULATE THE CHANNELS FOR THE DECODER
conv_bn_relu_results_list = []
for feature, conv_bn_relu in zip(backbone_features_tuple, self.conv_out_list):
out = conv_bn_relu(feature)
conv_bn_relu_results_list.append(out)
else:
# FOR SHELFNET34 THE CHANNELS ARE ALREADY ALIGNED
conv_bn_relu_results_list = list(backbone_features_tuple)
decoder_out_list = self.decoder(conv_bn_relu_results_list)
ladder_out_list = self.ladder(decoder_out_list)
# GET THE LAST ELEMENTS OF THE LADDER_BLOCK BASED ON THE AMOUNT OF SHELVES IN THE ARCHITECTURE AND REVERSE LIST
feat_cp_list = list(reversed(ladder_out_list[(-1 * self.layers):]))
feat_out = self.net_output_list[0](feat_cp_list[0])
feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
if self.auxilary_head_outputs or self.training:
features_out_list = [feat_out]
for conv_output_layer, feat_cp in zip(self.net_output_list[1:], feat_cp_list[1:]):
feat_out_res = conv_output_layer(feat_cp)
feat_out_res = F.interpolate(feat_out_res, (H, W), mode='bilinear', align_corners=True)
features_out_list.append(feat_out_res)
return tuple(features_out_list)
else:
# THIS DOES NOT CALCULATE THE AUXILARY HEADS THAT ARE CRITICAL FOR THE LOSS (USED MAINLY FOR INFERENCE)
return feat_out
[docs] def initialize_param_groups(self, lr: float, training_params: HpmStruct) -> list:
"""
initialize_optimizer_for_model_param_groups - Initializes the optimizer group params, with 10x learning rate
for all but the backbone
:param lr: lr to set for the backbone
:param training_params:
:return: list of dictionaries with named params and optimizer attributes
"""
# OPTIMIZER PARAMETER GROUPS
params_list = []
# OPTIMIZE BACKBONE USING DIFFERENT LR
params_list.append({'named_params': self.backbone.named_parameters(), 'lr': lr})
# OPTIMIZE MAIN SHELFNET ARCHITECTURE LAYERS
params_list.append({'named_params': list(self.ladder.named_parameters()) + list(
self.decoder.named_parameters()) + list(
self.conv_out_list.named_parameters()), 'lr': lr * 10})
return params_list
[docs]class ShelfNet18_LW(ShelfNetLW):
def __init__(self, *args, **kwargs):
super().__init__(backbone=ShelfResNetBackBone18, planes=64, layers=3, *args, **kwargs)
# INITIALIZE THE net_output_list AND THE conv_out LIST
out_planes = self.planes
for i in range(self.layers):
# THE MID CHANNELS NUMBER OF THE NET OUTPUT BLOCK
mid_channels_num = self.planes if i == 0 else self.net_output_mid_channels_num
self.net_output_list.append(
NetOutput(out_planes, mid_channels_num, self.classes_num))
self.conv_out_list.append(
ConvBNReLU(out_planes * 2, out_planes, ks=1, stride=1, padding=0)
)
out_planes *= 2
[docs]class ShelfNet34_LW(ShelfNetLW):
def __init__(self, *args, **kwargs):
super().__init__(backbone=ShelfResNetBackBone34, planes=128, layers=3, *args, **kwargs)
# INITIALIZE THE net_output_list
net_out_planes = self.planes
for i in range(self.layers):
# IF IT'S THE FIRST LAYER THAN THE MID-CHANNELS NUM IS ACTUALLY self.planes
mid_channels_num = self.planes if i == 0 else self.net_output_mid_channels_num
self.net_output_list.append(
NetOutput(net_out_planes, mid_channels_num, self.classes_num))
net_out_planes *= 2
[docs]class ShelfNet503343(ShelfNetHW):
def __init__(self, *args, **kwargs):
super().__init__(backbone=ShelfResNetBackBone503343, planes=256, layers=4, *args, **kwargs)
[docs]class ShelfNet50(ShelfNetHW):
def __init__(self, *args, **kwargs):
super().__init__(backbone=ShelfResNetBackBone50, planes=256, layers=4, *args, **kwargs)
[docs]class ShelfNet101(ShelfNetHW):
def __init__(self, *args, **kwargs):
super().__init__(backbone=ShelfResNetBackBone101, planes=256, layers=4, *args, **kwargs)