547 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			547 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #!/usr/bin/env python3
 | |
| # -*- coding: utf-8 -*-
 | |
| 
 | |
| from __future__ import annotations
 | |
| 
 | |
| from collections import OrderedDict
 | |
| try:
 | |
|     from typing import Literal
 | |
| except ImportError:
 | |
|     from typing_extensions import Literal
 | |
| 
 | |
| import torch
 | |
| import torch.nn as nn
 | |
| 
 | |
| ####################
 | |
| # Basic blocks
 | |
| ####################
 | |
| 
 | |
| 
 | |
| def act(act_type: str, inplace=True, neg_slope=0.2, n_prelu=1):
 | |
|     # helper selecting activation
 | |
|     # neg_slope: for leakyrelu and init of prelu
 | |
|     # n_prelu: for p_relu num_parameters
 | |
|     act_type = act_type.lower()
 | |
|     if act_type == "relu":
 | |
|         layer = nn.ReLU(inplace)
 | |
|     elif act_type == "leakyrelu":
 | |
|         layer = nn.LeakyReLU(neg_slope, inplace)
 | |
|     elif act_type == "prelu":
 | |
|         layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
 | |
|     else:
 | |
|         raise NotImplementedError(
 | |
|             "activation layer [{:s}] is not found".format(act_type)
 | |
|         )
 | |
|     return layer
 | |
| 
 | |
| 
 | |
| def norm(norm_type: str, nc: int):
 | |
|     # helper selecting normalization layer
 | |
|     norm_type = norm_type.lower()
 | |
|     if norm_type == "batch":
 | |
|         layer = nn.BatchNorm2d(nc, affine=True)
 | |
|     elif norm_type == "instance":
 | |
|         layer = nn.InstanceNorm2d(nc, affine=False)
 | |
|     else:
 | |
|         raise NotImplementedError(
 | |
|             "normalization layer [{:s}] is not found".format(norm_type)
 | |
|         )
 | |
|     return layer
 | |
| 
 | |
| 
 | |
| def pad(pad_type: str, padding):
 | |
|     # helper selecting padding layer
 | |
|     # if padding is 'zero', do by conv layers
 | |
|     pad_type = pad_type.lower()
 | |
|     if padding == 0:
 | |
|         return None
 | |
|     if pad_type == "reflect":
 | |
|         layer = nn.ReflectionPad2d(padding)
 | |
|     elif pad_type == "replicate":
 | |
|         layer = nn.ReplicationPad2d(padding)
 | |
|     else:
 | |
|         raise NotImplementedError(
 | |
|             "padding layer [{:s}] is not implemented".format(pad_type)
 | |
|         )
 | |
|     return layer
 | |
| 
 | |
| 
 | |
| def get_valid_padding(kernel_size, dilation):
 | |
|     kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
 | |
|     padding = (kernel_size - 1) // 2
 | |
|     return padding
 | |
| 
 | |
| 
 | |
| class ConcatBlock(nn.Module):
 | |
|     # Concat the output of a submodule to its input
 | |
|     def __init__(self, submodule):
 | |
|         super(ConcatBlock, self).__init__()
 | |
|         self.sub = submodule
 | |
| 
 | |
|     def forward(self, x):
 | |
|         output = torch.cat((x, self.sub(x)), dim=1)
 | |
|         return output
 | |
| 
 | |
|     def __repr__(self):
 | |
|         tmpstr = "Identity .. \n|"
 | |
|         modstr = self.sub.__repr__().replace("\n", "\n|")
 | |
|         tmpstr = tmpstr + modstr
 | |
|         return tmpstr
 | |
| 
 | |
| 
 | |
| class ShortcutBlock(nn.Module):
 | |
|     # Elementwise sum the output of a submodule to its input
 | |
|     def __init__(self, submodule):
 | |
|         super(ShortcutBlock, self).__init__()
 | |
|         self.sub = submodule
 | |
| 
 | |
|     def forward(self, x):
 | |
|         output = x + self.sub(x)
 | |
|         return output
 | |
| 
 | |
|     def __repr__(self):
 | |
