mridc.collections.reconstruction.models.recurrentvarnet package

Submodules

mridc.collections.reconstruction.models.recurrentvarnet.conv2gru module

class mridc.collections.reconstruction.models.recurrentvarnet.conv2gru.Conv2dGRU(in_channels: int, hidden_channels: int, out_channels: Optional[int] = None, num_layers: int = 2, gru_kernel_size=1, orthogonal_initialization: bool = True, instance_norm: bool = False, dense_connect: int = 0, replication_padding: bool = True)[source]

Bases: torch.nn.modules.module.Module

2D Convolutional GRU Network.

forward(cell_input: torch.Tensor, previous_state: torch.Tensor) Tuple[torch.Tensor, torch.Tensor][source]

Computes Conv2dGRU forward pass given tensors cell_input and previous_state.

Parameters
  • cell_input (Reconstruction input) –

  • previous_state (Tensor of previous states.) –

Return type

Output and new states.

training: bool

mridc.collections.reconstruction.models.recurrentvarnet.recurentvarnet module

class mridc.collections.reconstruction.models.recurrentvarnet.recurentvarnet.RecurrentInit(in_channels: int, out_channels: int, channels: Tuple[int, ...], dilations: Tuple[int, ...], depth: int = 2, multiscale_depth: int = 1)[source]

Bases: torch.nn.modules.module.Module

Recurrent State Initializer (RSI) module of Recurrent Variational Network as presented in Yiasemis, George, et al. The RSI module learns to initialize the recurrent hidden state \(h_0\), input of the first RecurrentVarNetBlock of the RecurrentVarNet.

References

Yiasemis, George, et al. “Recurrent Variational Network: A Deep Learning Inverse Problem Solver Applied to the Task of Accelerated MRI Reconstruction.” ArXiv:2111.09639 [Physics], Nov. 2021. arXiv.org, http://arxiv.org/abs/2111.09639.

forward(x: torch.Tensor) torch.Tensor[source]

Computes initialization for recurrent unit given input x.

Parameters

x (Initialization for RecurrentInit.) –

Return type

Initial recurrent hidden state from input x.

training: bool
class mridc.collections.reconstruction.models.recurrentvarnet.recurentvarnet.RecurrentVarNetBlock(in_channels: int = 2, hidden_channels: int = 64, num_layers: int = 4, fft_type: str = 'orthogonal')[source]

Bases: torch.nn.modules.module.Module

Recurrent Variational Network Block \(\mathcal{H}_{ heta_{t}}\) as presented in Yiasemis, George, et al.

References

Yiasemis, George, et al. “Recurrent Variational Network: A Deep Learning Inverse Problem Solver Applied to the Task of Accelerated MRI Reconstruction.” ArXiv:2111.09639 [Physics], Nov. 2021. arXiv.org, http://arxiv.org/abs/2111.09639.

forward(current_kspace: torch.Tensor, masked_kspace: torch.Tensor, sampling_mask: torch.Tensor, sensitivity_map: torch.Tensor, hidden_state: Union[None, torch.Tensor], coil_dim: int = 1, complex_dim: int = - 1) Tuple[torch.Tensor, torch.Tensor][source]

Computes forward pass of RecurrentVarNetBlock.

Parameters
  • current_kspace (Current k-space prediction.) – torch.Tensor, shape [batch_size, n_coil, height, width, 2]

  • masked_kspace (Subsampled k-space.) – torch.Tensor, shape [batch_size, n_coil, height, width, 2]

  • sampling_mask (Sampling mask.) – torch.Tensor, shape [batch_size, 1, height, width, 1]

  • sensitivity_map (Coil sensitivities.) – torch.Tensor, shape [batch_size, n_coil, height, width, 2]

  • hidden_state (ConvGRU hidden state.) – None or torch.Tensor, shape [batch_size, n_l, height, width, hidden_channels]

  • coil_dim (Coil dimension.) – int, Default: 1.

  • complex_dim (Complex dimension.) – int, Default: -1.

Returns

  • new_kspace (New k-space prediction.) – torch.Tensor, shape [batch_size, n_coil, height, width, 2]

  • hidden_state (Next hidden state.) – list of torch.Tensor, shape [batch_size, hidden_channels, height, width, num_layers]

training: bool

Module contents