# coding=utf-8
__author__ = "Dimitrios Karkalousos"
# Taken and adapted from: https://github.com/NKI-AI/direct/blob/main/direct/nn/recurrentvarnet/recurrentvarnet.py
# Copyright (c) DIRECT Contributors
from typing import Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mridc.collections.common.parts.fft import fft2c, ifft2c
from mridc.collections.common.parts.utils import complex_conj, complex_mul
from mridc.collections.reconstruction.models.recurrentvarnet.conv2gru import Conv2dGRU
[docs]class RecurrentInit(nn.Module):
"""
Recurrent State Initializer (RSI) module of Recurrent Variational Network as presented in Yiasemis, George, et al.
The RSI module learns to initialize the recurrent hidden state :math:`h_0`, input of the first
RecurrentVarNetBlock of the RecurrentVarNet.
References
----------
..
Yiasemis, George, et al. “Recurrent Variational Network: A Deep Learning Inverse Problem Solver Applied to \
the Task of Accelerated MRI Reconstruction.” ArXiv:2111.09639 [Physics], Nov. 2021. arXiv.org, \
http://arxiv.org/abs/2111.09639.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
channels: Tuple[int, ...],
dilations: Tuple[int, ...],
depth: int = 2,
multiscale_depth: int = 1,
):
"""
Inits RecurrentInit.
Parameters
----------
in_channels: Input channels.
int
out_channels: Number of hidden channels of the recurrent unit of RecurrentVarNet Block.
int
channels: Channels :math:`n_d` in the convolutional layers of initializer.
Tuple[int, ...]
dilations: Dilations :math:`p` of the convolutional layers of the initializer.
Tuple[int, ...]
depth: RecurrentVarNet Block number of layers :math:`n_l`.
int
multiscale_depth: Number of feature layers to aggregate for the output, if 1, multi-scale context aggregation
is disabled.
int
"""
super().__init__()
self.conv_blocks = nn.ModuleList()
self.out_blocks = nn.ModuleList()
self.depth = depth
self.multiscale_depth = multiscale_depth
tch = in_channels
for (curr_channels, curr_dilations) in zip(channels, dilations):
block = [
nn.ReplicationPad2d(curr_dilations),
nn.Conv2d(tch, curr_channels, 3, padding=0, dilation=curr_dilations),
]
tch = curr_channels
self.conv_blocks.append(nn.Sequential(*block))
tch = np.sum(channels[-multiscale_depth:])
for _ in range(depth):
block = [nn.Conv2d(tch, out_channels, 1, padding=0)]
self.out_blocks.append(nn.Sequential(*block))
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Computes initialization for recurrent unit given input `x`.
Parameters
----------
x: Initialization for RecurrentInit.
Returns
-------
Initial recurrent hidden state from input `x`.
"""
features = []
for block in self.conv_blocks:
x = F.relu(block(x), inplace=True)
if self.multiscale_depth > 1:
features.append(x)
if self.multiscale_depth > 1:
x = torch.cat(features[-self.multiscale_depth :], dim=1)
output_list = []
for block in self.out_blocks:
y = F.relu(block(x), inplace=True)
output_list.append(y)
return torch.stack(output_list, dim=-1)
[docs]class RecurrentVarNetBlock(nn.Module):
"""
Recurrent Variational Network Block :math:`\mathcal{H}_{\theta_{t}}` as presented in Yiasemis, George, et al.
References
----------
..
Yiasemis, George, et al. “Recurrent Variational Network: A Deep Learning Inverse Problem Solver Applied to \
the Task of Accelerated MRI Reconstruction.” ArXiv:2111.09639 [Physics], Nov. 2021. arXiv.org, \
http://arxiv.org/abs/2111.09639.
"""
def __init__(
self,
in_channels: int = 2,
hidden_channels: int = 64,
num_layers: int = 4,
fft_type: str = "orthogonal",
):
"""
Inits RecurrentVarNetBlock.
Parameters
----------
in_channels: Input channel number.
int, Default is 2 for complex data.
hidden_channels: Hidden channels.
int, Default: 64.
num_layers: Number of layers of :math:`n_l` recurrent unit.
int, Default: 4.
fft_type: FFT type.
str, Default: "orthogonal".
"""
super().__init__()
self.fft_type = fft_type
self.learning_rate = nn.Parameter(torch.tensor([1.0])) # :math:`\alpha_t`
self.regularizer = Conv2dGRU(
in_channels=in_channels,
hidden_channels=hidden_channels,
num_layers=num_layers,
replication_padding=True,
) # Recurrent Unit of RecurrentVarNet Block :math:`\mathcal{H}_{\theta_t}`
[docs] def forward(
self,
current_kspace: torch.Tensor,
masked_kspace: torch.Tensor,
sampling_mask: torch.Tensor,
sensitivity_map: torch.Tensor,
hidden_state: Union[None, torch.Tensor],
coil_dim: int = 1,
complex_dim: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Computes forward pass of RecurrentVarNetBlock.
Parameters
----------
current_kspace: Current k-space prediction.
torch.Tensor, shape [batch_size, n_coil, height, width, 2]
masked_kspace: Subsampled k-space.
torch.Tensor, shape [batch_size, n_coil, height, width, 2]
sampling_mask: Sampling mask.
torch.Tensor, shape [batch_size, 1, height, width, 1]
sensitivity_map: Coil sensitivities.
torch.Tensor, shape [batch_size, n_coil, height, width, 2]
hidden_state: ConvGRU hidden state.
None or torch.Tensor, shape [batch_size, n_l, height, width, hidden_channels]
coil_dim: Coil dimension.
int, Default: 1.
complex_dim: Complex dimension.
int, Default: -1.
Returns
-------
new_kspace: New k-space prediction.
torch.Tensor, shape [batch_size, n_coil, height, width, 2]
hidden_state: Next hidden state.
list of torch.Tensor, shape [batch_size, hidden_channels, height, width, num_layers]
"""
kspace_error = torch.where(
sampling_mask == 0,
torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device),
current_kspace - masked_kspace,
)
recurrent_term = torch.cat(
[
complex_mul(ifft2c(kspace, fft_type=self.fft_type), complex_conj(sensitivity_map)).sum(coil_dim)
for kspace in torch.split(current_kspace, 2, complex_dim)
],
dim=complex_dim,
).permute(0, 3, 1, 2)
recurrent_term, hidden_state = self.regularizer(recurrent_term, hidden_state) # :math:`w_t`, :math:`h_{t+1}`
recurrent_term = recurrent_term.permute(0, 2, 3, 1)
recurrent_term = torch.cat(
[
fft2c(complex_mul(image.unsqueeze(coil_dim), sensitivity_map), fft_type=self.fft_type)
for image in torch.split(recurrent_term, 2, complex_dim)
],
dim=complex_dim,
)
new_kspace = current_kspace - self.learning_rate * kspace_error + recurrent_term
return new_kspace, hidden_state