Source code for mridc.collections.reconstruction.models.rim.rim_block

# coding=utf-8
__author__ = "Dimitrios Karkalousos"

from typing import Any, Tuple, Union

import torch

from mridc.collections.common.parts.fft import fft2c, ifft2c
from mridc.collections.common.parts.utils import complex_conj, complex_mul
from mridc.collections.reconstruction.models.rim.conv_layers import ConvNonlinear, ConvRNNStack
from mridc.collections.reconstruction.models.rim.rnn_cells import ConvGRUCell, ConvMGUCell, IndRNNCell
from mridc.collections.reconstruction.models.rim.utils import log_likelihood_gradient


[docs]class RIMBlock(torch.nn.Module): """RIMBlock is a block of Recurrent Inference Machines (RIMs).""" def __init__( self, recurrent_layer=None, conv_filters=None, conv_kernels=None, conv_dilations=None, conv_bias=None, recurrent_filters=None, recurrent_kernels=None, recurrent_dilations=None, recurrent_bias=None, depth: int = 2, time_steps: int = 8, conv_dim: int = 2, no_dc: bool = False, fft_type: str = "orthogonal", ): """ Initialize the RIMBlock. Parameters ---------- recurrent_layer: Type of recurrent layer. conv_filters: Number of filters in the convolutional layers. conv_kernels: Kernel size of the convolutional layers. conv_dilations: Dilation of the convolutional layers. conv_bias: Bias of the convolutional layers. recurrent_filters: Number of filters in the recurrent layers. recurrent_kernels: Kernel size of the recurrent layers. recurrent_dilations: Dilation of the recurrent layers. recurrent_bias: Bias of the recurrent layers. depth: Number of layers in the block. time_steps: Number of time steps in the block. conv_dim: Dimension of the convolutional layers. no_dc: If True, the DC component is removed from the input. fft_type: Type of FFT. """ super(RIMBlock, self).__init__() self.input_size = depth * 2 self.time_steps = time_steps self.layers = torch.nn.ModuleList() for ( (conv_features, conv_k_size, conv_dilation, l_conv_bias, nonlinear), (rnn_features, rnn_k_size, rnn_dilation, rnn_bias, rnn_type), ) in zip( zip(conv_filters, conv_kernels, conv_dilations, conv_bias, ["relu", "relu", None]), zip( recurrent_filters, recurrent_kernels, recurrent_dilations, recurrent_bias, [recurrent_layer, recurrent_layer, None], ), ): conv_layer = None if conv_features != 0: conv_layer = ConvNonlinear( self.input_size, conv_features, conv_dim=conv_dim, kernel_size=conv_k_size, dilation=conv_dilation, bias=l_conv_bias, nonlinear=nonlinear, ) self.input_size = conv_features if rnn_features != 0 and rnn_type is not None: if rnn_type.upper() == "GRU": rnn_type = ConvGRUCell elif rnn_type.upper() == "MGU": rnn_type = ConvMGUCell elif rnn_type.upper() == "INDRNN": rnn_type = IndRNNCell else: raise ValueError("Please specify a proper recurrent layer type.") rnn_layer = rnn_type( self.input_size, rnn_features, conv_dim=2, kernel_size=rnn_k_size, dilation=rnn_dilation, bias=rnn_bias, ) self.input_size = rnn_features self.layers.append(ConvRNNStack(conv_layer, rnn_layer)) self.final_layer = torch.nn.Sequential(conv_layer) self.recurrent_filters = recurrent_filters self.fft_type = fft_type self.no_dc = no_dc if not self.no_dc: self.dc_weight = torch.nn.Parameter(torch.ones(1)) self.zero = torch.zeros(1, 1, 1, 1, 1)
[docs] def forward( self, pred: torch.Tensor, masked_kspace: torch.Tensor, sense: torch.Tensor, mask: torch.Tensor, eta: torch.Tensor = None, hx: torch.Tensor = None, sigma: float = 1.0, keep_eta: bool = False, ) -> Tuple[Any, Union[list, torch.Tensor, None]]: """ Forward pass of the RIMBlock. Parameters ---------- pred: Predicted k-space. masked_kspace: Subsampled k-space. sense: Coil sensitivity maps. mask: Sample mask. eta: Initial guess for the eta. hx: Initial guess for the hidden state. sigma: Noise level. keep_eta: Whether to keep the eta. Returns ------- Reconstructed image and hidden states. """ if hx is None: hx = [ masked_kspace.new_zeros((masked_kspace.size(0), f, *masked_kspace.size()[2:-1])) for f in self.recurrent_filters if f != 0 ] if isinstance(pred, list): pred = pred[-1].detach() if eta is None or eta.ndim < 3: eta = ( pred if keep_eta else torch.sum( complex_mul(ifft2c(pred, fft_type=self.fft_type), complex_conj(sense)), 1, ) ) etas = [] for _ in range(self.time_steps): grad_eta = log_likelihood_gradient( eta, masked_kspace, sense, mask, sigma=sigma, fft_type=self.fft_type ).contiguous() for h, convrnn in enumerate(self.layers): hx[h] = convrnn(grad_eta, hx[h]) grad_eta = hx[h] eta = eta + self.final_layer(grad_eta).permute(0, 2, 3, 1) etas.append(eta) eta = etas if self.no_dc: return eta, None soft_dc = torch.where(mask, pred - masked_kspace, self.zero.to(masked_kspace)) * self.dc_weight current_kspace = [ masked_kspace - soft_dc - fft2c(complex_mul(e.unsqueeze(1), sense), fft_type=self.fft_type) for e in eta ] return current_kspace, None