162 lines
		
	
	
		
			4.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			162 lines
		
	
	
		
			4.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # From https://github.com/Koushik0901/Swift-SRGAN/blob/master/swift-srgan/models.py
 | |
| 
 | |
| import torch
 | |
| from torch import nn
 | |
| 
 | |
| 
 | |
| class SeperableConv2d(nn.Module):
 | |
|     def __init__(
 | |
|         self, in_channels, out_channels, kernel_size, stride=1, padding=1, bias=True
 | |
|     ):
 | |
|         super(SeperableConv2d, self).__init__()
 | |
|         self.depthwise = nn.Conv2d(
 | |
|             in_channels,
 | |
|             in_channels,
 | |
|             kernel_size=kernel_size,
 | |
|             stride=stride,
 | |
|             groups=in_channels,
 | |
|             bias=bias,
 | |
|             padding=padding,
 | |
|         )
 | |
|         self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias)
 | |
| 
 | |
|     def forward(self, x):
 | |
|         return self.pointwise(self.depthwise(x))
 | |
| 
 | |
| 
 | |
| class ConvBlock(nn.Module):
 | |
|     def __init__(
 | |
|         self,
 | |
|         in_channels,
 | |
|         out_channels,
 | |
|         use_act=True,
 | |
|         use_bn=True,
 | |
|         discriminator=False,
 | |
|         **kwargs,
 | |
|     ):
 | |
|         super(ConvBlock, self).__init__()
 | |
| 
 | |
|         self.use_act = use_act
 | |
|         self.cnn = SeperableConv2d(in_channels, out_channels, **kwargs, bias=not use_bn)
 | |
|         self.bn = nn.BatchNorm2d(out_channels) if use_bn else nn.Identity()
 | |
|         self.act = (
 | |
|             nn.LeakyReLU(0.2, inplace=True)
 | |
|             if discriminator
 | |
|             else nn.PReLU(num_parameters=out_channels)
 | |
|         )
 | |
| 
 | |
|     def forward(self, x):
 | |
|         return self.act(self.bn(self.cnn(x))) if self.use_act else self.bn(self.cnn(x))
 | |
| 
 | |
| 
 | |
| class UpsampleBlock(nn.Module):
 | |
|     def __init__(self, in_channels, scale_factor):
 | |
|         super(UpsampleBlock, self).__init__()
 | |
| 
 | |
|         self.conv = SeperableConv2d(
 | |
|             in_channels,
 | |
|             in_channels * scale_factor**2,
 | |
|             kernel_size=3,
 | |
|             stride=1,
 | |
|             padding=1,
 | |
|         )
 | |
|         self.ps = nn.PixelShuffle(
 | |
|             scale_factor
 | |
|         )  # (in_channels * 4, H, W) -> (in_channels, H*2, W*2)
 | |
|         self.act = nn.PReLU(num_parameters=in_channels)
 | |
| 
 | |
|     def forward(self, x):
 | |
|         return self.act(self.ps(self.conv(x)))
 | |
| 
 | |
| 
 | |
| class ResidualBlock(nn.Module):
 | |
|     def __init__(self, in_channels):
 | |
|         super(ResidualBlock, self).__init__()
 | |
| 
 | |
|         self.block1 = ConvBlock(
 | |
|             in_channels, in_channels, kernel_size=3, stride=1, padding=1
 | |
|         )
 | |
|         self.block2 = ConvBlock(
 | |
|             in_channels, in_channels, kernel_size=3, stride=1, padding=1, use_act=False
 | |
|         )
 | |
| 
 | |
|     def forward(self, x):
 | |
|         out = self.block1(x)
 | |
|         out = self.block2(out)
 | |
|         return out + x
 | |
| 
 | |
| 
 | |
| class Generator(nn.Module):
 | |
|     """Swift-SRGAN Generator
 | |
|     Args:
 | |
|         in_channels (int): number of input image channels.
 | |
|         num_channels (int): number of hidden channels.
 | |
|         num_blocks (int): number of residual blocks.
 | |
|         upscale_factor (int): factor to upscale the image [2x, 4x, 8x].
 | |
|     Returns:
 | |
|         torch.Tensor: super resolution image
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         state_dict,
 | |
|     ):
 | |
|         super(Generator, self).__init__()
 | |
|         self.model_arch = "Swift-SRGAN"
 | |
|         self.sub_type = "SR"
 | |
|         self.state = state_dict
 | |
|         if "model" in self.state:
 | |
|             self.state = self.state["model"]
 | |
| 
 | |
|         self.in_nc: int = self.state["initial.cnn.depthwise.weight"].shape[0]
 | |
|         self.out_nc: int = self.state["final_conv.pointwise.weight"].shape[0]
 | |
|         self.num_filters: int = self.state["initial.cnn.pointwise.weight"].shape[0]
 | |
|         self.num_blocks = len(
 | |
|             set([x.split(".")[1] for x in self.state.keys() if "residual" in x])
 | |
|         )
 | |
|         self.scale: int = 2 ** len(
 | |
|             set([x.split(".")[1] for x in self.state.keys() if "upsampler" in x])
 | |
|         )
 | |
| 
 | |
|         in_channels = self.in_nc
 | |
|         num_channels = self.num_filters
 | |
|         num_blocks = self.num_blocks
 | |
|         upscale_factor = self.scale
 | |
| 
 | |
|         self.supports_fp16 = True
 | |
|         self.supports_bfp16 = True
 | |
|         self.min_size_restriction = None
 | |
| 
 | |
|         self.initial = ConvBlock(
 | |
|             in_channels, num_channels, kernel_size=9, stride=1, padding=4, use_bn=False
 | |
|         )
 | |
|         self.residual = nn.Sequential(
 | |
|             *[ResidualBlock(num_channels) for _ in range(num_blocks)]
 | |
|         )
 | |
|         self.convblock = ConvBlock(
 | |
|             num_channels,
 | |
|             num_channels,
 | |
|             kernel_size=3,
 | |
|             stride=1,
 | |
|             padding=1,
 | |
|             use_act=False,
 | |
|         )
 | |
|         self.upsampler = nn.Sequential(
 | |
|             *[
 | |
|                 UpsampleBlock(num_channels, scale_factor=2)
 | |
|                 for _ in range(upscale_factor // 2)
 | |
|             ]
 | |
|         )
 | |
|         self.final_conv = SeperableConv2d(
 | |
|             num_channels, in_channels, kernel_size=9, stride=1, padding=4
 | |
|         )
 | |
| 
 | |
|         self.load_state_dict(self.state, strict=False)
 | |
| 
 | |
|     def forward(self, x):
 | |
|         initial = self.initial(x)
 | |
|         x = self.residual(initial)
 | |
|         x = self.convblock(x) + initial
 | |
|         x = self.upsampler(x)
 | |
|         return (torch.tanh(self.final_conv(x)) + 1) / 2
 |