# coding=utf-8
__author__ = "Dimitrios Karkalousos"
# Taken and adapted from: https://github.com/NKI-AI/direct/blob/main/direct/nn/recurrent/recurrent.py
# Copyright (c) DIRECT Contributors
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
[docs]class Conv2dGRU(nn.Module):
"""2D Convolutional GRU Network."""
def __init__(
self,
in_channels: int,
hidden_channels: int,
out_channels: Optional[int] = None,
num_layers: int = 2,
gru_kernel_size=1,
orthogonal_initialization: bool = True,
instance_norm: bool = False,
dense_connect: int = 0,
replication_padding: bool = True,
):
"""
Inits Conv2dGRU.
Parameters
----------
in_channels: Number of input channels.
int
hidden_channels: Number of hidden channels.
int
out_channels: Number of output channels. If None, same as in_channels.
int (optional), Default: None.
num_layers: Number of layers.
int, Default: 2.
gru_kernel_size: Size of the GRU kernel.
int, Default: 1.
orthogonal_initialization: Orthogonal initialization is used if set to True.
bool, Default: True.
instance_norm: Instance norm is used if set to True.
bool, Default: False.
dense_connect: Number of dense connections.
replication_padding: If set to true replication padding is applied.
"""
super().__init__()
if out_channels is None:
out_channels = in_channels
self.num_layers = num_layers
self.hidden_channels = hidden_channels
self.dense_connect = dense_connect
self.reset_gates = nn.ModuleList([])
self.update_gates = nn.ModuleList([])
self.out_gates = nn.ModuleList([])
self.conv_blocks = nn.ModuleList([])
# Create convolutional blocks
for idx in range(num_layers + 1):
in_ch = in_channels if idx == 0 else (1 + min(idx, dense_connect)) * hidden_channels
out_ch = hidden_channels if idx < num_layers else out_channels
padding = 0 if replication_padding else (2 if idx == 0 else 1)
block = []
if replication_padding:
if idx == 1:
block.append(nn.ReplicationPad2d(2))
else:
block.append(nn.ReplicationPad2d(2 if idx == 0 else 1))
block.append(
nn.Conv2d(
in_channels=in_ch,
out_channels=out_ch,
kernel_size=5 if idx == 0 else 3,
dilation=(2 if idx == 1 else 1),
padding=padding,
)
)
self.conv_blocks.append(nn.Sequential(*block))
# Create GRU blocks
for _ in range(num_layers):
for gru_part in [self.reset_gates, self.update_gates, self.out_gates]:
block = []
if instance_norm:
block.append(nn.InstanceNorm2d(2 * hidden_channels))
block.append(
nn.Conv2d(
in_channels=2 * hidden_channels,
out_channels=hidden_channels,
kernel_size=gru_kernel_size,
padding=gru_kernel_size // 2,
)
)
gru_part.append(nn.Sequential(*block))
if orthogonal_initialization:
for reset_gate, update_gate, out_gate in zip(self.reset_gates, self.update_gates, self.out_gates):
nn.init.orthogonal_(reset_gate[-1].weight)
nn.init.orthogonal_(update_gate[-1].weight)
nn.init.orthogonal_(out_gate[-1].weight)
nn.init.constant_(reset_gate[-1].bias, -1.0)
nn.init.constant_(update_gate[-1].bias, 0.0)
nn.init.constant_(out_gate[-1].bias, 0.0)
[docs] def forward(
self,
cell_input: torch.Tensor,
previous_state: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Computes Conv2dGRU forward pass given tensors `cell_input` and `previous_state`.
Parameters
----------
cell_input: Reconstruction input
previous_state: Tensor of previous states.
Returns
-------
Output and new states.
"""
new_states: List[torch.Tensor] = []
conv_skip: List[torch.Tensor] = []
if previous_state is None:
batch_size, spatial_size = cell_input.size(0), (cell_input.size(2), cell_input.size(3))
state_size = [batch_size, self.hidden_channels] + list(spatial_size) + [self.num_layers]
previous_state = torch.zeros(*state_size, dtype=cell_input.dtype).to(cell_input.device)
for idx in range(self.num_layers):
if len(conv_skip) > 0:
cell_input = F.relu(
self.conv_blocks[idx](torch.cat([*conv_skip[-self.dense_connect :], cell_input], dim=1)),
inplace=True,
)
else:
cell_input = F.relu(self.conv_blocks[idx](cell_input), inplace=True)
if self.dense_connect > 0:
conv_skip.append(cell_input)
stacked_inputs = torch.cat([cell_input, previous_state[:, :, :, :, idx]], dim=1)
update = torch.sigmoid(self.update_gates[idx](stacked_inputs))
reset = torch.sigmoid(self.reset_gates[idx](stacked_inputs))
delta = torch.tanh(
self.out_gates[idx](torch.cat([cell_input, previous_state[:, :, :, :, idx] * reset], dim=1))
)
cell_input = previous_state[:, :, :, :, idx] * (1 - update) + delta * update
new_states.append(cell_input)
cell_input = F.relu(cell_input, inplace=False)
if len(conv_skip) > 0:
out = self.conv_blocks[self.num_layers](torch.cat([*conv_skip[-self.dense_connect :], cell_input], dim=1))
else:
out = self.conv_blocks[self.num_layers](cell_input)
return out, torch.stack(new_states, dim=-1)