Source code for mridc.collections.reconstruction.models.variablesplittingnet.vsnet_block

# coding=utf-8
__author__ = "Dimitrios Karkalousos"

from typing import Any, List, Union

import torch

from mridc.collections.common.parts.fft import fft2c, ifft2c
from mridc.collections.common.parts.utils import complex_conj, complex_mul


[docs]class DataConsistencyLayer(torch.nn.Module): """ Data consistency layer for the VSNet. This layer is used to ensure that the output of the VSNet is the same as the input. """ def __init__(self): """Initializes the data consistency layer.""" super().__init__() self.dc_weight = torch.nn.Parameter(torch.ones(1))
[docs] def forward(self, pred_kspace, ref_kspace, mask): """Forward pass of the data consistency layer.""" return ((1 - mask) * pred_kspace + mask * ref_kspace) * self.dc_weight
[docs]class WeightedAverageTerm(torch.nn.Module): """Weighted average term for the VSNet.""" def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.ones(1))
[docs] def forward(self, x, Sx): return self.param * x + (1 - self.param) * Sx
[docs]class VSNetBlock(torch.nn.Module): """ Model block for the Variable-Splitting Network inspired by [1]_. References ---------- .. [1] Duan, J. et al. (2019) ‘Vs-net: Variable splitting network for accelerated parallel MRI reconstruction’, Lecture Notes in Computer Science (including subseries Lecture Notes in Artificial Intelligence and Lecture Notes in Bioinformatics), 11767 LNCS, pp. 713–722. doi: 10.1007/978-3-030-32251-9_78. """ def __init__( self, denoiser_block: torch.nn.ModuleList, data_consistency_block: torch.nn.ModuleList, weighted_average_block: torch.nn.ModuleList, num_cascades: int = 8, fft_type: str = "orthogonal", ): """ Parameters ---------- denoiser_block: Model to apply denoising. data_consistency_block: Model to apply data consistency. weighted_average_block: Model to apply weighted average. num_cascades: Number of cascades. fft_type: Type of FFT to use. """ super().__init__() self.denoiser_block = denoiser_block self.data_consistency_block = data_consistency_block self.weighted_average_block = weighted_average_block self.num_cascades = num_cascades self.fft_type = fft_type
[docs] def sens_expand(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: """ Expand the sensitivity maps to the same size as the input. Parameters ---------- x: Input data. sens_maps: Coil Sensitivity maps. Returns ------- SENSE reconstruction expanded to the same size as the input sens_maps. """ return fft2c(complex_mul(x, sens_maps), fft_type=self.fft_type)
[docs] def sens_reduce(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: """ Reduce the sensitivity maps. Parameters ---------- x: Input data. sens_maps: Coil Sensitivity maps. Returns ------- SENSE coil-combined reconstruction. """ x = ifft2c(x, fft_type=self.fft_type) return complex_mul(x, complex_conj(sens_maps)).sum(1)
[docs] def forward( self, kspace: torch.Tensor, sens_maps: torch.Tensor, mask: torch.Tensor, ) -> List[Union[torch.Tensor, Any]]: """ Parameters ---------- kspace: Reference k-space data. sens_maps: Coil sensitivity maps. mask: Mask to apply to the data. Returns ------- Reconstructed image. """ for idx in range(self.num_cascades): pred = self.sens_reduce(kspace, sens_maps) pred = self.denoiser_block[idx](pred.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) pred = self.sens_expand(pred, sens_maps) sx = self.data_consistency_block[idx](pred, kspace, mask) sx = self.sens_reduce(sx, sens_maps) kspace = self.weighted_average_block[idx](kspace + pred, sx) return kspace