Source code for distil.utils.models.mnist_net

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs]class MnistNet(nn.Module): def __init__(self): super(MnistNet, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.dropout1 = nn.Dropout2d(0.25) self.dropout2 = nn.Dropout2d(0.5) self.fc1 = nn.Linear(9216, 128) self.fc2 = nn.Linear(128, 10)
[docs] def forward(self, x, last=False): x = self.conv1(x) x = F.relu(x) x = self.conv2(x) x = F.relu(x) x = F.max_pool2d(x, 2) x = self.dropout1(x) x = torch.flatten(x, 1) x = self.fc1(x) x = F.relu(x) x = self.dropout2(x) output = self.fc2(x) if last: return output, x else: return output
[docs] def get_embedding_dim(self): return 128