Source code for mridc.collections.reconstruction.models.sigmanet.dc_layers

# coding=utf-8
__author__ = "Dimitrios Karkalousos"

# Taken and adapted from:
# https://github.com/khammernik/sigmanet/blob/master/reconstruction/common/mytorch/models/datalayer.py

import torch

from mridc.collections.common.parts.fft import fft2c, ifft2c
from mridc.collections.common.parts.utils import complex_abs, complex_conj, complex_mul


[docs]class DataIDLayer(torch.nn.Module): """Placeholder for the data layer.""" def __init__(self, *args, **kwargs): super().__init__()
[docs]class DataGDLayer(torch.nn.Module): """DataLayer computing the gradient on the L2 dataterm.""" def __init__(self, lambda_init, learnable=True, fft_type="orthogonal"): """ Parameters ---------- lambda_init: Init value of data term weight lambda. learnable: If True, the data term weight lambda is learnable. fft_type: Type of FFT to use. """ super(DataGDLayer, self).__init__() self.lambda_init = lambda_init self.data_weight = torch.nn.Parameter(torch.Tensor(1)) self.data_weight.data = torch.tensor( lambda_init, dtype=self.data_weight.dtype, ) self.data_weight.requires_grad = learnable self.fft_type = fft_type
[docs] def forward(self, x, y, smaps, mask): """ Parameters ---------- x: Input image. y: Subsampled k-space data. smaps: Coil sensitivity maps. mask: Sampling mask. Returns ------- data_loss: Data term loss. """ A_x_y = ( torch.sum( fft2c(complex_mul(x.unsqueeze(-5).expand_as(smaps), smaps), fft_type=self.fft_type) * mask, -4, keepdim=True, ) - y ) gradD_x = torch.sum(complex_mul(ifft2c(A_x_y * mask), complex_conj(smaps)), dim=(-5)) return x - self.data_weight * gradD_x
[docs]class DataProxCGLayer(torch.nn.Module): """Solving the prox wrt. dataterm using Conjugate Gradient as proposed by Aggarwal et al.""" def __init__(self, lambda_init, tol=1e-6, iter=10, learnable=True, fft_type="orthogonal"): super(DataProxCGLayer, self).__init__() self.lambdaa = torch.nn.Parameter(torch.Tensor(1)) self.lambdaa.data = torch.tensor(lambda_init) self.lambdaa_init = lambda_init self.lambdaa.requires_grad = learnable self.tol = tol self.iter = iter self.op = ConjugateGradient self.fft_type = fft_type
[docs] def forward(self, x, f, smaps, mask): """ Parameters ---------- x: Input image. f: Subsampled k-space data. smaps: Coil sensitivity maps. mask: Sampling mask. Returns ------- data_loss: Data term loss. """ return self.op.apply( x, self.lambdaa, f, smaps, mask, self.tol, self.iter, self.fft_type, )
[docs] def set_learnable(self, flag): self.lambdaa.requires_grad = flag
[docs]class ConjugateGradient(torch.autograd.Function): """Conjugate Gradient solver for the prox of the data term."""
[docs] @staticmethod def complexDot(data1, data2): """Complex dot product of two tensors.""" nBatch = data1.shape[0] mult = complex_mul(data1, complex_conj(data2)) re, im = torch.unbind(mult, dim=-1) return torch.stack([torch.sum(re.view(nBatch, -1), dim=-1), torch.sum(im.view(nBatch, -1), dim=-1)], -1)
[docs] @staticmethod def solve(x0, M, tol, max_iter): """Solve the linear system Mx=b using conjugate gradient.""" nBatch = x0.shape[0] x = torch.zeros(x0.shape).to(x0.device) r = x0.clone() p = x0.clone() x0x0 = (x0.pow(2)).view(nBatch, -1).sum(-1) rr = torch.stack([(r.pow(2)).view(nBatch, -1).sum(-1), torch.zeros(nBatch).to(x0.device)], dim=-1) it = 0 while torch.min(rr[..., 0] / x0x0) > tol and it < max_iter: it += 1 q = M(p) data1 = rr data2 = ConjugateGradient.complexDot(p, q) re1, im1 = torch.unbind(data1, -1) re2, im2 = torch.unbind(data2, -1) alpha = torch.stack([re1 * re2 + im1 * im2, im1 * re2 - re1 * im2], -1) / complex_abs(data2) ** 2 x += complex_mul(alpha.reshape(nBatch, 1, 1, 1, -1), p.clone()) r -= complex_mul(alpha.reshape(nBatch, 1, 1, 1, -1), q.clone()) rr_new = torch.stack([(r.pow(2)).view(nBatch, -1).sum(-1), torch.zeros(nBatch).to(x0.device)], dim=-1) beta = torch.stack([rr_new[..., 0] / rr[..., 0], torch.zeros(nBatch).to(x0.device)], dim=-1) p = r.clone() + complex_mul(beta.reshape(nBatch, 1, 1, 1, -1), p) rr = rr_new.clone() return x
[docs] @staticmethod def forward(ctx, z, lambdaa, y, smaps, mask, tol, max_iter, fft_type): """ Forward pass of the conjugate gradient solver. Parameters ---------- ctx: Context object. z: Input image. lambdaa: Regularization parameter. y: Subsampled k-space data. smaps: Coil sensitivity maps. mask: Sampling mask. tol: Tolerance for the stopping criterion. max_iter: Maximum number of iterations. fft_type: FFT type. Returns ------- z: Output image. """ ctx.tol = tol ctx.max_iter = max_iter ctx.fft_type = fft_type def A(x): x = fft2c(complex_mul(x.expand_as(smaps), smaps), fft_type=fft_type) * mask return torch.sum(x, dim=-4, keepdim=True) def AT(x): return torch.sum(complex_mul(ifft2c(x * mask), complex_conj(smaps)), dim=(-5)) def M(p): return lambdaa * AT(A(p)) + p x0 = lambdaa * AT(y) + z ctx.save_for_backward(AT(y), x0, smaps, mask, lambdaa) return ConjugateGradient.solve(x0, M, ctx.tol, ctx.max_iter)
[docs] @staticmethod def backward(ctx, grad_x): """ Backward pass of the conjugate gradient solver. Parameters ---------- ctx: Context object. grad_x: Gradient of the output image. Returns ------- grad_z: Gradient of the input image. """ ATy, rhs, smaps, mask, lambdaa = ctx.saved_tensors def A(x): x = fft2c(complex_mul(x.expand_as(smaps), smaps), fft_type=ctx.fft_type) * mask return torch.sum(x, dim=-4, keepdim=True) def AT(x): return torch.sum(complex_mul(ifft2c(x * mask), complex_conj(smaps)), dim=(-5)) def M(p): return lambdaa * AT(A(p)) + p Qe = ConjugateGradient.solve(grad_x, M, ctx.tol, ctx.max_iter) QQe = ConjugateGradient.solve(Qe, M, ctx.tol, ctx.max_iter) grad_z = Qe grad_lambdaa = ( complex_mul(ifft2c(Qe), complex_conj(ATy)).sum() - complex_mul(ifft2c(QQe), complex_conj(rhs)).sum() ) return grad_z, grad_lambdaa, None, None, None, None, None, None
[docs]class DataVSLayer(torch.nn.Module): """ DataLayer using variable splitting formulation """ def __init__(self, alpha_init, beta_init, learnable=True, fft_type="orthogonal"): """ Parameters ---------- alpha_init: Init value of data consistency block (DCB) beta_init: Init value of weighted averaging block (WAB) learnable: If True, the parameters of the model are learnable fft_type: Type of FFT to use. Can be "orthogonal". """ super(DataVSLayer, self).__init__() self.alpha = torch.nn.Parameter(torch.Tensor(1)) self.alpha.data = torch.tensor(alpha_init, dtype=self.alpha.dtype) self.beta = torch.nn.Parameter(torch.Tensor(1)) self.beta.data = torch.tensor(beta_init, dtype=self.beta.dtype) self.learnable = learnable self.set_learnable(learnable) self.fft_type = fft_type
[docs] def forward(self, x, y, smaps, mask): """ Forward pass of the data-consistency block. Parameters ---------- x: Input image. y: Subsampled k-space data. smaps: Coil sensitivity maps. mask: Sampling mask. Returns ------- Output image. """ A_x = torch.sum( fft2c(complex_mul(x.unsqueeze(-5).expand_as(smaps), smaps), fft_type=self.fft_type), -4, keepdim=True ) k_dc = (1 - mask) * A_x + mask * (self.alpha * A_x + (1 - self.alpha) * y) x_dc = torch.sum(complex_mul(ifft2c(k_dc), complex_conj(smaps)), dim=(-5)) return self.beta * x + (1 - self.beta) * x_dc
[docs] def set_learnable(self, flag): """ Set the learnable flag of the parameters. Parameters ---------- flag: If True, the parameters of the model are learnable. """ self.learnable = flag self.alpha.requires_grad = self.learnable self.beta.requires_grad = self.learnable
[docs]class DCLayer(torch.nn.Module): """ Data Consistency layer from DC-CNN, apply for single coil mainly """ def __init__(self, lambda_init=0.0, learnable=True, fft_type="orthogonal"): """ Parameters ---------- lambda_init: Init value of data consistency block (DCB) learnable: If True, the parameters of the model are learnable fft_type: Type of FFT to use. Can be "orthogonal". """ super(DCLayer, self).__init__() self.lambda_ = torch.nn.Parameter(torch.Tensor(1)) self.lambda_.data = torch.tensor(lambda_init, dtype=self.lambda_.dtype) self.learnable = learnable self.set_learnable(learnable) self.fft_type = fft_type
[docs] def forward(self, x, y, mask): """ Forward pass of the data-consistency block. Parameters ---------- x: Input image. y: Subsampled k-space data. mask: Sampling mask. Returns ------- Output image. """ A_x = fft2c(x, fft_type=self.fft_type) k_dc = (1 - mask) * A_x + mask * (self.lambda_ * A_x + (1 - self.lambda_) * y) return ifft2c(k_dc, fft_type=self.fft_type)
[docs] def set_learnable(self, flag): """ Set the learnable flag of the parameters. Parameters ---------- flag: If True, the parameters of the model are learnable. """ self.learnable = flag self.lambda_.requires_grad = self.learnable