# coding=utf-8
__author__ = "Dimitrios Karkalousos"
from abc import ABC
import torch
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer
from torch.nn import L1Loss
from mridc.collections.common.losses.ssim import SSIMLoss
from mridc.collections.common.parts.fft import fft2c, ifft2c
from mridc.collections.common.parts.utils import complex_conj, complex_mul
from mridc.collections.reconstruction.models.base import BaseMRIReconstructionModel, BaseSensitivityModel
from mridc.collections.reconstruction.models.conv.conv2d import Conv2d
from mridc.collections.reconstruction.models.crossdomain.multicoil import MultiCoil
from mridc.collections.reconstruction.models.didn.didn import DIDN
from mridc.collections.reconstruction.models.mwcnn.mwcnn import MWCNN
from mridc.collections.reconstruction.models.unet_base.unet_block import NormUnet
from mridc.collections.reconstruction.parts.utils import center_crop_to_smallest
from mridc.core.classes.common import typecheck
__all__ = ["KIKINet"]
[docs]class KIKINet(BaseMRIReconstructionModel, ABC):
"""
Based on KIKINet implementation [1]. Modified to work with multi-coil k-space data, as presented in Eo, Taejoon, \
et al.
References
----------
..
Eo, Taejoon, et al. “KIKI-Net: Cross-Domain Convolutional Neural Networks for Reconstructing Undersampled \
Magnetic Resonance Images.” Magnetic Resonance in Medicine, vol. 80, no. 5, Nov. 2018, pp. 2188–201. PubMed, \
https://doi.org/10.1002/mrm.27201.
"""
def __init__(self, cfg: DictConfig, trainer: Trainer = None):
# init superclass
super().__init__(cfg=cfg, trainer=trainer)
cfg_dict = OmegaConf.to_container(cfg, resolve=True)
self.num_iter = cfg_dict.get("num_iter")
self.no_dc = cfg_dict.get("no_dc")
kspace_model_architecture = cfg_dict.get("kspace_model_architecture")
if kspace_model_architecture == "CONV":
kspace_model = Conv2d(
in_channels=2,
out_channels=2,
hidden_channels=cfg_dict.get("kspace_conv_hidden_channels"),
n_convs=cfg_dict.get("kspace_conv_n_convs"),
batchnorm=cfg_dict.get("kspace_conv_batchnorm"),
)
elif kspace_model_architecture == "DIDN":
kspace_model = DIDN(
in_channels=2,
out_channels=2,
hidden_channels=cfg_dict.get("kspace_didn_hidden_channels"),
num_dubs=cfg_dict.get("kspace_didn_num_dubs"),
num_convs_recon=cfg_dict.get("kspace_didn_num_convs_recon"),
)
elif kspace_model_architecture in ["UNET", "NORMUNET"]:
kspace_model = NormUnet(
cfg_dict.get("kspace_unet_num_filters"),
cfg_dict.get("kspace_unet_num_pool_layers"),
in_chans=2,
out_chans=2,
drop_prob=cfg_dict.get("kspace_unet_dropout_probability"),
padding_size=cfg_dict.get("kspace_unet_padding_size"),
normalize=cfg_dict.get("kspace_unet_normalize"),
)
else:
raise NotImplementedError(
f"KIKINet is currently implemented for kspace_model_architecture == 'CONV' or 'DIDN' or 'UNet'."
f"Got kspace_model_architecture == {kspace_model_architecture}."
)
image_model_architecture = cfg_dict.get("imspace_model_architecture")
if image_model_architecture == "MWCNN":
image_model = MWCNN(
input_channels=2,
first_conv_hidden_channels=cfg_dict.get("image_mwcnn_hidden_channels"),
num_scales=cfg_dict.get("image_mwcnn_num_scales"),
bias=cfg_dict.get("image_mwcnn_bias"),
batchnorm=cfg_dict.get("image_mwcnn_batchnorm"),
)
elif image_model_architecture in ["UNET", "NORMUNET"]:
image_model = NormUnet(
cfg_dict.get("imspace_unet_num_filters"),
cfg_dict.get("imspace_unet_num_pool_layers"),
in_chans=2,
out_chans=2,
drop_prob=cfg_dict.get("imspace_unet_dropout_probability"),
padding_size=cfg_dict.get("imspace_unet_padding_size"),
normalize=cfg_dict.get("imspace_unet_normalize"),
)
else:
raise NotImplementedError(
f"KIKINet is currently implemented only with image_model_architecture == 'MWCNN' or 'UNet'."
f"Got {image_model_architecture}."
)
self.fft_type = cfg_dict.get("fft_type")
self._coil_dim = 1
self.image_model_list = torch.nn.ModuleList([image_model] * self.num_iter)
self.kspace_model_list = torch.nn.ModuleList([MultiCoil(kspace_model, self._coil_dim)] * self.num_iter)
# Initialize the sensitivity network if use_sens_net is True
self.use_sens_net = cfg_dict.get("use_sens_net")
if self.use_sens_net:
self.sens_net = BaseSensitivityModel(
cfg_dict.get("sens_chans"),
cfg_dict.get("sens_pools"),
fft_type=self.fft_type,
mask_type=cfg_dict.get("sens_mask_type"),
normalize=cfg_dict.get("sens_normalize"),
)
self.train_loss_fn = SSIMLoss() if cfg_dict.get("train_loss_fn") == "ssim" else L1Loss()
self.eval_loss_fn = SSIMLoss() if cfg_dict.get("eval_loss_fn") == "ssim" else L1Loss()
self.output_type = cfg_dict.get("output_type")
self.dc_weight = torch.nn.Parameter(torch.ones(1))
self.accumulate_estimates = False
[docs] @typecheck()
def forward(
self,
y: torch.Tensor,
sensitivity_maps: torch.Tensor,
mask: torch.Tensor,
init_pred: torch.Tensor,
target: torch.Tensor,
) -> torch.Tensor:
"""
Forward pass of the network.
Parameters
----------
y: Subsampled k-space data.
torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
sensitivity_maps: Coil sensitivity maps.
torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
mask: Sampling mask.
torch.Tensor, shape [1, 1, n_x, n_y, 1]
init_pred: Initial prediction.
torch.Tensor, shape [batch_size, n_x, n_y, 2]
target: Target data to compute the loss.
torch.Tensor, shape [batch_size, n_x, n_y, 2]
Returns
-------
pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2]
If self.accumulate_loss is True, returns a list of all intermediate estimates.
If False, returns the final estimate.
"""
sensitivity_maps = self.sens_net(y, mask) if self.use_sens_net else sensitivity_maps
kspace = y.clone()
zero = torch.zeros(1, 1, 1, 1, 1).to(kspace)
for idx in range(self.num_iter):
soft_dc = torch.where(mask.bool(), kspace - y, zero) * self.dc_weight
kspace = self.kspace_model_list[idx](kspace)
if kspace.shape[-1] != 2:
kspace = kspace.permute(0, 1, 3, 4, 2).to(target)
kspace = torch.view_as_real(kspace[..., 0] + 1j * kspace[..., 1]) # this is necessary, but why?
image = complex_mul(ifft2c(kspace, fft_type=self.fft_type), complex_conj(sensitivity_maps)).sum(1)
image = self.image_model_list[idx](image.unsqueeze(1)).squeeze(1)
if not self.no_dc:
image = fft2c(complex_mul(image.unsqueeze(1), sensitivity_maps), fft_type=self.fft_type).type(
image.type()
)
image = kspace - soft_dc - image
image = complex_mul(ifft2c(image, fft_type=self.fft_type), complex_conj(sensitivity_maps)).sum(1)
if idx < self.num_iter - 1:
kspace = fft2c(complex_mul(image.unsqueeze(1), sensitivity_maps), fft_type=self.fft_type).type(
image.type()
)
image = torch.view_as_complex(image)
_, image = center_crop_to_smallest(target, image)
return image