297 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			297 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #!/usr/bin/env python3
 | |
| # -*- coding: utf-8 -*-
 | |
| 
 | |
| import functools
 | |
| import math
 | |
| import re
 | |
| from collections import OrderedDict
 | |
| 
 | |
| import torch
 | |
| import torch.nn as nn
 | |
| import torch.nn.functional as F
 | |
| 
 | |
| from . import block as B
 | |
| 
 | |
| 
 | |
| # Borrowed from https://github.com/rlaphoenix/VSGAN/blob/master/vsgan/archs/ESRGAN.py
 | |
| # Which enhanced stuff that was already here
 | |
| class RRDBNet(nn.Module):
 | |
|     def __init__(
 | |
|         self,
 | |
|         state_dict,
 | |
|         norm=None,
 | |
|         act: str = "leakyrelu",
 | |
|         upsampler: str = "upconv",
 | |
|         mode: B.ConvMode = "CNA",
 | |
|     ) -> None:
 | |
|         """
 | |
|         ESRGAN - Enhanced Super-Resolution Generative Adversarial Networks.
 | |
|         By Xintao Wang, Ke Yu, Shixiang Wu, Jinjin Gu, Yihao Liu, Chao Dong, Yu Qiao,
 | |
|         and Chen Change Loy.
 | |
|         This is old-arch Residual in Residual Dense Block Network and is not
 | |
|         the newest revision that's available at github.com/xinntao/ESRGAN.
 | |
|         This is on purpose, the newest Network has severely limited the
 | |
|         potential use of the Network with no benefits.
 | |
|         This network supports model files from both new and old-arch.
 | |
|         Args:
 | |
|             norm: Normalization layer
 | |
|             act: Activation layer
 | |
|             upsampler: Upsample layer. upconv, pixel_shuffle
 | |
|             mode: Convolution mode
 | |
|         """
 | |
|         super(RRDBNet, self).__init__()
 | |
|         self.model_arch = "ESRGAN"
 | |
|         self.sub_type = "SR"
 | |
| 
 | |
|         self.state = state_dict
 | |
|         self.norm = norm
 | |
|         self.act = act
 | |
|         self.upsampler = upsampler
 | |
|         self.mode = mode
 | |
| 
 | |
|         self.state_map = {
 | |
|             # currently supports old, new, and newer RRDBNet arch models
 | |
|             # ESRGAN, BSRGAN/RealSR, Real-ESRGAN
 | |
|             "model.0.weight": ("conv_first.weight",),
 | |
|             "model.0.bias": ("conv_first.bias",),
 | |
|             "model.1.sub./NB/.weight": ("trunk_conv.weight", "conv_body.weight"),
 | |
|             "model.1.sub./NB/.bias": ("trunk_conv.bias", "conv_body.bias"),
 | |
|             r"model.1.sub.\1.RDB\2.conv\3.0.\4": (
 | |
|                 r"RRDB_trunk\.(\d+)\.RDB(\d)\.conv(\d+)\.(weight|bias)",
 | |
|                 r"body\.(\d+)\.rdb(\d)\.conv(\d+)\.(weight|bias)",
 | |
|             ),
 | |
|         }
 | |
|         if "params_ema" in self.state:
 | |
|             self.state = self.state["params_ema"]
 | |
|             # self.model_arch = "RealESRGAN"
 | |
|         self.num_blocks = self.get_num_blocks()
 | |
|         self.plus = any("conv1x1" in k for k in self.state.keys())
 | |
|         if self.plus:
 | |
|             self.model_arch = "ESRGAN+"
 | |
| 
 | |
|         self.state = self.new_to_old_arch(self.state)
 | |
| 
 | |
|         self.key_arr = list(self.state.keys())
 | |
| 
 | |
|         self.in_nc: int = self.state[self.key_arr[0]].shape[1]
 | |
|         self.out_nc: int = self.state[self.key_arr[-1]].shape[0]
 | |
| 
 | |
|         self.scale: int = self.get_scale()
 | |
|         self.num_filters: int = self.state[self.key_arr[0]].shape[0]
 | |
| 
 | |
|         c2x2 = False
 | |
|         if self.state["model.0.weight"].shape[-2] == 2:
 | |
|             c2x2 = True
 | |
|             self.scale = round(math.sqrt(self.scale / 4))
 | |
|             self.model_arch = "ESRGAN-2c2"
 | |
| 
 | |
|         self.supports_fp16 = True
 | |
|         self.supports_bfp16 = True
 | |
|         self.min_size_restriction = None
 | |
| 
 | |
|         # Detect if pixelunshuffle was used (Real-ESRGAN)
 | |
