Source code for mridc.collections.reconstruction.models.primaldual.pd

# coding=utf-8
__author__ = "Dimitrios Karkalousos"

# Taken and adapted from: https://github.com/NKI-AI/direct/blob/main/direct/nn/lpd/lpd.py
# Copyright (c) DIRECT Contributors

import torch
import torch.nn as nn


[docs]class DualNet(nn.Module): """Dual Network for Learned Primal Dual Network.""" def __init__(self, num_dual, **kwargs): """ Inits DualNet. Parameters ---------- num_dual: Number of dual for LPD algorithm. kwargs: Keyword arguments. """ super().__init__() if kwargs.get("dual_architecture") is None: n_hidden = kwargs.get("n_hidden") if n_hidden is None: raise ValueError("n_hidden is required for DualNet") self.dual_block = nn.Sequential( *[ nn.Conv2d(2 * (num_dual + 2), n_hidden, kernel_size=3, padding=1), nn.PReLU(), nn.Conv2d(n_hidden, n_hidden, kernel_size=3, padding=1), nn.PReLU(), nn.Conv2d(n_hidden, 2 * num_dual, kernel_size=3, padding=1), ] ) else: self.dual_block = kwargs.get("dual_architecture")
[docs] @staticmethod def compute_model_per_coil(model, data): """ Computes model per coil. Parameters ---------- model: Model to compute. data: Multi-coil input. Returns ------- Multi-coil output. """ output = [] for idx in range(data.size(1)): subselected_data = data.select(1, idx) output.append(model(subselected_data)) output = torch.stack(output, dim=1) return output
[docs] def forward(self, h, forward_f, g): """Forward pass.""" inp = torch.cat([h, forward_f, g], dim=-1).permute(0, 1, 4, 2, 3) return self.compute_model_per_coil(self.dual_block, inp).permute(0, 1, 3, 4, 2)
[docs]class PrimalNet(nn.Module): """Primal Network for Learned Primal Dual Network.""" def __init__(self, num_primal, **kwargs): """ Inits PrimalNet. Parameters ---------- num_primal: Number of primal for LPD algorithm. """ super().__init__() if kwargs.get("primal_architecture") is None: n_hidden = kwargs.get("n_hidden") if n_hidden is None: raise ValueError("Missing argument n_hidden.") self.primal_block = nn.Sequential( *[ nn.Conv2d(2 * (num_primal + 1), n_hidden, kernel_size=3, padding=1), nn.PReLU(), nn.Conv2d(n_hidden, n_hidden, kernel_size=3, padding=1), nn.PReLU(), nn.Conv2d(n_hidden, 2 * num_primal, kernel_size=3, padding=1), ] ) else: self.primal_block = kwargs.get("primal_architecture")
[docs] def forward(self, f, backward_h): """ Forward pass of primal network. Parameters ---------- f: Forward function. backward_h: Backward function. Returns ------- Primal function. """ inp = torch.cat([f, backward_h], dim=-1).permute(0, 3, 1, 2) return self.primal_block(inp).permute(0, 2, 3, 1)