Source code for mridc.collections.reconstruction.models.rim.utils

# coding=utf-8
__author__ = "Dimitrios Karkalousos"

import torch

from mridc.collections.common.parts.fft import fft2c, ifft2c


[docs]def log_likelihood_gradient( eta: torch.Tensor, masked_kspace: torch.Tensor, sense: torch.Tensor, mask: torch.Tensor, sigma: float, fft_type: str = "orthogonal", ) -> torch.Tensor: """ Computes the gradient of the log-likelihood function. Parameters ---------- eta: Initial guess for the reconstruction. masked_kspace: Subsampled k-space data. sense: Sensing matrix. mask: Sampling mask. sigma: Noise level. fft_type: Type of FFT to use. Returns ------- Gradient of the log-likelihood function. """ eta_real, eta_imag = map(lambda x: torch.unsqueeze(x, 0), eta.chunk(2, -1)) sense_real, sense_imag = sense.chunk(2, -1) re_se = eta_real * sense_real - eta_imag * sense_imag im_se = eta_real * sense_imag + eta_imag * sense_real pred = ifft2c(mask * (fft2c(torch.cat((re_se, im_se), -1), fft_type=fft_type) - masked_kspace), fft_type=fft_type) pred_real, pred_imag = pred.chunk(2, -1) re_out = torch.sum(pred_real * sense_real + pred_imag * sense_imag, 1) / (sigma**2.0) im_out = torch.sum(pred_imag * sense_real - pred_real * sense_imag, 1) / (sigma**2.0) eta_real = eta_real.squeeze(0) eta_imag = eta_imag.squeeze(0) return torch.cat((eta_real, eta_imag, re_out, im_out), 0).unsqueeze(0).squeeze(-1)