|         tmpstr = "Identity + \n|"
 | |
|         modstr = self.sub.__repr__().replace("\n", "\n|")
 | |
|         tmpstr = tmpstr + modstr
 | |
|         return tmpstr
 | |
| 
 | |
| 
 | |
| class ShortcutBlockSPSR(nn.Module):
 | |
|     # Elementwise sum the output of a submodule to its input
 | |
|     def __init__(self, submodule):
 | |
|         super(ShortcutBlockSPSR, self).__init__()
 | |
|         self.sub = submodule
 | |
| 
 | |
|     def forward(self, x):
 | |
|         return x, self.sub
 | |
| 
 | |
|     def __repr__(self):
 | |
|         tmpstr = "Identity + \n|"
 | |
|         modstr = self.sub.__repr__().replace("\n", "\n|")
 | |
|         tmpstr = tmpstr + modstr
 | |
|         return tmpstr
 | |
| 
 | |
| 
 | |
| def sequential(*args):
 | |
|     # Flatten Sequential. It unwraps nn.Sequential.
 | |
|     if len(args) == 1:
 | |
|         if isinstance(args[0], OrderedDict):
 | |
|             raise NotImplementedError("sequential does not support OrderedDict input.")
 | |
|         return args[0]  # No sequential is needed.
 | |
|     modules = []
 | |
|     for module in args:
 | |
|         if isinstance(module, nn.Sequential):
 | |
|             for submodule in module.children():
 | |
|                 modules.append(submodule)
 | |
|         elif isinstance(module, nn.Module):
 | |
|             modules.append(module)
 | |
|     return nn.Sequential(*modules)
 | |
| 
 | |
| 
 | |
| ConvMode = Literal["CNA", "NAC", "CNAC"]
 | |
| 
 | |
| 
 | |
| # 2x2x2 Conv Block
 | |
| def conv_block_2c2(
 | |
|     in_nc,
 | |
|     out_nc,
 | |
|     act_type="relu",
 | |
| ):
 | |
|     return sequential(
 | |
|         nn.Conv2d(in_nc, out_nc, kernel_size=2, padding=1),
 | |
|         nn.Conv2d(out_nc, out_nc, kernel_size=2, padding=0),
 | |
|         act(act_type) if act_type else None,
 | |
|     )
 | |
| 
 | |
| 
 | |
| def conv_block(
 | |
|     in_nc: int,
 | |
|     out_nc: int,
 | |
|     kernel_size,
 | |
|     stride=1,
 | |
|     dilation=1,
 | |
|     groups=1,
 | |
|     bias=True,
 | |
|     pad_type="zero",
 | |
|     norm_type: str | None = None,
 | |
|     act_type: str | None = "relu",
 | |
|     mode: ConvMode = "CNA",
 | |
|     c2x2=False,
 | |
| ):
 | |
|     """
 | |
|     Conv layer with padding, normalization, activation
 | |
|     mode: CNA --> Conv -> Norm -> Act
 | |
|         NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
 | |
|     """
 | |
| 
 | |
|     if c2x2:
 | |
|         return conv_block_2c2(in_nc, out_nc, act_type=act_type)
 | |
| 
 | |
|     assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode)
 | |
|     padding = get_valid_padding(kernel_size, dilation)
 | |
|     p = pad(pad_type, padding) if pad_type and pad_type != "zero" else None
 | |
|     padding = padding if pad_type == "zero" else 0
 | |
| 
 | |
|     c = nn.Conv2d(
 | |
|         in_nc,
 | |
|         out_nc,
 | |
|         kernel_size=kernel_size,
 | |
|         stride=stride,
 | |
|         padding=padding,
 | |
|         dilation=dilation,
 | |
|         bias=bias,
 | |
|         groups=groups,
 | |
|     )
 | |
|     a = act(act_type) if act_type else None
 | |
|     if mode in ("CNA", "CNAC"):
 | |
|         n = norm(norm_type, out_nc) if norm_type else None
 | |
|         return sequential(p, c, n, a)
 | |
|     elif mode == "NAC":
 | |
