Source code for mridc.collections.reconstruction.models.crossdomain.crossdomain

# coding=utf-8
__author__ = "Dimitrios Karkalousos"

# Taken and adapted from: https://github.com/NKI-AI/direct/blob/main/direct/nn/crossdomain/crossdomain.py
# Copyright (c) DIRECT Contributors
from typing import Optional, Union

import torch
import torch.nn as nn

from mridc.collections.common.parts.fft import fft2c, ifft2c
from mridc.collections.common.parts.utils import complex_conj, complex_mul


[docs]class CrossDomainNetwork(nn.Module): """This performs optimisation in both, k-space ("K") and image ("I") domains according to domain_sequence.""" def __init__( self, image_model_list: nn.Module, kspace_model_list: Optional[Union[nn.Module, None]] = None, domain_sequence: str = "KIKI", image_buffer_size: int = 1, kspace_buffer_size: int = 1, normalize_image: bool = False, fft_type: str = "orthogonal", **kwargs, ): """ Inits CrossDomainNetwork. Parameters ---------- image_model_list: Image domain model list. torch.nn.Module kspace_model_list: K-space domain model list. If set to None, a correction step is applied. torch.nn.Module, Default: None. domain_sequence: Domain sequence containing only "K" (k-space domain) and/or "I" (image domain). str, Default: "KIKI". image_buffer_size: Image buffer size. int, Default: 1. kspace_buffer_size: K-space buffer size. int, Default: 1. normalize_image: If True, input is normalized. bool, Default: False. fft_type: Type of FFT. str, Default: "orthogonal". kwargs:Keyword Arguments. dict """ super().__init__() self.fft_type = fft_type domain_sequence = list(domain_sequence.strip()) # type: ignore if not set(domain_sequence).issubset({"K", "I"}): raise ValueError(f"Invalid domain sequence. Got {domain_sequence}. Should only contain 'K' and 'I'.") if kspace_model_list is not None and len(kspace_model_list) != domain_sequence.count("K"): raise ValueError("K-space domain steps do not match k-space model list length.") if len(image_model_list) != domain_sequence.count("I"): raise ValueError("Image domain steps do not match image model list length.") self.domain_sequence = domain_sequence self.kspace_model_list = kspace_model_list self.kspace_buffer_size = kspace_buffer_size self.image_model_list = image_model_list self.image_buffer_size = image_buffer_size self._coil_dim = 1 self._complex_dim = -1 self._spatial_dims = (2, 3)
[docs] def kspace_correction(self, block_idx, image_buffer, kspace_buffer, sampling_mask, sensitivity_map, masked_kspace): """Performs k-space correction.""" forward_buffer = [ self._forward_operator(image.clone(), sampling_mask, sensitivity_map) for image in torch.split(image_buffer, 2, self._complex_dim) ] forward_buffer = torch.cat(forward_buffer, self._complex_dim) kspace_buffer = torch.cat([kspace_buffer, forward_buffer, masked_kspace], self._complex_dim) if self.kspace_model_list is not None: kspace_buffer = self.kspace_model_list[block_idx](kspace_buffer.permute(0, 1, 4, 2, 3)).permute( 0, 1, 3, 4, 2 ) else: kspace_buffer = kspace_buffer[..., :2] - kspace_buffer[..., 2:4] return kspace_buffer
[docs] def image_correction(self, block_idx, image_buffer, kspace_buffer, sampling_mask, sensitivity_map): """Performs image correction.""" backward_buffer = [ self._backward_operator(kspace.clone(), sampling_mask, sensitivity_map) for kspace in torch.split(kspace_buffer, 2, self._complex_dim) ] backward_buffer = torch.cat(backward_buffer, self._complex_dim) image_buffer = torch.cat([image_buffer, backward_buffer], self._complex_dim).permute(0, 3, 1, 2) image_buffer = self.image_model_list[block_idx](image_buffer).permute(0, 2, 3, 1) return image_buffer
def _forward_operator(self, image, sampling_mask, sensitivity_map): """Forward operator.""" return torch.where( sampling_mask == 0, torch.tensor([0.0], dtype=image.dtype).to(image.device), fft2c( complex_mul(image.unsqueeze(1), sensitivity_map), fft_type=self.fft_type, ).type(image.type()), ) def _backward_operator(self, kspace, sampling_mask, sensitivity_map): """Backward operator.""" kspace = torch.where(sampling_mask == 0, torch.tensor([0.0], dtype=kspace.dtype).to(kspace.device), kspace) return ( complex_mul( ifft2c(kspace.float(), fft_type=self.fft_type), complex_conj(sensitivity_map), ) .sum(1) .type(kspace.type()) )
[docs] def forward( self, masked_kspace: torch.Tensor, sensitivity_map: torch.Tensor, sampling_mask: torch.Tensor, ) -> torch.Tensor: """ Computes the forward pass of CrossDomainNetwork. Parameters ---------- masked_kspace: Subsampled k-space data. torch.tenor, shape [batch_size, n_coil, height, width, 2] sensitivity_map: Sensitivity map. torch.tenor, shape [batch_size, n_coil, height, width, 2] sampling_mask: Sampling mask. torch.tenor, shape [batch_size, 1, height, width, 1] Returns ------- Output image. torch.tenor, shape [batch_size, height, width, 2] """ input_image = self._backward_operator(masked_kspace, sampling_mask, sensitivity_map) image_buffer = torch.cat([input_image] * self.image_buffer_size, self._complex_dim).to(masked_kspace.device) kspace_buffer = torch.cat([masked_kspace] * self.kspace_buffer_size, self._complex_dim).to( masked_kspace.device ) kspace_block_idx, image_block_idx = 0, 0 for block_domain in self.domain_sequence: if block_domain == "K": kspace_buffer = self.kspace_correction( kspace_block_idx, image_buffer, kspace_buffer, sampling_mask, sensitivity_map, masked_kspace ) kspace_block_idx += 1 else: image_buffer = self.image_correction( image_block_idx, image_buffer, kspace_buffer, sampling_mask, sensitivity_map ) image_block_idx += 1 return image_buffer[..., :2]