Source code for distil.utils.models.dla_simple

'''Simplified version of DLA in PyTorch.

Note this implementation is not identical to the original paper version.
But it seems works fine.

See dla.py for the original paper version.

Reference:
    Deep Layer Aggregation. https://arxiv.org/abs/1707.06484
'''
import torch
import torch.nn as nn
import torch.nn.functional as F


[docs]class BasicBlock(nn.Module): expansion = 1 def __init__(self, in_planes, planes, stride=1): super(BasicBlock, self).__init__() self.conv1 = nn.Conv2d( in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.shortcut = nn.Sequential() if stride != 1 or in_planes != self.expansion*planes: self.shortcut = nn.Sequential( nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(self.expansion*planes) )
[docs] def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.shortcut(x) out = F.relu(out) return out
[docs]class Root(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=1): super(Root, self).__init__() self.conv = nn.Conv2d( in_channels, out_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2, bias=False) self.bn = nn.BatchNorm2d(out_channels)
[docs] def forward(self, xs): x = torch.cat(xs, 1) out = F.relu(self.bn(self.conv(x))) return out
[docs]class Tree(nn.Module): def __init__(self, block, in_channels, out_channels, level=1, stride=1): super(Tree, self).__init__() self.root = Root(2*out_channels, out_channels) if level == 1: self.left_tree = block(in_channels, out_channels, stride=stride) self.right_tree = block(out_channels, out_channels, stride=1) else: self.left_tree = Tree(block, in_channels, out_channels, level=level-1, stride=stride) self.right_tree = Tree(block, out_channels, out_channels, level=level-1, stride=1)
[docs] def forward(self, x): out1 = self.left_tree(x) out2 = self.right_tree(out1) out = self.root([out1, out2]) return out
[docs]class SimpleDLA(nn.Module): def __init__(self, num_classes=10, block=BasicBlock): super(SimpleDLA, self).__init__() self.embDim = 512 self.base = nn.Sequential( nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(16), nn.ReLU(True) ) self.layer1 = nn.Sequential( nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(16), nn.ReLU(True) ) self.layer2 = nn.Sequential( nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(32), nn.ReLU(True) ) self.layer3 = Tree(block, 32, 64, level=1, stride=1) self.layer4 = Tree(block, 64, 128, level=2, stride=2) self.layer5 = Tree(block, 128, 256, level=2, stride=2) self.layer6 = Tree(block, 256, 512, level=1, stride=2) self.linear = nn.Linear(512, num_classes)
[docs] def forward(self, x, last=False): out = self.base(x) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) out = self.layer5(out) out = self.layer6(out) out = F.avg_pool2d(out, 4) e = out.view(out.size(0), -1) out = self.linear(e) if last: return out, e else: return out
[docs] def get_embedding_dim(self): return self.embDim
[docs]def test(): net = SimpleDLA() print(net) x = torch.randn(1, 3, 32, 32) y = net(x) print(y.size())
if __name__ == '__main__': test()