|         if self.in_nc in (self.out_nc * 4, self.out_nc * 16) and self.out_nc in (
 | |
|             self.in_nc / 4,
 | |
|             self.in_nc / 16,
 | |
|         ):
 | |
|             self.shuffle_factor = int(math.sqrt(self.in_nc / self.out_nc))
 | |
|         else:
 | |
|             self.shuffle_factor = None
 | |
| 
 | |
|         upsample_block = {
 | |
|             "upconv": B.upconv_block,
 | |
|             "pixel_shuffle": B.pixelshuffle_block,
 | |
|         }.get(self.upsampler)
 | |
|         if upsample_block is None:
 | |
|             raise NotImplementedError(f"Upsample mode [{self.upsampler}] is not found")
 | |
| 
 | |
|         if self.scale == 3:
 | |
|             upsample_blocks = upsample_block(
 | |
|                 in_nc=self.num_filters,
 | |
|                 out_nc=self.num_filters,
 | |
|                 upscale_factor=3,
 | |
|                 act_type=self.act,
 | |
|                 c2x2=c2x2,
 | |
|             )
 | |
|         else:
 | |
|             upsample_blocks = [
 | |
|                 upsample_block(
 | |
|                     in_nc=self.num_filters,
 | |
|                     out_nc=self.num_filters,
 | |
|                     act_type=self.act,
 | |
|                     c2x2=c2x2,
 | |
|                 )
 | |
|                 for _ in range(int(math.log(self.scale, 2)))
 | |
|             ]
 | |
| 
 | |
|         self.model = B.sequential(
 | |
|             # fea conv
 | |
|             B.conv_block(
 | |
|                 in_nc=self.in_nc,
 | |
|                 out_nc=self.num_filters,
 | |
|                 kernel_size=3,
 | |
|                 norm_type=None,
 | |
|                 act_type=None,
 | |
|                 c2x2=c2x2,
 | |
|             ),
 | |
|             B.ShortcutBlock(
 | |
|                 B.sequential(
 | |
|                     # rrdb blocks
 | |
|                     *[
 | |
|                         B.RRDB(
 | |
|                             nf=self.num_filters,
 | |
|                             kernel_size=3,
 | |
|                             gc=32,
 | |
|                             stride=1,
 | |
|                             bias=True,
 | |
|                             pad_type="zero",
 | |
|                             norm_type=self.norm,
 | |
|                             act_type=self.act,
 | |
|                             mode="CNA",
 | |
|                             plus=self.plus,
 | |
|                             c2x2=c2x2,
 | |
|                         )
 | |
|                         for _ in range(self.num_blocks)
 | |
|                     ],
 | |
|                     # lr conv
 | |
|                     B.conv_block(
 | |
|                         in_nc=self.num_filters,
 | |
|                         out_nc=self.num_filters,
 | |
|                         kernel_size=3,
 | |
|                         norm_type=self.norm,
 | |
|                         act_type=None,
 | |
|                         mode=self.mode,
 | |
|                         c2x2=c2x2,
 | |
|                     ),
 | |
|                 )
 | |
|             ),
 | |
|             *upsample_blocks,
 | |
|             # hr_conv0
 | |
|             B.conv_block(
 | |
|                 in_nc=self.num_filters,
 | |
|                 out_nc=self.num_filters,
 | |
|                 kernel_size=3,
 | |
|                 norm_type=None,
 | |
|                 act_type=self.act,
 | |
|                 c2x2=c2x2,
 | |
|             ),
 | |
|             # hr_conv1
 | |
|             B.conv_block(
 | |
|                 in_nc=self.num_filters,
 | |
|                 out_nc=self.out_nc,
 | |
|                 kernel_size=3,
 | |
|                 norm_type=None,
 | |
|                 act_type=None,
 | |
|                 c2x2=c2x2,
 | |
|             ),
 | |
|         )
 | |
| 
 | |
|         # Adjust these properties for calculations outside of the model
 | |
|         if self.shuffle_factor:
 | |
|             self.in_nc //= self.shuffle_factor**2
 | |
|             self.scale //= self.shuffle_factor
 | |
| 
 | |
|         self.load_state_dict(self.state, strict=False)
 | |
| 
 | |
|     def new_to_old_arch(self, state):
 | |
|         """Convert a new-arch model state dictionary to an old-arch dictionary."""
 | |
|         if "params_ema" in state:
 | |
|             state = state["params_ema"]
 | |
| 
 | |
|         if "conv_first.weight" not in state:
 | |
|             # model is already old arch, this is a loose check, but should be sufficient
 | |
|             return state
 | |
| 
 | |
|         # add nb to state keys
 | |
