Source code for distil.utils.models.simpleNN_net

import torch.nn as nn
import torch.nn.functional as F


[docs]class TwoLayerNet(nn.Module): def __init__(self, input_dim, num_classes, hidden_units): super(TwoLayerNet, self).__init__() self.linear1 = nn.Linear(input_dim, hidden_units) self.linear2 = nn.Linear(hidden_units, num_classes) self.embd_dim = hidden_units
[docs] def forward(self, x, last=False): l1scores = F.relu(self.linear1(x)) scores = self.linear2(l1scores) if last: return scores, l1scores else: return scores
[docs] def get_embedding_dim(self): return self.embd_dim
[docs]class ThreeLayerNet(nn.Module): def __init__(self, input_dim, num_classes, h1, h2): super(ThreeLayerNet, self).__init__() self.linear1 = nn.Linear(input_dim, h1) self.linear2 = nn.Linear(h1, h2) self.linear3 = nn.Linear(h2, num_classes) self.embd_dim = h2
[docs] def forward(self, x, last=False): l1scores = F.relu(self.linear1(x)) l2scores = F.relu(self.linear2(l1scores)) scores = self.linear3(l2scores) if last: return scores, l2scores else: return scores
[docs] def get_embedding_dim(self): return self.embd_dim