Source code for mridc.collections.reconstruction.models.conv.conv2d

# coding=utf-8
__author__ = "Dimitrios Karkalousos"

# Taken and adapted from: https://github.com/NKI-AI/direct/blob/main/direct/nn/conv/conv.py
# Copyright (c) DIRECT Contributors

import torch.nn as nn


[docs]class Conv2d(nn.Module): """ Implementation of a simple cascade of 2D convolutions. If batchnorm is set to True, batch normalization layer is applied after each convolution. """ def __init__(self, in_channels, out_channels, hidden_channels, n_convs=3, activation=nn.PReLU(), batchnorm=False): """ Inits Conv2d. Parameters ---------- in_channels: Number of input channels. int out_channels: Number of output channels. int hidden_channels: Number of hidden channels. int n_convs: Number of convolutional layers. int activation: Activation function. torch.nn.Module batchnorm: If True a batch normalization layer is applied after every convolution. bool """ super().__init__() self.conv = [] for idx in range(n_convs): self.conv.append( nn.Conv2d( in_channels if idx == 0 else hidden_channels, hidden_channels if idx != n_convs - 1 else out_channels, kernel_size=3, padding=1, ) ) if batchnorm: self.conv.append(nn.BatchNorm2d(hidden_channels if idx != n_convs - 1 else out_channels, eps=1e-4)) if idx != n_convs - 1: self.conv.append(activation) self.conv = nn.Sequential(*self.conv)
[docs] def forward(self, x): """ Performs the forward pass of Conv2d. Parameters ---------- x: Input tensor. Returns ------- Convoluted output. """ if x.dim() == 5: x = x.squeeze(1) if x.shape[-1] == 2: x = x.permute(0, 3, 1, 2) return self.conv(x)