|         for kind in ("weight", "bias"):
 | |
|             self.state_map[f"model.1.sub.{self.num_blocks}.{kind}"] = self.state_map[
 | |
|                 f"model.1.sub./NB/.{kind}"
 | |
|             ]
 | |
|             del self.state_map[f"model.1.sub./NB/.{kind}"]
 | |
| 
 | |
|         old_state = OrderedDict()
 | |
|         for old_key, new_keys in self.state_map.items():
 | |
|             for new_key in new_keys:
 | |
|                 if r"\1" in old_key:
 | |
|                     for k, v in state.items():
 | |
|                         sub = re.sub(new_key, old_key, k)
 | |
|                         if sub != k:
 | |
|                             old_state[sub] = v
 | |
|                 else:
 | |
|                     if new_key in state:
 | |
|                         old_state[old_key] = state[new_key]
 | |
| 
 | |
|         # upconv layers
 | |
|         max_upconv = 0
 | |
|         for key in state.keys():
 | |
|             match = re.match(r"(upconv|conv_up)(\d)\.(weight|bias)", key)
 | |
|             if match is not None:
 | |
|                 _, key_num, key_type = match.groups()
 | |
|                 old_state[f"model.{int(key_num) * 3}.{key_type}"] = state[key]
 | |
|                 max_upconv = max(max_upconv, int(key_num) * 3)
 | |
| 
 | |
|         # final layers
 | |
|         for key in state.keys():
 | |
|             if key in ("HRconv.weight", "conv_hr.weight"):
 | |
|                 old_state[f"model.{max_upconv + 2}.weight"] = state[key]
 | |
|             elif key in ("HRconv.bias", "conv_hr.bias"):
 | |
|                 old_state[f"model.{max_upconv + 2}.bias"] = state[key]
 | |
|             elif key in ("conv_last.weight",):
 | |
|                 old_state[f"model.{max_upconv + 4}.weight"] = state[key]
 | |
|             elif key in ("conv_last.bias",):
 | |
|                 old_state[f"model.{max_upconv + 4}.bias"] = state[key]
 | |
| 
 | |
|         # Sort by first numeric value of each layer
 | |
|         def compare(item1, item2):
 | |
|             parts1 = item1.split(".")
 | |
|             parts2 = item2.split(".")
 | |
|             int1 = int(parts1[1])
 | |
|             int2 = int(parts2[1])
 | |
|             return int1 - int2
 | |
| 
 | |
|         sorted_keys = sorted(old_state.keys(), key=functools.cmp_to_key(compare))
 | |
| 
 | |
|         # Rebuild the output dict in the right order
 | |
|         out_dict = OrderedDict((k, old_state[k]) for k in sorted_keys)
 | |
| 
 | |
|         return out_dict
 | |
| 
 | |
|     def get_scale(self, min_part: int = 6) -> int:
 | |
|         n = 0
 | |
|         for part in list(self.state):
 | |
|             parts = part.split(".")[1:]
 | |
|             if len(parts) == 2:
 | |
|                 part_num = int(parts[0])
 | |
|                 if part_num > min_part and parts[1] == "weight":
 | |
|                     n += 1
 | |
|         return 2**n
 | |
| 
 | |
|     def get_num_blocks(self) -> int:
 | |
|         nbs = []
 | |
|         state_keys = self.state_map[r"model.1.sub.\1.RDB\2.conv\3.0.\4"] + (
 | |
|             r"model\.\d+\.sub\.(\d+)\.RDB(\d+)\.conv(\d+)\.0\.(weight|bias)",
 | |
|         )
 | |
|         for state_key in state_keys:
 | |
|             for k in self.state:
 | |
|                 m = re.search(state_key, k)
 | |
|                 if m:
 | |
|                     nbs.append(int(m.group(1)))
 | |
|             if nbs:
 | |
|                 break
 | |
|         return max(*nbs) + 1
 | |
| 
 | |
|     def forward(self, x):
 | |
|         if self.shuffle_factor:
 | |
|             _, _, h, w = x.size()
 | |
|             mod_pad_h = (
 | |
|                 self.shuffle_factor - h % self.shuffle_factor
 | |
|             ) % self.shuffle_factor
 | |
|             mod_pad_w = (
 | |
|                 self.shuffle_factor - w % self.shuffle_factor
 | |
|             ) % self.shuffle_factor
 | |
|             x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
 | |
|             x = torch.pixel_unshuffle(x, downscale_factor=self.shuffle_factor)
 | |
|             x = self.model(x)
 | |
|             return x[:, :, : h * self.scale, : w * self.scale]
 | |
|         return self.model(x)
 |