|         if norm_type is None and act_type is not None:
 | |
|             a = act(act_type, inplace=False)
 | |
|             # Important!
 | |
|             # input----ReLU(inplace)----Conv--+----output
 | |
|             #        |________________________|
 | |
|             # inplace ReLU will modify the input, therefore wrong output
 | |
|         n = norm(norm_type, in_nc) if norm_type else None
 | |
|         return sequential(n, a, p, c)
 | |
|     else:
 | |
|         assert False, f"Invalid conv mode {mode}"
 | |
| 
 | |
| 
 | |
| ####################
 | |
| # Useful blocks
 | |
| ####################
 | |
| 
 | |
| 
 | |
| class ResNetBlock(nn.Module):
 | |
|     """
 | |
|     ResNet Block, 3-3 style
 | |
|     with extra residual scaling used in EDSR
 | |
|     (Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17)
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         in_nc,
 | |
|         mid_nc,
 | |
|         out_nc,
 | |
|         kernel_size=3,
 | |
|         stride=1,
 | |
|         dilation=1,
 | |
|         groups=1,
 | |
|         bias=True,
 | |
|         pad_type="zero",
 | |
|         norm_type=None,
 | |
|         act_type="relu",
 | |
|         mode: ConvMode = "CNA",
 | |
|         res_scale=1,
 | |
|     ):
 | |
|         super(ResNetBlock, self).__init__()
 | |
|         conv0 = conv_block(
 | |
|             in_nc,
 | |
|             mid_nc,
 | |
|             kernel_size,
 | |
|             stride,
 | |
|             dilation,
 | |
|             groups,
 | |
|             bias,
 | |
|             pad_type,
 | |
|             norm_type,
 | |
|             act_type,
 | |
|             mode,
 | |
|         )
 | |
|         if mode == "CNA":
 | |
|             act_type = None
 | |
|         if mode == "CNAC":  # Residual path: |-CNAC-|
 | |
|             act_type = None
 | |
|             norm_type = None
 | |
|         conv1 = conv_block(
 | |
|             mid_nc,
 | |
|             out_nc,
 | |
|             kernel_size,
 | |
|             stride,
 | |
|             dilation,
 | |
|             groups,
 | |
|             bias,
 | |
|             pad_type,
 | |
|             norm_type,
 | |
|             act_type,
 | |
|             mode,
 | |
|         )
 | |
|         # if in_nc != out_nc:
 | |
|         #     self.project = conv_block(in_nc, out_nc, 1, stride, dilation, 1, bias, pad_type, \
 | |
|         #         None, None)
 | |
|         #     print('Need a projecter in ResNetBlock.')
 | |
|         # else:
 | |
|         #     self.project = lambda x:x
 | |
|         self.res = sequential(conv0, conv1)
 | |
|         self.res_scale = res_scale
 | |
| 
 | |
|     def forward(self, x):
 | |
|         res = self.res(x).mul(self.res_scale)
 | |
|         return x + res
 | |
| 
 | |
| 
 | |
| class RRDB(nn.Module):
 | |
|     """
 | |
|     Residual in Residual Dense Block
 | |
|     (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         nf,
 | |
|         kernel_size=3,
 | |
|         gc=32,
 | |
|         stride=1,
 | |
|         bias: bool = True,
 | |
|         pad_type="zero",
 | |
|         norm_type=None,
 | |
|         act_type="leakyrelu",
 | |
|         mode: ConvMode = "CNA",
 | |
|         _convtype="Conv2D",
 | |
|         _spectral_norm=False,
 | |
|         plus=False,
 | |
|         c2x2=False,
 | |
|     ):
 | |
|         super(RRDB, self).__init__()
 | |
|         self.RDB1 = ResidualDenseBlock_5C(
 | |
|             nf,
 | |
|             kernel_size,
 | |
|             gc,
 | |
|             stride,
 | |
|             bias,
 | |
|             pad_type,
 | |
|             norm_type,
 | |
|             act_type,
 | |
|             mode,
 | |
|             plus=plus,
 | |
|             c2x2=c2x2,
 | |
|         )
 | |
