# coding=utf-8
__author__ = "Dimitrios Karkalousos"
import torch
import torch.nn as nn
[docs]class ConvRNNStack(nn.Module):
"""A stack of convolutional RNNs."""
def __init__(self, convs, rnn):
"""
Parameters
----------
convs: list of convolutional layers
rnn: list of RNN layers
"""
super(ConvRNNStack, self).__init__()
self.convs = convs
self.rnn = rnn
[docs] def forward(self, x, hidden):
"""
Parameters
----------
x: [batch_size, seq_len, input_size]
hidden: [num_layers * num_directions, batch_size, hidden_size
Returns
-------
output: [batch_size, seq_len, hidden_size]
"""
return self.rnn(self.convs(x), hidden)
[docs]class ConvNonlinear(nn.Module):
"""A convolutional layer with nonlinearity."""
def __init__(self, input_size, features, conv_dim, kernel_size, dilation, bias, nonlinear="relu"):
"""
Initializes the convolutional layer.
Parameters
----------
input_size: number of input channels.
features: number of output channels.
conv_dim: number of dimensions of the convolutional layer.
kernel_size: size of the convolutional kernel.
dilation: dilation of the convolutional kernel.
bias: whether to use bias.
nonlinear: nonlinearity of the convolutional layer.
"""
super(ConvNonlinear, self).__init__()
self.input_size = input_size
self.features = features
self.kernel_size = kernel_size
self.dilation = dilation
self.bias = bias
self.conv_dim = conv_dim
self.conv_class = self.determine_conv_class(conv_dim)
if nonlinear is not None and nonlinear.upper() == "RELU":
self.nonlinear = torch.nn.ReLU()
elif nonlinear is None:
self.nonlinear = lambda x: x
else:
raise ValueError("Please specify a proper nonlinearity")
self.padding = [
torch.nn.ReplicationPad1d(torch.div(dilation * (kernel_size - 1), 2, rounding_mode="trunc").item()),
torch.nn.ReplicationPad2d(torch.div(dilation * (kernel_size - 1), 2, rounding_mode="trunc").item()),
torch.nn.ReplicationPad3d(torch.div(dilation * (kernel_size - 1), 2, rounding_mode="trunc").item()),
][conv_dim - 1]
self.conv_layer = self.conv_class(
in_channels=input_size,
out_channels=features,
kernel_size=kernel_size,
padding=0,
dilation=dilation,
bias=bias,
)
self.reset_parameters()
[docs] def reset_parameters(self):
"""Resets the parameters of the convolutional layer."""
torch.nn.init.kaiming_normal_(self.conv_layer.weight, nonlinearity="relu")
if self.conv_layer.bias is not None:
nn.init.zeros_(self.conv_layer.bias)
[docs] @staticmethod
def determine_conv_class(n_dim):
"""Determines the convolutional layer class."""
if n_dim == 1:
return nn.Conv1d
if n_dim == 2:
return nn.Conv2d
if n_dim == 3:
return nn.Conv3d
raise ValueError(f"Convolution of: {n_dim} dims is not implemented")
[docs] def forward(self, _input):
"""Forward pass of the convolutional layer."""
return self.nonlinear(self.conv_layer(self.padding(_input)))