# coding=utf-8
__author__ = "Dimitrios Karkalousos"
from abc import ABC
import torch
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer
from torch.nn import L1Loss
from mridc.collections.common.losses.ssim import SSIMLoss
from mridc.collections.common.parts.fft import ifft2c
from mridc.collections.common.parts.utils import coil_combination
from mridc.collections.reconstruction.models.base import BaseMRIReconstructionModel, BaseSensitivityModel
from mridc.collections.reconstruction.models.unet_base.unet_block import NormUnet
from mridc.collections.reconstruction.parts.utils import center_crop_to_smallest
from mridc.core.classes.common import typecheck
__all__ = ["UNet"]
[docs]class UNet(BaseMRIReconstructionModel, ABC):
"""
Implementation of the UNet, as presented in O. Ronneberger, P. Fischer, and Thomas Brox.
References
----------
..
O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks for biomedical image segmentation. \
In International Conference on Medical image computing and computer-assisted intervention, pages 234–241. \
Springer, 2015.
"""
def __init__(self, cfg: DictConfig, trainer: Trainer = None):
# init superclass
super().__init__(cfg=cfg, trainer=trainer)
cfg_dict = OmegaConf.to_container(cfg, resolve=True)
self.fft_type = cfg_dict.get("fft_type")
self.unet = NormUnet(
chans=cfg_dict.get("channels"),
num_pools=cfg_dict.get("pooling_layers"),
padding_size=cfg_dict.get("padding_size"),
normalize=cfg_dict.get("normalize"),
)
self.output_type = cfg_dict.get("output_type")
# Initialize the sensitivity network if use_sens_net is True
self.use_sens_net = cfg_dict.get("use_sens_net")
if self.use_sens_net:
self.sens_net = BaseSensitivityModel(
cfg_dict.get("sens_chans"),
cfg_dict.get("sens_pools"),
fft_type=self.fft_type,
mask_type=cfg_dict.get("sens_mask_type"),
normalize=cfg_dict.get("sens_normalize"),
)
# initialize weights if not using pretrained unet
# TODO if not cfg_dict.get("pretrained", False):
self.train_loss_fn = SSIMLoss() if cfg_dict.get("train_loss_fn") == "ssim" else L1Loss()
self.eval_loss_fn = SSIMLoss() if cfg_dict.get("eval_loss_fn") == "ssim" else L1Loss()
self.accumulate_estimates = False
[docs] @typecheck()
def forward(
self,
y: torch.Tensor,
sensitivity_maps: torch.Tensor,
mask: torch.Tensor,
init_pred: torch.Tensor,
target: torch.Tensor,
) -> torch.Tensor:
"""
Forward pass of the network.
Parameters
----------
y: Subsampled k-space data.
torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
sensitivity_maps: Coil sensitivity maps.
torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2]
mask: Sampling mask.
torch.Tensor, shape [1, 1, n_x, n_y, 1]
init_pred: Initial prediction.
torch.Tensor, shape [batch_size, n_x, n_y, 2]
target: Target data to compute the loss.
torch.Tensor, shape [batch_size, n_x, n_y, 2]
Returns
-------
pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2]
If self.accumulate_loss is True, returns a list of all intermediate estimates.
If False, returns the final estimate.
"""
sensitivity_maps = self.sens_net(y, mask) if self.use_sens_net else sensitivity_maps
eta = torch.view_as_complex(
coil_combination(ifft2c(y, fft_type=self.fft_type), sensitivity_maps, method=self.output_type, dim=1)
)
_, eta = center_crop_to_smallest(target, eta)
return torch.view_as_complex(self.unet(torch.view_as_real(eta.unsqueeze(1)))).squeeze(1)