|         self.RDB2 = ResidualDenseBlock_5C(
 | |
|             nf,
 | |
|             kernel_size,
 | |
|             gc,
 | |
|             stride,
 | |
|             bias,
 | |
|             pad_type,
 | |
|             norm_type,
 | |
|             act_type,
 | |
|             mode,
 | |
|             plus=plus,
 | |
|             c2x2=c2x2,
 | |
|         )
 | |
|         self.RDB3 = ResidualDenseBlock_5C(
 | |
|             nf,
 | |
|             kernel_size,
 | |
|             gc,
 | |
|             stride,
 | |
|             bias,
 | |
|             pad_type,
 | |
|             norm_type,
 | |
|             act_type,
 | |
|             mode,
 | |
|             plus=plus,
 | |
|             c2x2=c2x2,
 | |
|         )
 | |
| 
 | |
|     def forward(self, x):
 | |
|         out = self.RDB1(x)
 | |
|         out = self.RDB2(out)
 | |
|         out = self.RDB3(out)
 | |
|         return out * 0.2 + x
 | |
| 
 | |
| 
 | |
| class ResidualDenseBlock_5C(nn.Module):
 | |
|     """
 | |
|     Residual Dense Block
 | |
|     style: 5 convs
 | |
|     The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
 | |
|     Modified options that can be used:
 | |
|         - "Partial Convolution based Padding" arXiv:1811.11718
 | |
|         - "Spectral normalization" arXiv:1802.05957
 | |
|         - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
 | |
|             {Rakotonirina} and A. {Rasoanaivo}
 | |
| 
 | |
|     Args:
 | |
|         nf (int): Channel number of intermediate features (num_feat).
 | |
|         gc (int): Channels for each growth (num_grow_ch: growth channel,
 | |
|             i.e. intermediate channels).
 | |
|         convtype (str): the type of convolution to use. Default: 'Conv2D'
 | |
|         gaussian_noise (bool): enable the ESRGAN+ gaussian noise (no new
 | |
|             trainable parameters)
 | |
|         plus (bool): enable the additional residual paths from ESRGAN+
 | |
|             (adds trainable parameters)
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         nf=64,
 | |
|         kernel_size=3,
 | |
|         gc=32,
 | |
|         stride=1,
 | |
|         bias: bool = True,
 | |
|         pad_type="zero",
 | |
|         norm_type=None,
 | |
|         act_type="leakyrelu",
 | |
|         mode: ConvMode = "CNA",
 | |
|         plus=False,
 | |
|         c2x2=False,
 | |
|     ):
 | |
|         super(ResidualDenseBlock_5C, self).__init__()
 | |
| 
 | |
|         ## +
 | |
|         self.conv1x1 = conv1x1(nf, gc) if plus else None
 | |
|         ## +
 | |
| 
 | |
|         self.conv1 = conv_block(
 | |
|             nf,
 | |
|             gc,
 | |
|             kernel_size,
 | |
|             stride,
 | |
|             bias=bias,
 | |
|             pad_type=pad_type,
 | |
|             norm_type=norm_type,
 | |
|             act_type=act_type,
 | |
|             mode=mode,
 | |
|             c2x2=c2x2,
 | |
|         )
 | |
|         self.conv2 = conv_block(
 | |
|             nf + gc,
 | |
|             gc,
 | |
|             kernel_size,
 | |
|             stride,
 | |
|             bias=bias,
 | |
|             pad_type=pad_type,
 | |
|             norm_type=norm_type,
 | |
|             act_type=act_type,
 | |
|             mode=mode,
 | |
|             c2x2=c2x2,
 | |
|         )
 | |
|         self.conv3 = conv_block(
 | |
|             nf + 2 * gc,
 | |
|             gc,
 | |
|             kernel_size,
 | |
|             stride,
 | |
|             bias=bias,
 | |
|             pad_type=pad_type,
 | |
|             norm_type=norm_type,
 | |
|             act_type=act_type,
 | |
|             mode=mode,
 | |
|             c2x2=c2x2,
 | |
|         )
 | |
|         self.conv4 = conv_block(
 | |
|             nf + 3 * gc,
 | |
|             gc,
 | |
|             kernel_size,
 | |
|             stride,
 | |
|             bias=bias,
 | |
|             pad_type=pad_type,
 | |
|             norm_type=norm_type,
 | |
|             act_type=act_type,
 | |
|             mode=mode,
 | |
|             c2x2=c2x2,
 | |
|         )
 | |
|         if mode == "CNA":
 | |
|             last_act = None
 | |
|         else:
 | |
|             last_act = act_type
 | |
|         self.conv5 = conv_block(
 | |
|             nf + 4 * gc,
 | |
|             nf,
 | |
|             3,
 | |
|             stride,
 | |
|             bias=bias,
 | |
|             pad_type=pad_type,
 | |
|             norm_type=norm_type,
 | |
|             act_type=last_act,
 | |
|             mode=mode,
 | |
|             c2x2=c2x2,
 | |
|         )
 | |
| 
 | |
|     def forward(self, x):
 | |
|         x1 = self.conv1(x)
 | |
|         x2 = self.conv2(torch.cat((x, x1), 1))
 | |
|         if self.conv1x1:
 | |
|             # pylint: disable=not-callable
 | |
|             x2 = x2 + self.conv1x1(x)  # +
 | |
|         x3 = self.conv3(torch.cat((x, x1, x2), 1))
 | |
|         x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
 | |
|         if self.conv1x1:
 | |
|             x4 = x4 + x2  # +
 | |
|         x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
 | |
|         return x5 * 0.2 + x
 | |
| 
 | |
| 
 | |
| def conv1x1(in_planes, out_planes, stride=1):
 | |
|     return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
 | |
| 
 | |
| 
 | |
| ####################
 | |
| # Upsampler
 | |
| ####################
 | |
| 
 | |
| 
 | |
| def pixelshuffle_block(
 | |
|     in_nc: int,
 | |
|     out_nc: int,
 | |
|     upscale_factor=2,
 | |
|     kernel_size=3,
 | |
|     stride=1,
 | |
|     bias=True,
 | |
|     pad_type="zero",
 | |
|     norm_type: str | None = None,
 | |
|     act_type="relu",
 | |
| ):
 | |
|     """
 | |
