Source code for mridc.collections.reconstruction.models.crossdomain.multicoil

# coding=utf-8
__author__ = "Dimitrios Karkalousos"

# Taken and adapted from: https://github.com/NKI-AI/direct/blob/main/direct/nn/crossdomain/multicoil.py
# Copyright (c) DIRECT Contributors

import torch
import torch.nn as nn


[docs]class MultiCoil(nn.Module): """ This makes the forward pass of multi-coil data of shape (N, N_coils, H, W, C) to a model. If coil_to_batch is set to True, coil dimension is moved to the batch dimension. Otherwise, it passes to the model each coil-data individually. """ def __init__(self, model: nn.Module, coil_dim: int = 1, coil_to_batch: bool = False): """Inits MultiCoil. Parameters ---------- model: Any nn.Module that takes as input with 4D data (N, H, W, C). Typically, a convolutional-like model. torch.nn.Module coil_dim: Coil dimension. int, Default: 1. coil_to_batch: If True batch and coil dimensions are merged when forwarded by the model and unmerged when outputted. Otherwise, input is forwarded to the model per coil. bool, Default: False. """ super().__init__() self.model = model self.coil_to_batch = coil_to_batch self._coil_dim = coil_dim def _compute_model_per_coil(self, data: torch.Tensor) -> torch.Tensor: """Computes the model per coil.""" output = [] for idx in range(data.size(self._coil_dim)): subselected_data = data.select(self._coil_dim, idx) if subselected_data.shape[-1] == 2 and subselected_data.dim() == 4: output.append(self.model(subselected_data.permute(0, 3, 1, 2))) else: output.append(self.model(subselected_data.unsqueeze(1)).squeeze(1)) output = torch.stack(output, dim=self._coil_dim) return output
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Performs the forward pass of MultiCoil. Parameters ---------- x: Multi-coil input. torch.Tensor, shape (N, N_coils, H, W, C) Returns ------- Multi-coil output. torch.Tensor, shape (N, N_coils, H, W, C) """ if self.coil_to_batch: x = x.clone() batch, coil, channels, height, width = x.size() x = x.reshape(batch * coil, channels, height, width).contiguous() x = self.model(x).permute(0, 2, 3, 1) x = x.reshape(batch, coil, height, width, -1).permute(0, 1, 4, 2, 3) else: x = self._compute_model_per_coil(x).contiguous() return x