Source code for mridc.collections.common.parts.fft

# encoding: utf-8
__author__ = "Dimitrios Karkalousos"

# Parts of the code have been taken from https://github.com/facebookresearch/fastMRI

from typing import List, Optional, Union

import numpy as np
import torch

__all__ = ["fft2c", "ifft2c"]


[docs]def fft2c( data: torch.Tensor, fft_type: str = "orthogonal", fft_normalization: str = "ortho", fft_dim: Union[Optional[int], List[int], None] = None, ) -> torch.Tensor: """ Apply centered 2 dimensional Fast Fourier Transform. Parameters ---------- data: Complex valued input data containing at least 3 dimensions: dimensions -2 & -1 are spatial dimensions. All other dimensions are assumed to be batch dimensions. fft_type: Specify fft type. This is important if an orthogonal transformation is needed or not. fft_normalization: "ortho" is the default normalization used by PyTorch. Can be changed to "ortho" or None. fft_dim: dimensions to apply the FFT Returns ------- The FFT of the input. """ if fft_dim is None: fft_dim = [-2, -1] if fft_type == "orthogonal": data = ifftshift(data, dim=[-3, -2]) data = torch.view_as_real(torch.fft.fft2(torch.view_as_complex(data), dim=fft_dim, norm=fft_normalization)) if fft_type == "orthogonal": data = fftshift(data, dim=[-3, -2]) return data
[docs]def ifft2c( data: torch.Tensor, fft_type: str = "orthogonal", fft_normalization: str = "ortho", fft_dim: Union[Optional[int], List[int], None] = None, ) -> torch.Tensor: """ Apply centered 2 dimensional Inverse Fast Fourier Transform. Parameters ---------- data: Complex valued input data containing at least 3 dimensions: dimensions -2 & -1 are spatial dimensions. All other dimensions are assumed to be batch dimensions. fft_type: Specify fft type. This is important if an orthogonal transformation is needed or not. fft_normalization: "ortho" is the default normalization used by PyTorch. Can be changed to "ortho" or None. fft_dim: dimensions to apply the FFT Returns ------- The IFFT of the input. """ if fft_dim is None: fft_dim = [-2, -1] if fft_type == "orthogonal": data = ifftshift(data, dim=[-3, -2]) data = torch.view_as_real(torch.fft.ifft2(torch.view_as_complex(data), dim=fft_dim, norm=fft_normalization)) if fft_type == "orthogonal": data = fftshift(data, dim=[-3, -2]) return data
def roll_one_dim(x: torch.Tensor, shift: int, dim: int) -> torch.Tensor: """ Similar to roll but for only one dim. Parameters ---------- x: A PyTorch tensor. shift: Amount to roll. dim: Which dimension to roll. Returns ------- Rolled version of x. """ shift %= x.size(dim) if shift == 0: return x left = x.narrow(dim, 0, x.size(dim) - shift) right = x.narrow(dim, x.size(dim) - shift, shift) return torch.cat((right, left), dim=dim) def roll(x: torch.Tensor, shift: List[int], dim: List[int]) -> torch.Tensor: """ Similar to np.roll but applies to PyTorch Tensors. Parameters ---------- x: A PyTorch tensor. shift: Amount to roll. dim: Which dimension to roll. Returns ------- Rolled version of x. """ if len(shift) != len(dim): raise ValueError("len(shift) must match len(dim)") for (s, d) in zip(shift, dim): x = roll_one_dim(x, s, d) return x def fftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor: """ Similar to np.fft.fftshift but applies to PyTorch Tensors Parameters ---------- x: A PyTorch tensor. dim: Which dimension to fftshift. Returns ------- fftshifted version of x. """ if dim is None: # this weird code is necessary for torch.jit.script typing dim = [0] * (x.dim()) for i in range(1, x.dim()): dim[i] = i # Also necessary for torch.jit.script shift = [0] * len(dim) for i, dim_num in enumerate(dim): shift[i] = np.floor_divide(x.shape[dim_num], 2) return roll(x, shift, dim) def ifftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor: """ Similar to np.fft.ifftshift but applies to PyTorch Tensors Parameters ---------- x: A PyTorch tensor. dim: Which dimension to ifftshift. Returns ------- ifftshifted version of x. """ if dim is None: # this weird code is necessary for torch.jit.script typing dim = [0] * (x.dim()) for i in range(1, x.dim()): dim[i] = i # Also necessary for torch.jit.script shift = [0] * len(dim) for i, dim_num in enumerate(dim): shift[i] = np.floor_divide(x.shape[dim_num] + 1, 2) return roll(x, shift, dim)