|     Pixel shuffle layer
 | |
|     (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
 | |
|     Neural Network, CVPR17)
 | |
|     """
 | |
|     conv = conv_block(
 | |
|         in_nc,
 | |
|         out_nc * (upscale_factor**2),
 | |
|         kernel_size,
 | |
|         stride,
 | |
|         bias=bias,
 | |
|         pad_type=pad_type,
 | |
|         norm_type=None,
 | |
|         act_type=None,
 | |
|     )
 | |
|     pixel_shuffle = nn.PixelShuffle(upscale_factor)
 | |
| 
 | |
|     n = norm(norm_type, out_nc) if norm_type else None
 | |
|     a = act(act_type) if act_type else None
 | |
|     return sequential(conv, pixel_shuffle, n, a)
 | |
| 
 | |
| 
 | |
| def upconv_block(
 | |
|     in_nc: int,
 | |
|     out_nc: int,
 | |
|     upscale_factor=2,
 | |
|     kernel_size=3,
 | |
|     stride=1,
 | |
|     bias=True,
 | |
|     pad_type="zero",
 | |
|     norm_type: str | None = None,
 | |
|     act_type="relu",
 | |
|     mode="nearest",
 | |
|     c2x2=False,
 | |
| ):
 | |
|     # Up conv
 | |
|     # described in https://distill.pub/2016/deconv-checkerboard/
 | |
|     upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
 | |
|     conv = conv_block(
 | |
|         in_nc,
 | |
|         out_nc,
 | |
|         kernel_size,
 | |
|         stride,
 | |
|         bias=bias,
 | |
|         pad_type=pad_type,
 | |
|         norm_type=norm_type,
 | |
|         act_type=act_type,
 | |
|         c2x2=c2x2,
 | |
|     )
 | |
|     return sequential(upsample, conv)
 |