Source code for mridc.collections.reconstruction.parts.transforms
# encoding: utf-8
__author__ = "Dimitrios Karkalousos"
# Parts of the code have been taken from https://github.com/facebookresearch/fastMRI
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from mridc.collections.common.parts.fft import fft2c, ifft2c
from mridc.collections.common.parts.utils import complex_conj, complex_mul, to_tensor
from mridc.collections.reconstruction.data.subsample import MaskFunc
from mridc.collections.reconstruction.parts.utils import apply_mask, center_crop, complex_center_crop
__all__ = ["MRIDataTransforms"]
[docs]class MRIDataTransforms:
"""MRI preprocessing data transforms."""
def __init__(
self,
mask_func: Optional[List[MaskFunc]] = None,
shift_mask: bool = False,
mask_center_scale: Optional[float] = 0.02,
half_scan_percentage: float = 0.0,
crop_size: Optional[Tuple[int, int]] = None,
kspace_crop: bool = False,
crop_before_masking: bool = True,
kspace_zero_filling_size: Optional[Tuple] = None,
normalize_inputs: bool = False,
fft_type: str = "orthogonal",
use_seed: bool = True,
):
"""
Initialize the data transform.
Parameters
----------
mask_func: The function that masks the kspace.
shift_mask: Whether to shift the mask.
mask_center_scale: The scale of the center of the mask.
half_scan_percentage: The percentage of the scan to be used.
crop_size: The size of the crop.
kspace_crop: Whether to crop the kspace.
crop_before_masking: Whether to crop before masking.
kspace_zero_filling_size: The size of padding in kspace -> zero filling.
normalize_inputs: Whether to normalize the inputs.
fft_type: The type of the FFT.
use_seed: Whether to use the seed.
"""
self.mask_func = mask_func
self.shift_mask = shift_mask
self.mask_center_scale = mask_center_scale
self.half_scan_percentage = half_scan_percentage
self.crop_size = crop_size
self.kspace_crop = kspace_crop
self.crop_before_masking = crop_before_masking
self.kspace_zero_filling_size = kspace_zero_filling_size
self.normalize_inputs = normalize_inputs
self.fft_type = fft_type
self.use_seed = use_seed
[docs] def __call__(
self,
kspace: np.ndarray,
sensitivity_map: np.ndarray,
mask: np.ndarray,
eta: np.ndarray,
target: np.ndarray,
attrs: Dict,
fname: str,
slice_idx: int,
) -> Tuple[
Union[Union[List[Union[torch.Tensor, Any]], torch.Tensor], Any],
Union[Optional[torch.Tensor], Any],
Union[List, Any],
Union[Optional[torch.Tensor], Any],
Union[torch.Tensor, Any],
str,
int,
Union[Union[List, torch.Tensor], Any],
]:
"""
Apply the data transform.
Parameters
----------
kspace: The kspace.
sensitivity_map: The sensitivity map.
mask: The mask.
eta: The initial estimation.
target: The target.
attrs: The attributes.
fname: The file name.
slice_idx: The slice number.
Returns
-------
The transformed data.
"""
kspace = to_tensor(kspace)
# This condition is necessary in case of auto estimation of sense maps.
if sensitivity_map is not None and sensitivity_map.size != 0:
sensitivity_map = to_tensor(sensitivity_map)
# Apply zero-filling on kspace
if self.kspace_zero_filling_size is not None and self.kspace_zero_filling_size not in ("", "None"):
padding_top = np.floor_divide(abs(int(self.kspace_zero_filling_size[0]) - kspace.shape[1]), 2)
padding_bottom = padding_top
padding_left = np.floor_divide(abs(int(self.kspace_zero_filling_size[1]) - kspace.shape[2]), 2)
padding_right = padding_left
kspace = torch.view_as_complex(kspace)
kspace = torch.nn.functional.pad(
kspace, pad=(padding_left, padding_right, padding_top, padding_bottom), mode="constant", value=0
)
kspace = torch.view_as_real(kspace)
sensitivity_map = fft2c(sensitivity_map, self.fft_type)
sensitivity_map = torch.view_as_complex(sensitivity_map)
sensitivity_map = torch.nn.functional.pad(
sensitivity_map,
pad=(padding_left, padding_right, padding_top, padding_bottom),
mode="constant",
value=0,
)
sensitivity_map = torch.view_as_real(sensitivity_map)
sensitivity_map = ifft2c(sensitivity_map, self.fft_type)
if eta is not None and eta.size != 0:
eta = to_tensor(eta)
else:
eta = torch.tensor([])
# TODO: add RSS target option
if sensitivity_map is not None and sensitivity_map.size != 0:
target = torch.sum(complex_mul(ifft2c(kspace, fft_type=self.fft_type), complex_conj(sensitivity_map)), 0)
target = torch.view_as_complex(target)
elif target is not None and target.size != 0:
target = to_tensor(target)
elif "target" in attrs or "target_rss" in attrs:
target = torch.tensor(attrs["target"])
else:
raise ValueError("No target found")
target = torch.abs(target / torch.max(torch.abs(target)))
seed = None if not self.use_seed else tuple(map(ord, fname))
acq_start = attrs["padding_left"] if "padding_left" in attrs else 0
acq_end = attrs["padding_right"] if "padding_left" in attrs else 0
# This should be outside of the condition because it needs to be returned in the end, even if cropping is off.
# crop_size = torch.tensor([attrs["recon_size"][0], attrs["recon_size"][1]])
crop_size = target.shape
if self.crop_size is not None and self.crop_size not in ("", "None"):
# Check for smallest size against the target shape.
h = int(self.crop_size[0]) if int(self.crop_size[0]) <= target.shape[0] else target.shape[0]
w = int(self.crop_size[1]) if int(self.crop_size[1]) <= target.shape[1] else target.shape[1]
# Check for smallest size against the stored recon shape in metadata.
if crop_size[0] != 0:
h = h if h <= crop_size[0] else crop_size[0]
if crop_size[1] != 0:
w = w if w <= crop_size[1] else crop_size[1]
self.crop_size = (int(h), int(w))
target = center_crop(target, self.crop_size)
if sensitivity_map is not None and sensitivity_map.size != 0:
sensitivity_map = (
ifft2c(
complex_center_crop(fft2c(sensitivity_map, fft_type=self.fft_type), self.crop_size),
fft_type=self.fft_type,
)
if self.kspace_crop
else complex_center_crop(sensitivity_map, self.crop_size)
)
if eta is not None and eta.ndim > 2:
eta = (
ifft2c(
complex_center_crop(fft2c(eta, fft_type=self.fft_type), self.crop_size), fft_type=self.fft_type
)
if self.kspace_crop
else complex_center_crop(eta, self.crop_size)
)
# Cropping before masking will maintain the shape of original kspace intact for masking.
if self.crop_size is not None and self.crop_size not in ("", "None") and self.crop_before_masking:
kspace = (
complex_center_crop(kspace, self.crop_size)
if self.kspace_crop
else fft2c(
complex_center_crop(ifft2c(kspace, fft_type=self.fft_type), self.crop_size), fft_type=self.fft_type
)
)
if self.mask_func is not None:
# Check for multiple masks/accelerations.
if isinstance(self.mask_func, list):
masked_kspaces = []
masks = []
accs = []
for m in self.mask_func:
_masked_kspace, _mask, _acc = apply_mask(
kspace,
m,
seed,
(acq_start, acq_end),
shift=self.shift_mask,
half_scan_percentage=self.half_scan_percentage,
center_scale=self.mask_center_scale,
)
masked_kspaces.append(_masked_kspace)
masks.append(_mask.byte())
accs.append(_acc)
masked_kspace = masked_kspaces
mask = masks
acc = accs
else:
masked_kspace, mask, acc = apply_mask(
kspace,
self.mask_func[0], # type: ignore
seed,
(acq_start, acq_end),
shift=self.shift_mask,
half_scan_percentage=self.half_scan_percentage,
center_scale=self.mask_center_scale,
)
mask = mask.byte()
else:
masked_kspace = kspace
acc = torch.tensor([np.around(mask.size / mask.sum())]) if mask is not None else torch.tensor([1])
if mask is not None:
mask = torch.from_numpy(mask)
if mask.shape[0] == masked_kspace.shape[2]: # type: ignore
mask = mask.permute(1, 0)
elif mask.shape[0] != masked_kspace.shape[1]: # type: ignore
mask = torch.ones(
[masked_kspace.shape[-3], masked_kspace.shape[-2]], dtype=torch.float32 # type: ignore
)
else:
mask = torch.ones(
[masked_kspace.shape[-3], masked_kspace.shape[-2]], dtype=torch.float32 # type: ignore
)
if mask.ndim == 1:
mask = np.expand_dims(mask, axis=0)
if mask.shape[-2] == 1: # 1D mask
mask = torch.from_numpy(mask).unsqueeze(0).unsqueeze(-1)
else: # 2D mask
# Crop loaded mask.
if self.crop_size is not None and self.crop_size not in ("", "None"):
mask = center_crop(mask, self.crop_size)
mask = mask.unsqueeze(0).unsqueeze(-1)
if self.shift_mask:
mask = torch.fft.fftshift(mask, dim=[-3, -2])
masked_kspace = masked_kspace * mask
mask = mask.byte()
# Cropping after masking.
if self.crop_size is not None and self.crop_size not in ("", "None") and not self.crop_before_masking:
masked_kspace = (
complex_center_crop(masked_kspace, self.crop_size)
if self.kspace_crop
else fft2c(
complex_center_crop(ifft2c(masked_kspace, fft_type=self.fft_type), self.crop_size),
fft_type=self.fft_type,
)
)
mask = center_crop(mask.squeeze(-1), self.crop_size).unsqueeze(-1)
# Normalize by the max value.
if self.normalize_inputs:
if isinstance(self.mask_func, list):
masked_kspaces = []
for y in masked_kspace:
if self.fft_type in ("orthogonal", "orthogonal_norm_only"):
imspace = ifft2c(y, fft_type=self.fft_type)
imspace = imspace / torch.max(torch.abs(imspace))
masked_kspaces.append(fft2c(imspace, fft_type=self.fft_type))
elif self.fft_type == "fft_norm_only":
imspace = ifft2c(y, fft_type=self.fft_type)
masked_kspaces.append(fft2c(imspace, fft_type=self.fft_type))
elif self.fft_type == "backward_norm":
imspace = ifft2c(y, fft_type=self.fft_type, fft_normalization="backward")
masked_kspaces.append(fft2c(imspace, fft_type=self.fft_type, fft_normalization="backward"))
else:
imspace = torch.fft.ifftn(torch.view_as_complex(y), dim=[-2, -1], norm=None)
imspace = imspace / torch.max(torch.abs(imspace))
masked_kspaces.append(torch.view_as_real(torch.fft.fftn(imspace, dim=[-2, -1], norm=None)))
masked_kspace = masked_kspaces
else:
if self.fft_type in ("orthogonal", "orthogonal_norm_only"):
imspace = ifft2c(masked_kspace, fft_type=self.fft_type)
imspace = imspace / torch.max(torch.abs(imspace))
masked_kspace = fft2c(imspace, fft_type=self.fft_type)
elif self.fft_type == "fft_norm_only":
masked_kspace = fft2c(ifft2c(masked_kspace, fft_type=self.fft_type), fft_type=self.fft_type)
elif self.fft_type == "backward_norm":
masked_kspace = fft2c(
ifft2c(masked_kspace, fft_type=self.fft_type, fft_normalization="backward"),
fft_type=self.fft_type,
fft_normalization="backward",
)
else:
imspace = torch.fft.ifftn(torch.view_as_complex(masked_kspace), dim=[-2, -1], norm=None)
imspace = imspace / torch.max(torch.abs(imspace))
masked_kspace = torch.view_as_real(torch.fft.fftn(imspace, dim=[-2, -1], norm=None))
if sensitivity_map.size != 0:
sensitivity_map = sensitivity_map / torch.max(torch.abs(sensitivity_map))
if eta.size != 0 and eta.ndim > 2:
eta = eta / torch.max(torch.abs(eta))
target = target / torch.max(torch.abs(target))
return masked_kspace, sensitivity_map, mask, eta, target, fname, slice_idx, acc