--- title: Layers keywords: fastai sidebar: home_sidebar summary: "Helper function used to build PyTorch timeseries models." description: "Helper function used to build PyTorch timeseries models." nb_path: "nbs/100_models.layers.ipynb" ---
bs = 2
c_in = 3
c_out = 5
h = 16
w = 20
t = torch.rand(bs, c_in, h, w)
test_eq(Conv2dSame(c_in, c_out, ks=3, stride=1, dilation=1, bias=False)(t).shape, (bs, c_out, h, w))
test_eq(Conv2dSame(c_in, c_out, ks=(3, 1), stride=1, dilation=1, bias=False)(t).shape, (bs, c_out, h, w))
test_eq(Conv2dSame(c_in, c_out, ks=3, stride=(1, 1), dilation=(2, 2), bias=False)(t).shape, (bs, c_out, h, w))
test_eq(Conv2dSame(c_in, c_out, ks=3, stride=(2, 2), dilation=(1, 1), bias=False)(t).shape, (bs, c_out, h//2, w//2))
test_eq(Conv2dSame(c_in, c_out, ks=3, stride=(2, 2), dilation=(2, 2), bias=False)(t).shape, (bs, c_out, h//2, w//2))
test_eq(Conv2d(c_in, c_out, ks=3, padding='same', stride=1, dilation=1, bias=False)(t).shape, (bs, c_out, h, w))
bs = 2
c_in = 3
c_out = 5
seq_len = 512
t = torch.rand(bs, c_in, seq_len)
dilation = 1
test_eq(CausalConv1d(c_in, c_out, ks=3, dilation=dilation)(t).shape, Conv1d(c_in, c_out, ks=3, padding="same", dilation=dilation)(t).shape)
dilation = 2
test_eq(CausalConv1d(c_in, c_out, ks=3, dilation=dilation)(t).shape, Conv1d(c_in, c_out, ks=3, padding="same", dilation=dilation)(t).shape)
bs = 2
ni = 3
nf = 5
seq_len = 6
ks = 3
t = torch.rand(bs, c_in, seq_len)
test_eq(Conv1d(ni, nf, ks, padding=0)(t).shape, (bs, c_out, seq_len - (2 * (ks//2))))
test_eq(Conv1d(ni, nf, ks, padding='valid')(t).shape, (bs, c_out, seq_len - (2 * (ks//2))))
test_eq(Conv1d(ni, nf, ks, padding='same')(t).shape, (bs, c_out, seq_len))
test_eq(Conv1d(ni, nf, ks, padding='causal')(t).shape, (bs, c_out, seq_len))
test_error('use kernel_size or ks but not both simultaneously', Conv1d, ni, nf, kernel_size=3, ks=3)
test_error('you need to pass a ks', Conv1d, ni, nf)
conv = Conv1d(ni, nf, ks, padding='same')
init_linear(conv, None, init='auto', bias_std=.01)
conv
conv = Conv1d(ni, nf, ks, padding='causal')
init_linear(conv, None, init='auto', bias_std=.01)
conv
conv = Conv1d(ni, nf, ks, padding='valid')
init_linear(conv, None, init='auto', bias_std=.01)
weight_norm(conv)
conv
conv = Conv1d(ni, nf, ks, padding=0)
init_linear(conv, None, init='auto', bias_std=.01)
weight_norm(conv)
conv
bs = 64
c_in = 6
c_out = 5
seq_len = 512
t = torch.rand(bs, c_in, seq_len)
test_eq(SeparableConv1d(c_in, c_out, 3)(t).shape, (bs, c_out, seq_len))
bs = 2
c_in = 3
c_out = 5
seq_len = 50
t = torch.rand(bs, c_in, seq_len)
t = (t - t.mean()) / t.std()
test_eq(AddCoords1d()(t).shape, (bs, c_in + 1, seq_len))
new_t = AddCoords1d()(t)
test_close(new_t.mean(),0, 1e-2)
test_close(new_t.std(), 1, 1e-2)
t = torch.rand(8, 32, 12)
test_eq(SEModule1d(t.shape[1], 16, act=nn.ReLU, act_kwargs={})(t).shape, t.shape)
bs = 2
ni = 3
nf = 5
sl = 4
ks = 5
t = torch.rand(bs, ni, sl)
test_eq(ConvBlock(ni, nf, ks)(t).shape, (bs, nf, sl))
test_eq(ConvBlock(ni, nf, ks, padding='causal')(t).shape, (bs, nf, sl))
test_eq(ConvBlock(ni, nf, ks, coord=True)(t).shape, (bs, nf, sl))
test_eq(BN1d(ni)(t).shape, (bs, ni, sl))
test_eq(BN1d(ni).weight.data.mean().item(), 1.)
test_eq(BN1d(ni, zero_norm=True).weight.data.mean().item(), 0.)
test_eq(ConvBlock(ni, nf, ks, norm='batch', zero_norm=True)[1].weight.data.unique().item(), 0)
test_ne(ConvBlock(ni, nf, ks, norm='batch', zero_norm=False)[1].weight.data.unique().item(), 0)
test_eq(ConvBlock(ni, nf, ks, bias=False)[0].bias, None)
ConvBlock(ni, nf, ks, act=Swish, coord=True)
LinLnDrop(2, 3, p=.5)
bs = 2
nf = 5
sl = 4
t = torch.rand(bs, nf, sl)
test_eq(Permute(0,2,1)(t).shape, (bs, sl, nf))
test_eq(Max(1)(t).shape, (bs, sl))
test_eq(Transpose(1,2)(t).shape, (bs, sl, nf))
test_eq(Transpose(1,2, contiguous=True)(t).shape, (bs, sl, nf))
test_eq(View(-1, 2, 10)(t).shape, (bs, 1, 2, 10))
test_eq(Reshape(-1, 2, 10)(t).shape, (bs, 1, 2, 10))
Transpose(1,2), Permute(0,2,1), View(-1, 2, 10), Transpose(1,2, contiguous=True), Reshape(-1, 2, 10), Noop
t = torch.ones(100,2,3)
test_eq(DropPath(0.)(t), t)
assert DropPath(0.5)(t).max() >= 1
n_samples = 1000
n_classes = 3
t = (torch.rand(n_samples, n_classes) - .5) * 10
probas = F.softmax(t, -1)
sharpened_probas = Sharpen()(probas)
plt.plot(probas.flatten().sort().values, color='r')
plt.plot(sharpened_probas.flatten().sort().values, color='b')
plt.show()
test_gt(sharpened_probas[n_samples//2:].max(-1).values.sum().item(), probas[n_samples//2:].max(-1).values.sum().item())
bs = 2
c_out = 3
t = torch.rand(bs, c_out)
for calibrator, cal_name in zip(['temp', 'vector', 'matrix'], ['Temp_Scale', 'Vector_Scale', 'Matrix_Scale']):
cal = get_calibrator(calibrator, n_classes=c_out)
# print(calibrator)
# print(cal.weight, cal.bias, '\n')
test_eq(cal(t), t)
test_eq(cal.__class__.__name__, cal_name)
for calibrator, cal_name in zip(['dtemp', 'dvector', 'dmatrix'], ['Temp_Scale', 'Vector_Scale', 'Matrix_Scale']):
cal = get_calibrator(calibrator, n_classes=c_out)
# print(calibrator)
# print(cal.weight, cal.bias, '\n')
test_eq(cal(t), F.log_softmax(t, dim=1))
test_eq(cal.__class__.__name__, cal_name)
bs = 2
c_out = 3
t = torch.rand(bs, c_out)
test_eq(Temp_Scale()(t).shape, t.shape)
test_eq(Vector_Scale(c_out)(t).shape, t.shape)
test_eq(Matrix_Scale(c_out)(t).shape, t.shape)
test_eq(Temp_Scale(dirichlet=True)(t).shape, t.shape)
test_eq(Vector_Scale(c_out, dirichlet=True)(t).shape, t.shape)
test_eq(Matrix_Scale(c_out, dirichlet=True)(t).shape, t.shape)
test_eq(Temp_Scale()(t), t)
test_eq(Vector_Scale(c_out)(t), t)
test_eq(Matrix_Scale(c_out)(t), t)
bs = 2
c_out = 5
t = torch.rand(bs, c_out)
test_eq(Vector_Scale(c_out)(t), t)
test_eq(Vector_Scale(c_out).weight.data, torch.ones(c_out))
test_eq(Vector_Scale(c_out).weight.requires_grad, True)
test_eq(type(Vector_Scale(c_out).weight), torch.nn.parameter.Parameter)
bs = 2
c_out = 3
weight = 2
bias = 1
t = torch.rand(bs, c_out)
test_eq(Matrix_Scale(c_out)(t).shape, t.shape)
test_eq(Matrix_Scale(c_out).weight.requires_grad, True)
test_eq(type(Matrix_Scale(c_out).weight), torch.nn.parameter.Parameter)
bs, n_classes = 16, 3
class_priors = torch.rand(n_classes)
logits = torch.randn(bs, n_classes) * 2
test_eq(LogitAdjLayer(class_priors)(logits), logits + class_priors)
bs = 2
nf = 5
sl = 4
t = torch.rand(bs, nf, sl)
test_eq(MaxPPVPool1d()(t).shape, (bs, nf*2, 1))
test_eq(MaxPPVPool1d()(t).shape, AdaptiveConcatPool1d(1)(t).shape)
t = torch.randn(16, 64, 50)
head = gwa_pool_head(64, 5, 50)
test_eq(head(t).shape, (16, 5))
bs, c_in, seq_len = 16, 1, 50
c_out = 3
t = torch.rand(bs, c_in, seq_len)
test_eq(GAP1d()(t).shape, (bs, c_in))
test_eq(GACP1d()(t).shape, (bs, c_in*2))
bs, c_in, seq_len = 16, 4, 50
t = torch.rand(bs, c_in, seq_len)
test_eq(GAP1d()(t).shape, (bs, c_in))
test_eq(GACP1d()(t).shape, (bs, c_in*2))
test_eq(GAWP1d(c_in, seq_len, n_layers=2, ln=False, dropout=0.5, act=nn.ReLU(), zero_init=False)(t).shape, (bs, c_in))
test_eq(GAWP1d(c_in, seq_len, n_layers=2, ln=False, dropout=0.5, act=nn.ReLU(), zero_init=False)(t).shape, (bs, c_in))
test_eq(GAWP1d(c_in, seq_len, n_layers=1, ln=False, dropout=0.5, zero_init=False)(t).shape, (bs, c_in))
test_eq(GAWP1d(c_in, seq_len, n_layers=1, ln=False, dropout=0.5, zero_init=True)(t).shape, (bs, c_in))
test_eq(AttentionalPool1d(c_in, c_out)(t).shape, (bs, c_out, 1))
bs, c_in, seq_len = 16, 128, 50
c_out = 14
t = torch.rand(bs, c_in, seq_len)
attp = attentional_pool_head(c_in, c_out)
test_eq(attp(t).shape, (bs, c_out))
test_eq(get_act_fn(nn.ReLU).__repr__(), "ReLU()")
test_eq(get_act_fn(nn.ReLU()).__repr__(), "ReLU()")
test_eq(get_act_fn(nn.LeakyReLU, negative_slope=0.05).__repr__(), "LeakyReLU(negative_slope=0.05)")
test_eq(get_act_fn('reglu').__repr__(), "ReGLU()")
test_eq(get_act_fn('leakyrelu', negative_slope=0.05).__repr__(), "LeakyReLU(negative_slope=0.05)")
bs = 16
nf = 12
c_out = 2
seq_len = 20
t = torch.rand(bs, nf, seq_len)
test_eq(create_pool_head(nf, c_out, seq_len, fc_dropout=0.5)(t).shape, (bs, c_out))
test_eq(create_pool_head(nf, c_out, seq_len, concat_pool=True, fc_dropout=0.5)(t).shape, (bs, c_out))
create_pool_head(nf, c_out, seq_len, concat_pool=True, bn=True, fc_dropout=.5)
bs = 16
nf = 12
c_out = 2
seq_len = 20
t = torch.rand(bs, nf, seq_len)
test_eq(max_pool_head(nf, c_out, seq_len, fc_dropout=0.5)(t).shape, (bs, c_out))
bs = 16
nf = 12
c_out = 2
seq_len = 20
t = torch.rand(bs, nf, seq_len)
test_eq(create_pool_plus_head(nf, c_out, seq_len, fc_dropout=0.5)(t).shape, (bs, c_out))
test_eq(create_pool_plus_head(nf, c_out, concat_pool=True, fc_dropout=0.5)(t).shape, (bs, c_out))
create_pool_plus_head(nf, c_out, seq_len, fc_dropout=0.5)
bs = 16
nf = 12
c_out = 2
seq_len = 20
t = torch.rand(bs, nf, seq_len)
test_eq(create_conv_head(nf, c_out, seq_len)(t).shape, (bs, c_out))
test_eq(create_conv_head(nf, c_out, adaptive_size=50)(t).shape, (bs, c_out))
create_conv_head(nf, c_out, 50)
bs = 16
nf = 12
c_out = 2
seq_len = 20
t = torch.rand(bs, nf, seq_len)
test_eq(create_mlp_head(nf, c_out, seq_len, fc_dropout=0.5)(t).shape, (bs, c_out))
t = torch.rand(bs, nf, seq_len)
create_mlp_head(nf, c_out, seq_len, bn=True, fc_dropout=.5)
bs = 16
nf = 12
c_out = 2
seq_len = 20
t = torch.rand(bs, nf, seq_len)
test_eq(create_fc_head(nf, c_out, seq_len, fc_dropout=0.5)(t).shape, (bs, c_out))
create_mlp_head(nf, c_out, seq_len, bn=True, fc_dropout=.5)
bs = 16
nf = 12
c_out = 2
seq_len = 20
t = torch.rand(bs, nf, seq_len)
test_eq(create_rnn_head(nf, c_out, seq_len, fc_dropout=0.5)(t).shape, (bs, c_out))
create_rnn_head(nf, c_out, seq_len, bn=True, fc_dropout=.5)
bs = 16
nf = 12
ni = 2
seq_len = 20
t = torch.rand(bs, nf, seq_len)
head = imputation_head(nf, ni, seq_len=None, ks=1, y_range=None, fc_dropout=0.)
test_eq(head(t).shape, (bs, ni, seq_len))
head = imputation_head(nf, ni, seq_len=None, ks=1, y_range=(.3,.7), fc_dropout=0.)
test_ge(head(t).min(), .3)
test_le(head(t).max(), .7)
y_range = (tensor([0.1000, 0.1000, 0.1000, 0.1000, 0.2000, 0.2000, 0.2000, 0.2000, 0.3000,
0.3000, 0.3000, 0.3000]),
tensor([0.6000, 0.6000, 0.6000, 0.6000, 0.7000, 0.7000, 0.7000, 0.7000, 0.8000,
0.8000, 0.8000, 0.8000]))
test_ge(head(t).min(), .1)
test_le(head(t).max(), .9)
head = imputation_head(nf, ni, seq_len=None, ks=1, y_range=y_range, fc_dropout=0.)
head
bs = 16
nf = 32
c = 5
seq_len = 10
d = 2
targ = torch.randint(0, c, (bs,d))
t = torch.randn(bs, nf, seq_len)
head = conv_lin_nd_head(nf, c, seq_len, d, conv_first=True, fc_dropout=.5)
inp = head(t)
test_eq(inp.shape, (bs, d, c))
loss = CrossEntropyLossFlat()(inp, targ)
loss, head
bs = 16
nf = 32
c = 5
seq_len = 10
d = [2, 8]
targ = torch.randint(0, c, [bs]+d)
t = torch.randn(bs, nf, seq_len)
head = conv_lin_nd_head(nf, c, seq_len, d, conv_first=False, fc_dropout=.5)
inp = head(t)
test_eq(inp.shape, [bs]+d+[c])
loss = CrossEntropyLossFlat()(inp, targ)
loss, head
bs = 16
nf = 32
c = 1
seq_len = 10
d = 2
targ = torch.rand(bs, d)
t = torch.randn(bs, nf, seq_len)
head = conv_lin_nd_head(nf, c, seq_len, d, conv_first=False, fc_dropout=.5)
inp = head(t)
test_eq(inp.shape, (bs, d))
loss = L1LossFlat()(inp, targ)
loss, head
bs = 16
nf = 32
c = 1
seq_len = 10
d = [2,3]
targ = torch.rand(bs, *d)
t = torch.randn(bs, nf, seq_len)
head = conv_lin_nd_head(nf, c, seq_len, d, conv_first=False, fc_dropout=.5)
inp = head(t)
test_eq(inp.shape, [bs]+d)
loss = L1LossFlat()(inp, targ)
loss, head
bs = 16
nf = 32
c = 5
seq_len = 10
d = 2
targ = torch.randint(0, c, (bs,d))
t = torch.randn(bs, nf, seq_len)
head = lin_nd_head(nf, c, seq_len, d, fc_dropout=.5)
inp = head(t)
test_eq(inp.shape, (bs, d, c))
loss = CrossEntropyLossFlat()(inp, targ)
loss, head
bs = 16
nf = 32
c = 5
seq_len = 10
d = [2, 8]
targ = torch.randint(0, c, [bs]+d)
t = torch.randn(bs, nf, seq_len)
head = lin_nd_head(nf, c, seq_len, d, fc_dropout=.5)
inp = head(t)
test_eq(inp.shape, [bs]+d+[c])
loss = CrossEntropyLossFlat()(inp, targ)
loss, head
bs = 16
nf = 32
c = 1
seq_len = 10
d = 2
targ = torch.rand(bs, d)
t = torch.randn(bs, nf, seq_len)
head = lin_nd_head(nf, c, seq_len, d, fc_dropout=.5)
inp = head(t)
test_eq(inp.shape, (bs, d))
loss = L1LossFlat()(inp, targ)
loss, head
bs = 16
nf = 32
c = 1
seq_len = 10
d = [2,3]
targ = torch.rand(bs, *d)
t = torch.randn(bs, nf, seq_len)
head = lin_nd_head(nf, c, seq_len, d, fc_dropout=.5)
inp = head(t)
test_eq(inp.shape, [bs]+d)
loss = L1LossFlat()(inp, targ)
loss, head
bs = 16
nf = 32
c = 5
seq_len = 10
d = 10
targ = torch.randint(0, c, (bs,d))
t = torch.randn(bs, nf, seq_len)
head = conv_3d_head(nf, c, seq_len, d)
inp = head(t)
test_eq(inp.shape, (bs, d, c))
loss = CrossEntropyLossFlat()(inp, targ)
loss, head
bs = 16
nf = 32
c = 1
seq_len = 10
d = 10
targ = torch.rand(bs, d)
t = torch.randn(bs, nf, seq_len)
head = conv_3d_head(nf, c, seq_len, d)
inp = head(t)
test_eq(inp.shape, (bs, d))
loss = L1LossFlat()(inp, targ)
loss, head
bs, c_in, seq_len = 16, 128, 50
c_out = 14
t = torch.rand(bs, c_in, seq_len)
uph = universal_pool_head(c_in, c_out, seq_len)
test_eq(uph(t).shape, (bs, c_out))
uph = universal_pool_head(c_in, c_out, seq_len, 2)
test_eq(uph(t).shape, (bs, c_out))
bs, c_in, seq_len = 16, 128, 50
c_out = 14
d = 5
t = torch.rand(bs, c_in, seq_len)
for head in heads:
print(head.__name__)
if head.__name__ == "create_conv_3d_head":
h = head(c_in, c_out, seq_len, seq_len)
test_eq(h(t).shape, (bs, seq_len, c_out))
elif 'nd' in head.__name__:
h = head(c_in, c_out, seq_len, d)
test_eq(h(t).shape, (bs, d, c_out))
else:
h = head(c_in, c_out, seq_len)
test_eq(h(t).shape, (bs, c_out))
bs = 2
ni = 32
sl = 4
t = torch.rand(bs, ni, sl)
test_eq(SqueezeExciteBlock(ni)(t).shape, (bs, ni, sl))
t = torch.ones(2,3,4)
test_ne(GaussianNoise()(t), t)
test_eq(GaussianNoise()(t).shape, t.shape)
t = torch.ones(2,3)
test_ne(GaussianNoise()(t), t)
test_eq(GaussianNoise()(t).shape, t.shape)
t = torch.ones(2)
test_ne(GaussianNoise()(t), t)
test_eq(GaussianNoise()(t).shape, t.shape)
ttest_tensor(a, b)
t = torch.randn(2,3,10)
m = PositionwiseFeedForward(10, dropout=0., act='reglu', mlp_ratio=1)
test_eq(m(t).shape, t.shape)
m = PositionwiseFeedForward(10, dropout=0., act='smelu', mlp_ratio=1)
test_eq(m(t).shape, t.shape)
B = 16
C = 10
M = 1500 # seq_len
n_heads = 1
D = 128 # model dimension
N = 512 # max_seq_len - latent's index dimension
d_k = D // n_heads
xb = torch.randn(B, C, M)
xb = (xb - xb.mean()) / xb.std()
# Attention
# input (Q)
lin = nn.Linear(M, N, bias=False)
Q = lin(xb).transpose(1,2)
test_eq(Q.shape, (B, N, C))
# q
to_q = nn.Linear(C, D, bias=False)
q = to_q(Q)
q = nn.LayerNorm(D)(q)
# k, v
context = xb.transpose(1,2)
to_kv = nn.Linear(C, D * 2, bias=False)
k, v = to_kv(context).chunk(2, dim = -1)
k = k.transpose(-1, -2)
k = nn.LayerNorm(M)(k)
v = nn.LayerNorm(D)(v)
test_eq(q.shape, (B, N, D))
test_eq(k.shape, (B, D, M))
test_eq(v.shape, (B, M, D))
output, attn, scores = ScaledDotProductAttention(D, n_heads, res_attention=True)(q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1))
test_eq(output.shape, (B, 1, N, D))
test_eq(attn.shape, (B, 1, N, M))
test_eq(scores.shape, (B, 1, N, M))
scores.mean(), scores.std()
q = torch.rand([16, 3, 50, 8])
k = torch.rand([16, 3, 50, 8]).transpose(-1, -2)
v = torch.rand([16, 3, 50, 6])
attn_mask = torch.triu(torch.ones(50, 50)) # shape: q_len x q_len
key_padding_mask = torch.zeros(16, 50)
key_padding_mask[[1, 3, 6, 15], -10:] = 1
key_padding_mask = key_padding_mask.bool()
print('attn_mask', attn_mask.shape, 'key_padding_mask', key_padding_mask.shape)
output, attn = ScaledDotProductAttention(24, 3, attn_dropout=.1)(q, k, v, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
output.shape, attn.shape
t = torch.rand(16, 50, 128)
output, attn = MultiheadAttention(d_model=128, n_heads=3, d_k=8, d_v=6)(t, t, t, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
output.shape, attn.shape
Test multi-head attention with self-locality attention
t = torch.rand(16, 50, 128)
attn_mask = torch.eye(50).reshape(1, 1, 50, 50).bool()
output, attn = MultiheadAttention(d_model=128, n_heads=8, lsa=True)(t, t, t, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
output.shape, attn.shape
t = torch.rand(16, 50, 128)
att_mask = (torch.rand((50, 50)) > .85).float()
att_mask[att_mask == 1] = -np.inf
mha = MultiheadAttention(d_model=128, n_heads=3, d_k=8, d_v=6)
output, attn = mha(t, t, t, attn_mask=att_mask)
test_eq(torch.isnan(output).sum().item(), 0)
test_eq(torch.isnan(attn).sum().item(), 0)
loss = output[:2, :].sum()
test_eq(torch.isnan(loss).sum().item(), 0)
loss.backward()
for n, p in mha.named_parameters():
if p.grad is not None:
test_eq(torch.isnan(p.grad).sum().item(), 0)
t = torch.rand(16, 50, 128)
attn_mask = (torch.rand((50, 50)) > .85)
# True values will be masked
mha = MultiheadAttention(d_model=128, n_heads=3, d_k=8, d_v=6)
output, attn = mha(t, t, t, attn_mask=att_mask)
test_eq(torch.isnan(output).sum().item(), 0)
test_eq(torch.isnan(attn).sum().item(), 0)
loss = output[:2, :].sum()
test_eq(torch.isnan(loss).sum().item(), 0)
loss.backward()
for n, p in mha.named_parameters():
if p.grad is not None:
test_eq(torch.isnan(p.grad).sum().item(), 0)
t = torch.rand(16, 6, 37)
test_eq(MultiConv1d(6, None, kss=[1,3,5], keep_original=True)(t).shape, [16, 24, 37])
test_eq(MultiConv1d(6, 36, kss=[1,3,5], keep_original=False)(t).shape, [16, 36, 37])
test_eq(MultiConv1d(6, None, kss=[1,3,5], keep_original=True, dim=-1)(t).shape, [16, 6, 37*4])
test_eq(MultiConv1d(6, 60, kss=[1,3,5], keep_original=True)(t).shape, [16, 60, 37])
test_eq(MultiConv1d(6, 60, kss=[1,3,5], separable=True)(t).shape, [16, 60, 37])
t = ([1], [2], [3])
test_eq(LSTMOutput()(t), [1])
test_eq(emb_sz_rule(7), 5)
a = alphabet[np.random.randint(0,3,40)]
b = ALPHABET[np.random.randint(6,10,40)]
c = np.random.rand(40).reshape(4,1,10)
map_a = {k:v for v,k in enumerate(np.unique(a))}
map_b = {k:v for v,k in enumerate(np.unique(b))}
n_embeds = [len(m.keys()) for m in [map_a, map_b]]
szs = [emb_sz_rule(n) for n in n_embeds]
a = np.asarray(a.map(map_a)).reshape(4,1,10)
b = np.asarray(b.map(map_b)).reshape(4,1,10)
inp = torch.from_numpy(np.concatenate((c,a,b), 1)).float()
memb = MultiEmbedding(3, n_embeds, cat_pos=[1,2])
# registered buffers are part of the state_dict() but not module.parameters()
assert all([(k in memb.state_dict().keys()) for k in ['cat_pos', 'cont_pos']])
embeddings = memb(inp)
print(n_embeds, szs, inp.shape, embeddings.shape)
test_eq(embeddings.shape, (inp.shape[0],sum(szs)+1,inp.shape[-1]))
me = MultiEmbedding(3, 4, cat_pos=2)
test_eq(me.cat_embed[0].weight.shape, (4,3))
test_eq(me.cat_pos.cpu().item(), 2)