Source code for mridc.collections.reconstruction.models.multidomain.multidomain

# coding=utf-8
__author__ = "Dimitrios Karkalousos"

# Taken and adapted from:https://github.com/NKI-AI/direct/blob/main/direct/nn/multidomainnet/multidomain.py
# Copyright (c) DIRECT Contributors

import torch
import torch.nn as nn
import torch.nn.functional as F

from mridc.collections.common.parts.fft import fft2c, ifft2c
from mridc.collections.common.parts.utils import complex_conj, complex_mul


[docs]class MultiDomainConv2d(nn.Module): """Multi-domain convolution layer.""" def __init__( self, fft_type, in_channels, out_channels, **kwargs, ): super().__init__() self.image_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels // 2, **kwargs) self.kspace_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels // 2, **kwargs) self.fft_type = fft_type self._channels_dim = 1 self._spatial_dims = [1, 2]
[docs] def forward(self, image): """Forward method for the MultiDomainConv2d class.""" kspace = [ fft2c(im, fft_type=self.fft_type, fft_dim=self._spatial_dims) for im in torch.split(image.permute(0, 2, 3, 1).contiguous(), 2, -1) ] kspace = torch.cat(kspace, -1).permute(0, 3, 1, 2) kspace = self.kspace_conv(kspace) backward = [ ifft2c(ks.float(), fft_type=self.fft_type, fft_dim=self._spatial_dims).type(image.type()) for ks in torch.split(kspace.permute(0, 2, 3, 1).contiguous(), 2, -1) ] backward = torch.cat(backward, -1).permute(0, 3, 1, 2) image = self.image_conv(image) image = torch.cat([image, backward], dim=self._channels_dim) return image
[docs]class MultiDomainConvTranspose2d(nn.Module): """Multi-Domain convolutional transpose layer.""" def __init__( self, fft_type, in_channels, out_channels, **kwargs, ): super().__init__() self.image_conv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels // 2, **kwargs) self.kspace_conv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels // 2, **kwargs) self.fft_type = fft_type self._channels_dim = 1 self._spatial_dims = [1, 2]
[docs] def forward(self, image): """Forward method for the MultiDomainConvTranspose2d class.""" kspace = [ fft2c(im, fft_type=self.fft_type, fft_dim=self._spatial_dims) for im in torch.split(image.permute(0, 2, 3, 1).contiguous(), 2, -1) ] kspace = torch.cat(kspace, -1).permute(0, 3, 1, 2) kspace = self.kspace_conv(kspace) backward = [ ifft2c(ks.float(), fft_type=self.fft_type, fft_dim=self._spatial_dims).type(image.type()) for ks in torch.split(kspace.permute(0, 2, 3, 1).contiguous(), 2, -1) ] backward = torch.cat(backward, -1).permute(0, 3, 1, 2) image = self.image_conv(image) return torch.cat([image, backward], dim=self._channels_dim)
[docs]class MultiDomainConvBlock(nn.Module): """ A multi-domain convolutional block that consists of two multi-domain convolution layers each followed by instance normalization, LeakyReLU activation and dropout. """ def __init__(self, fft_type, in_channels: int, out_channels: int, dropout_probability: float): """ Parameters ---------- in_channels: Number of input channels. out_channels: Number of output channels. dropout_probability: Dropout probability. """ super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.dropout_probability = dropout_probability self.layers = nn.Sequential( MultiDomainConv2d(fft_type, in_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.InstanceNorm2d(out_channels), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Dropout2d(dropout_probability), MultiDomainConv2d(fft_type, out_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.InstanceNorm2d(out_channels), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Dropout2d(dropout_probability), )
[docs] def forward(self, _input: torch.Tensor): """Forward method for the MultiDomainConvBlock class.""" return self.layers(_input)
def __repr__(self): return ( f"MultiDomainConvBlock(in_channels={self.in_channels}, out_channels={self.out_channels}, " f"dropout_probability={self.dropout_probability})" )
[docs]class TransposeMultiDomainConvBlock(nn.Module): """ A Transpose Convolutional Block that consists of one convolution transpose layers followed by instance normalization and LeakyReLU activation. """ def __init__(self, fft_type, in_channels: int, out_channels: int): """ Parameters ---------- in_channels: Number of input channels. out_channels: Number of output channels. """ super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.layers = nn.Sequential( MultiDomainConvTranspose2d(fft_type, in_channels, out_channels, kernel_size=2, stride=2, bias=False), nn.InstanceNorm2d(out_channels), nn.LeakyReLU(negative_slope=0.2, inplace=True), )
[docs] def forward(self, input_data: torch.Tensor): """Forward method for the TransposeMultiDomainConvBlock class.""" return self.layers(input_data)
def __repr__(self): return f"MultiDomainConvBlock(in_channels={self.in_channels}, out_channels={self.out_channels})"
[docs]class StandardizationLayer(nn.Module): """ Multi-channel data standardization method. Inspired by AIRS model submission to the Fast MRI 2020 challenge. Given individual coil images :math:`\{x_i\}_{i=1}^{N_c}` and sensitivity coil maps :math:`\{S_i\}_{i=1}^{N_c}` \ it returns .. math:: [(x_{sense}, {x_{res}}_1), ..., (x_{sense}, {x_{res}}_{N_c})] where :math:`{x_{res}}_i = xi - S_i X x_{sense}` and :math:`x_{sense} = \sum_{i=1}^{N_c} {S_i}^{*} X x_i`. """ def __init__(self, coil_dim=1, channel_dim=-1): super().__init__() self.coil_dim = coil_dim self.channel_dim = channel_dim
[docs] def forward(self, coil_images: torch.Tensor, sensitivity_map: torch.Tensor) -> torch.Tensor: """Forward pass.""" combined_image = complex_mul(coil_images, complex_conj(sensitivity_map)).sum(self.coil_dim) residual_image = combined_image.unsqueeze(self.coil_dim) - complex_mul( combined_image.unsqueeze(self.coil_dim), sensitivity_map ) return torch.cat( [ torch.cat( [combined_image, residual_image.select(self.coil_dim, idx)], self.channel_dim, ).unsqueeze(self.coil_dim) for idx in range(coil_images.size(self.coil_dim)) ], self.coil_dim, )
[docs]class MultiDomainUnet2d(nn.Module): """ Unet modification to be used with Multi-domain network as in AIRS Medical submission to the Fast MRI 2020 challenge. """ def __init__( self, in_channels: int, out_channels: int, num_filters: int, num_pool_layers: int, dropout_probability: float, fft_type: str = "orthogonal", ): """ Parameters ---------- in_channels: Number of input channels to the u-net. out_channels: Number of output channels to the u-net. num_filters: Number of output channels of the first convolutional layer. num_pool_layers: Number of down-sampling and up-sampling layers (depth). dropout_probability: Dropout probability. fft_type: FFT type. """ super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.num_filters = num_filters self.num_pool_layers = num_pool_layers self.dropout_probability = dropout_probability self.fft_type = fft_type self.down_sample_layers = nn.ModuleList( [MultiDomainConvBlock(fft_type, in_channels, num_filters, dropout_probability)] ) ch = num_filters for _ in range(num_pool_layers - 1): self.down_sample_layers += [MultiDomainConvBlock(fft_type, ch, ch * 2, dropout_probability)] ch *= 2 self.conv = MultiDomainConvBlock(fft_type, ch, ch * 2, dropout_probability) self.up_conv = nn.ModuleList() self.up_transpose_conv = nn.ModuleList() for _ in range(num_pool_layers - 1): self.up_transpose_conv += [TransposeMultiDomainConvBlock(fft_type, ch * 2, ch)] self.up_conv += [MultiDomainConvBlock(fft_type, ch * 2, ch, dropout_probability)] ch //= 2 self.up_transpose_conv += [TransposeMultiDomainConvBlock(fft_type, ch * 2, ch)] self.up_conv += [ nn.Sequential( MultiDomainConvBlock(fft_type, ch * 2, ch, dropout_probability), nn.Conv2d(ch, self.out_channels, kernel_size=1, stride=1), ) ]
[docs] def forward(self, input_data: torch.Tensor): """Forward pass of the u-net.""" stack = [] output = input_data # Apply down-sampling layers for layer in self.down_sample_layers: output = layer(output) stack.append(output) output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) output = self.conv(output) # Apply up-sampling layers for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): downsample_layer = stack.pop() output = transpose_conv(output) # Reflect pad on the right/bottom if needed to handle odd input dimensions. padding = [0, 0, 0, 0] if output.shape[-1] != downsample_layer.shape[-1]: padding[1] = 1 # Padding right if output.shape[-2] != downsample_layer.shape[-2]: padding[3] = 1 # Padding bottom if sum(padding) != 0: output = F.pad(output, padding, "reflect") output = torch.cat([output, downsample_layer], dim=1) output = conv(output) return output