# coding=utf-8
__author__ = "Dimitrios Karkalousos"
# Taken and adapted from: https://github.com/NKI-AI/direct/blob/main/direct/nn/mwcnn/mwcnn.py
# Copyright (c) DIRECT Contributors
from collections import OrderedDict
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
[docs]class DWT(nn.Module):
"""
2D Discrete Wavelet Transform as implemented in Liu, Pengju, et al.
References
----------
..
Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. \
arXiv.org, http://arxiv.org/abs/1805.07071.
"""
def __init__(self):
"""Inits DWT."""
super().__init__()
self.requires_grad = False
[docs] @staticmethod
def forward(x: torch.Tensor) -> torch.Tensor:
"""
Computes DWT(`x`) given tensor `x`.
Parameters
----------
x: Input tensor.
Returns
-------
DWT of `x`.
"""
x01 = x[:, :, 0::2, :] / 2
x02 = x[:, :, 1::2, :] / 2
x1 = x01[:, :, :, 0::2]
x2 = x02[:, :, :, 0::2]
x3 = x01[:, :, :, 1::2]
x4 = x02[:, :, :, 1::2]
x_LL = x1 + x2 + x3 + x4
x_HL = -x1 - x2 + x3 + x4
x_LH = -x1 + x2 - x3 + x4
x_HH = x1 - x2 - x3 + x4
return torch.cat((x_LL, x_HL, x_LH, x_HH), 1)
[docs]class IWT(nn.Module):
"""
2D Inverse Wavelet Transform as implemented in Liu, Pengju, et al.
References
----------
..
Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. \
arXiv.org, http://arxiv.org/abs/1805.07071.
"""
def __init__(self):
"""Inits IWT."""
super().__init__()
self.requires_grad = False
self._r = 2
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Computes IWT(`x`) given tensor `x`.
Parameters
----------
x: Input tensor.
Returns
-------
IWT of `x`.
"""
batch, in_channel, in_height, in_width = x.size()
out_channel, out_height, out_width = int(in_channel / (self._r**2)), self._r * in_height, self._r * in_width
x1 = x[:, 0:out_channel, :, :] / 2
x2 = x[:, out_channel : out_channel * 2, :, :] / 2
x3 = x[:, out_channel * 2 : out_channel * 3, :, :] / 2
x4 = x[:, out_channel * 3 : out_channel * 4, :, :] / 2
h = torch.zeros([batch, out_channel, out_height, out_width], dtype=x.dtype).to(x.device)
h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4
h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4
h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4
h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4
return h
[docs]class ConvBlock(nn.Module):
"""
Convolution Block for MWCNN as implemented in Liu, Pengju, et al.
References
----------
..
Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. \
arXiv.org, http://arxiv.org/abs/1805.07071.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
bias: bool = True,
batchnorm: bool = False,
activation: nn.Module = nn.ReLU(True),
scale: Optional[float] = 1.0,
):
"""
Inits ConvBlock.
Parameters
----------
in_channels: Number of input channels.
int
out_channels: Number of output channels.
int
kernel_size: Conv kernel size.
int
bias: Use convolution bias.
bool, Default: True.
batchnorm: Use batch normalization.
bool, Default: False.
activation: Activation function.
torch.nn.Module, Default: nn.ReLU(True).
scale: Scale factor for convolution.
float (optional), Default: 1.0.
"""
super().__init__()
net = [
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
bias=bias,
padding=kernel_size // 2,
)
]
if batchnorm:
net.append(nn.BatchNorm2d(num_features=out_channels, eps=1e-4, momentum=0.95))
net.append(activation)
self.net = nn.Sequential(*net)
self.scale = scale
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Performs forward pass of ConvBlock.
Parameters
----------
x: Input with shape (N, C, H, W).
Returns
-------
Output with shape (N, C', H', W').
"""
return self.net(x) * self.scale
[docs]class DilatedConvBlock(nn.Module):
"""
Double dilated Convolution Block fpr MWCNN as implemented in Liu, Pengju, et al.
References
----------
..
Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. \
arXiv.org, http://arxiv.org/abs/1805.07071.
"""
def __init__(
self,
in_channels: int,
dilations: Tuple[int, int],
kernel_size: int,
out_channels: Optional[int] = None,
bias: bool = True,
batchnorm: bool = False,
activation: nn.Module = nn.ReLU(True),
scale: Optional[float] = 1.0,
):
"""
Inits DilatedConvBlock.
Parameters
----------
in_channels: Number of input channels.
int
dilations: Number of dilations.
Tuple[int, int], Default: (1, 1).
kernel_size: Conv kernel size.
int
out_channels: Number of output channels.
int (optional), Default: None.
bias: Use convolution bias.
bool, Default: True.
batchnorm: Use batch normalization.
bool, Default: False.
activation: Activation function.
torch.nn.Module, Default: nn.ReLU(True).
scale: Scale factor for convolution.
float (optional), Default: 1.0.
"""
super().__init__()
net = [
nn.Conv2d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=kernel_size,
bias=bias,
dilation=dilations[0],
padding=kernel_size // 2 + dilations[0] - 1,
)
]
if batchnorm:
net.append(nn.BatchNorm2d(num_features=in_channels, eps=1e-4, momentum=0.95))
net.append(activation)
if out_channels is None:
out_channels = in_channels
net.append(
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
bias=bias,
dilation=dilations[1],
padding=kernel_size // 2 + dilations[1] - 1,
)
)
if batchnorm:
net.append(nn.BatchNorm2d(num_features=in_channels, eps=1e-4, momentum=0.95))
net.append(activation)
self.net = nn.Sequential(*net)
self.scale = scale
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Performs forward pass of DilatedConvBlock.
Parameters
----------
x: Input with shape (N, C, H, W).
Returns
-------
Output with shape (N, C', H', W').
"""
return self.net(x) * self.scale
[docs]class MWCNN(nn.Module):
"""
Multi-level Wavelet CNN (MWCNN) implementation as implemented in Liu, Pengju, et al.
References
----------
..
Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. \
arXiv.org, http://arxiv.org/abs/1805.07071.
"""
def __init__(
self,
input_channels: int,
first_conv_hidden_channels: int,
num_scales: int = 4,
bias: bool = True,
batchnorm: bool = False,
activation: nn.Module = nn.ReLU(True),
):
"""
Inits MWCNN.
Parameters
----------
input_channels: Input channels dimension.
int
first_conv_hidden_channels: First convolution output channels dimension.
int
num_scales: Number of scales.
int, Default: 4.
bias: Convolution bias. If True, adds a learnable bias to the output.
bool, Default: True.
batchnorm: If True, a batchnorm layer is added after each convolution.
bool, Default: False.
activation: Activation function applied after each convolution.
torch.nn.Module, Default: nn.ReLU().
"""
super().__init__()
self._kernel_size = 3
self.DWT = DWT()
self.IWT = IWT()
self.down = nn.ModuleList()
for idx in range(num_scales):
in_channels = input_channels if idx == 0 else first_conv_hidden_channels * 2 ** (idx + 1)
out_channels = first_conv_hidden_channels * 2**idx
dilations = (2, 1) if idx != num_scales - 1 else (2, 3)
self.down.append(
nn.Sequential(
OrderedDict(
[
(
f"convblock{idx}",
ConvBlock(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=self._kernel_size,
bias=bias,
batchnorm=batchnorm,
activation=activation,
),
),
(
f"dilconvblock{idx}",
DilatedConvBlock(
in_channels=out_channels,
dilations=dilations,
kernel_size=self._kernel_size,
bias=bias,
batchnorm=batchnorm,
activation=activation,
),
),
]
)
)
)
self.up = nn.ModuleList()
for idx in range(num_scales)[::-1]:
in_channels = first_conv_hidden_channels * 2**idx
out_channels = input_channels if idx == 0 else first_conv_hidden_channels * 2 ** (idx + 1)
dilations = (2, 1) if idx != num_scales - 1 else (3, 2)
self.up.append(
nn.Sequential(
OrderedDict(
[
(
f"invdilconvblock{num_scales - 2 - idx}",
DilatedConvBlock(
in_channels=in_channels,
dilations=dilations,
kernel_size=self._kernel_size,
bias=bias,
batchnorm=batchnorm,
activation=activation,
),
),
(
f"invconvblock{num_scales - 2 - idx}",
ConvBlock(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=self._kernel_size,
bias=bias,
batchnorm=batchnorm,
activation=activation,
),
),
]
)
)
)
self.num_scales = num_scales
[docs] @staticmethod
def pad(x):
"""
Pad the input with zeros.
Parameters
----------
x: Input tensor.
Returns
-------
Padded tensor.
"""
padding = [0, 0, 0, 0]
if x.shape[-2] % 2 != 0:
padding[3] = 1 # Padding right - width
if x.shape[-1] % 2 != 0:
padding[1] = 1 # Padding bottom - height
if sum(padding) != 0:
x = F.pad(x, padding, "reflect")
return x
[docs] @staticmethod
def crop_to_shape(x, shape):
"""
Crop the input to the given shape.
Parameters
----------
x: Input tensor.
shape: Tuple of (height, width).
Returns
-------
Cropped tensor.
"""
h, w = x.shape[-2:]
if h > shape[0]:
x = x[:, :, : shape[0], :]
if w > shape[1]:
x = x[:, :, :, : shape[1]]
return x
[docs] def forward(self, input_tensor: torch.Tensor, res: bool = False) -> torch.Tensor:
"""
Computes forward pass of MWCNN.
Parameters
----------
input_tensor: Input tensor.
torch.tensor
res: If True, residual connection is applied to the output.
bool, Default: False.
Returns
-------
Output tensor.
"""
res_values = []
x = self.pad(input_tensor.clone())
for idx in range(self.num_scales):
if idx == 0:
x = self.pad(self.down[idx](x))
res_values.append(x)
elif idx == self.num_scales - 1:
x = self.down[idx](self.DWT(x))
else:
x = self.pad(self.down[idx](self.DWT(x)))
res_values.append(x)
for idx in range(self.num_scales):
if idx != self.num_scales - 1:
x = (
self.crop_to_shape(self.IWT(self.up[idx](x)), res_values[self.num_scales - 2 - idx].shape[-2:])
+ res_values[self.num_scales - 2 - idx]
)
else:
x = self.crop_to_shape(self.up[idx](x), input_tensor.shape[-2:])
if res:
x += input_tensor
return x