Source code for mridc.collections.reconstruction.models.sigmanet.sensitivity_net

# coding=utf-8
__author__ = "Dimitrios Karkalousos"

# Taken and adapted from:
# https://github.com/khammernik/sigmanet/blob/master/reconstruction/common/mytorch/models/sn.py
import numpy as np
import torch


[docs]def matrix_invert(xx, xy, yx, yy): """Invert a 2x2 matrix.""" det = xx * yy - xy * yx return yy.div(det), -xy.div(det), -yx.div(det), xx.div(det)
[docs]class ComplexInstanceNorm(torch.nn.Module): """Motivated by 'Deep Complex Networks' (https://arxiv.org/pdf/1705.09792.pdf)""" def __init__(self): super(ComplexInstanceNorm, self).__init__() self.mean = 0 self.cov_xx_half = 1 / np.sqrt(2) self.cov_xy_half = 0 self.cov_yx_half = 0 self.cov_yy_half = 1 / np.sqrt(2)
[docs] def complex_instance_norm(self, x, eps=1e-5): """Operates on images x of size [nBatch, nSmaps, nFE, nPE, 2]""" x_combined = torch.sum(x, dim=1, keepdim=True) mean = x_combined.mean(dim=(1, 2, 3), keepdim=True) x_m = x - mean self.mean = mean self.complex_pseudocovariance(x_m)
[docs] def complex_pseudocovariance(self, data): """Data variable hast to be already mean-free! Operates on images x of size [nBatch, nSmaps, nFE, nPE, 2]""" if data.size(-1) != 2: raise AssertionError shape = data.shape # compute number of elements N = shape[2] * shape[3] # separate real/imaginary channel re, im = torch.unbind(data, dim=-1) # dimensions is now length of original shape - 1 (because channels are seperated) dim = list(range(1, len(shape) - 1)) # compute covariance entries. cxy = cyx cxx = (re * re).sum(dim=dim, keepdim=True) / (N - 1) cyy = (im * im).sum(dim=dim, keepdim=True) / (N - 1) cxy = (re * im).sum(dim=dim, keepdim=True) / (N - 1) # Eigenvalue decomposition C = V*S*inv(V) # compute eigenvalues s1 = (cxx + cyy) / 2 - torch.sqrt((cxx + cyy) ** 2 / 4 - cxx * cyy + cxy**2) s2 = (cxx + cyy) / 2 + torch.sqrt((cxx + cyy) ** 2 / 4 - cxx * cyy + cxy**2) # compute eigenvectors v1x = s1 - cyy v1y = cxy v2x = s2 - cyy v2y = cxy # normalize eigenvectors norm1 = torch.sqrt(torch.sum(v1x * v1x + v1y * v1y, dim=dim, keepdim=True)) norm2 = torch.sqrt(torch.sum(v2x * v2x + v2y * v2y, dim=dim, keepdim=True)) v1x = v1x.div(norm1) v1y = v1y.div(norm1) v2x = v2x.div(norm2) v2y = v2y.div(norm2) # now we need the sqrt of the covariance matrix. # C^{-0.5} = V * sqrt(S) * inv(V) det = v1x * v2y - v2x * v1y s1 = torch.sqrt(s1).div(det) s2 = torch.sqrt(s2).div(det) self.cov_xx_half = v1x * v2y * s1 - v1y * v2x * s2 self.cov_yy_half = v1x * v2y * s2 - v1y * v2x * s1 self.cov_xy_half = v1x * v2x * (s2 - s1) self.cov_yx_half = v1y * v2y * (s1 - s2)
[docs] def forward(self, input): """Operates on images x of size [nBatch, nSmaps, nFE, nPE, 2]""" return self.normalize(input)
[docs] def set_normalization(self, input): """Set the normalization parameters for a given input.""" mean = torch.tensor([torch.mean(input).item()]).to(input) self.complex_pseudocovariance(input - mean) self.mean = mean.unsqueeze(1).unsqueeze(1).unsqueeze(1) self.cov_xx_half = self.cov_xx_half.view(-1, 1, 1, 1) self.cov_xy_half = self.cov_xy_half.view(-1, 1, 1, 1) self.cov_yx_half = self.cov_yx_half.view(-1, 1, 1, 1) self.cov_yy_half = self.cov_yy_half.view(-1, 1, 1, 1)
[docs] def normalize(self, x): """Normalize the input x.""" x_m = x - self.mean re, im = torch.unbind(x_m, dim=-1) cov_xx_half_inv, cov_xy_half_inv, cov_yx_half_inv, cov_yy_half_inv = matrix_invert( self.cov_xx_half, self.cov_xy_half, self.cov_yx_half, self.cov_yy_half ) x_norm_re = cov_xx_half_inv * re + cov_xy_half_inv * im x_norm_im = cov_yx_half_inv * re + cov_yy_half_inv * im img = torch.stack([x_norm_re, x_norm_im], dim=-1) img = img.clamp(-6, 6) return img
[docs] def unnormalize(self, x): """Unnormalize the input x.""" re, im = torch.unbind(x, dim=-1) x_unnorm_re = self.cov_xx_half * re + self.cov_xy_half * im x_unnorm_im = self.cov_yx_half * re + self.cov_yy_half * im return torch.stack([x_unnorm_re, x_unnorm_im], dim=-1) + self.mean
[docs]class ComplexNormWrapper(torch.nn.Module): """Wrapper for complex normalization.""" def __init__(self, model): super().__init__() self.model = model self.complex_instance_norm = ComplexInstanceNorm()
[docs] def forward(self, input): # compute complex instance norm on sample of size [nBatch, nSmaps, nFE, nPE, 2] self.complex_instance_norm.set_normalization(input) output = self.complex_instance_norm.normalize(input) # re-shape data from [nBatch, nSmaps, nFE, nPE, 2] to [nBatch*nSmaps, 2, nFE, nPE] shp = output.shape output = output.view(shp[0] * shp[1], *shp[2:]).permute(0, 3, 1, 2) # apply denoising output = self.model(output) # re-shape data from [nBatch*nSmaps, 2, nFE, nPE] # to [nBatch, nSmaps, nFE, nPE, 2] output = output.permute(0, 2, 3, 1).view(*shp) # unnormalize output = self.complex_instance_norm.unnormalize(output) return output
[docs]class SensitivityNetwork(torch.nn.Module): """Sensitivity network with data term based on forward and adjoint containing the sensitivity maps""" def __init__( self, num_iter, model, datalayer, shared_params=True, save_space=False, reset_cache=False, ): """ Parameters ---------- num_iter: Number of iterations. model: Model to be used for the forward and adjoint. datalayer: Data layer to be used for the forward and adjoint. shared_params: If True, the parameters of the model are shared between the forward and adjoint. save_space: If True, the adjoint is computed in the forward pass. reset_cache: If True, the adjoint is computed in the forward pass. """ super().__init__() self.shared_params = shared_params self.num_iter = 1 if self.shared_params else num_iter self.num_iter_total = num_iter self.is_trainable = [True] * num_iter # setup the modules self.gradR = torch.nn.ModuleList([ComplexNormWrapper(model) for _ in range(self.num_iter)]) self.gradD = torch.nn.ModuleList([datalayer for _ in range(self.num_iter)]) self.save_space = save_space if self.save_space: self.forward = self.forward_save_space self.reset_cache = reset_cache
[docs] def forward(self, x, y, smaps, mask): """ Parameters ---------- x: Input data. y: Subsampled k-space data. smaps: Coil sensitivity maps. mask: Sampling mask. Returns ------- Output data. """ x_all = [x] x_half_all = [] if self.shared_params: num_iter = self.num_iter_total else: num_iter = min(np.where(self.is_trainable)[0][-1] + 1, self.num_iter) for i in range(num_iter): x_thalf = x - self.gradR[i % self.num_iter](x) x = self.gradD[i % self.num_iter](x_thalf, y, smaps, mask) x_all.append(x) x_half_all.append(x_thalf) return x_all[-1]
[docs] def forward_save_space(self, x, y, smaps, mask): """ Parameters ---------- x: Input data. y: Subsampled k-space data. smaps: Coil sensitivity maps. mask: Sampling mask. Returns ------- Output data. """ if self.shared_params: num_iter = self.num_iter_total else: num_iter = min(np.where(self.is_trainable)[0][-1] + 1, self.num_iter) for i in range(num_iter): x_thalf = x - self.gradR[i % self.num_iter](x) x = self.gradD[i % self.num_iter](x_thalf, y, smaps, mask) # would run out of memory at test time # if this is False for some cases if self.reset_cache: torch.cuda.empty_cache() torch.backends.cuda.cufft_plan_cache.clear() return x
[docs] def freeze(self, i): """freeze parameter of cascade i""" for param in self.gradR[i].parameters(): param.require_grad_ = False self.is_trainable[i] = False
[docs] def unfreeze(self, i): """freeze parameter of cascade i""" for param in self.gradR[i].parameters(): param.require_grad_ = True self.is_trainable[i] = True
[docs] def freeze_all(self): """freeze parameter of cascade i""" for i in range(self.num_iter): self.freeze(i)
[docs] def unfreeze_all(self): """freeze parameter of cascade i""" for i in range(self.num_iter): self.unfreeze(i)
[docs] def copy_params(self, src_i, trg_j): """copy i-th cascade net parameters to j-th cascade net parameters""" src_params = self.gradR[src_i].parameters() trg_params = self.gradR[trg_j].parameters() for trg_param, src_param in zip(trg_params, src_params): trg_param.data.copy_(src_param.data)
[docs] def stage_training_init(self): """set stage training flag to True""" self.freeze_all() self.unfreeze(0) print(self.is_trainable)
[docs] def stage_training_transition_i(self, copy=False): """set stage training flag to True""" if self.shared_params: return # if all unlocked, don't do anything if not np.all(self.is_trainable): for i in range(self.num_iter): # if last cascade is reached, unlock all if i == self.num_iter - 1: self.unfreeze_all() break # freeze current i, unlock next. copy parameter if specified if self.is_trainable[i]: self.freeze(i) self.unfreeze(i + 1) if copy: self.copy_params(i, i + 1) break