695 lines
		
	
	
		
			21 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			695 lines
		
	
	
		
			21 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # pylint: skip-file
 | |
| """
 | |
| Model adapted from advimman's lama project: https://github.com/advimman/lama
 | |
| """
 | |
| 
 | |
| # Fast Fourier Convolution NeurIPS 2020
 | |
| # original implementation https://github.com/pkumivision/FFC/blob/main/model_zoo/ffc.py
 | |
| # paper https://proceedings.neurips.cc/paper/2020/file/2fd5d41ec6cfab47e32164d5624269b1-Paper.pdf
 | |
| 
 | |
| from typing import List
 | |
| 
 | |
| import torch
 | |
| import torch.nn as nn
 | |
| import torch.nn.functional as F
 | |
| from torchvision.transforms.functional import InterpolationMode, rotate
 | |
| 
 | |
| 
 | |
| class LearnableSpatialTransformWrapper(nn.Module):
 | |
|     def __init__(self, impl, pad_coef=0.5, angle_init_range=80, train_angle=True):
 | |
|         super().__init__()
 | |
|         self.impl = impl
 | |
|         self.angle = torch.rand(1) * angle_init_range
 | |
|         if train_angle:
 | |
|             self.angle = nn.Parameter(self.angle, requires_grad=True)
 | |
|         self.pad_coef = pad_coef
 | |
| 
 | |
|     def forward(self, x):
 | |
|         if torch.is_tensor(x):
 | |
|             return self.inverse_transform(self.impl(self.transform(x)), x)
 | |
|         elif isinstance(x, tuple):
 | |
|             x_trans = tuple(self.transform(elem) for elem in x)
 | |
|             y_trans = self.impl(x_trans)
 | |
|             return tuple(
 | |
|                 self.inverse_transform(elem, orig_x) for elem, orig_x in zip(y_trans, x)
 | |
|             )
 | |
|         else:
 | |
|             raise ValueError(f"Unexpected input type {type(x)}")
 | |
| 
 | |
|     def transform(self, x):
 | |
|         height, width = x.shape[2:]
 | |
|         pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef)
 | |
|         x_padded = F.pad(x, [pad_w, pad_w, pad_h, pad_h], mode="reflect")
 | |
|         x_padded_rotated = rotate(
 | |
|             x_padded, self.angle.to(x_padded), InterpolationMode.BILINEAR, fill=0
 | |
|         )
 | |
| 
 | |
|         return x_padded_rotated
 | |
| 
 | |
|     def inverse_transform(self, y_padded_rotated, orig_x):
 | |
|         height, width = orig_x.shape[2:]
 | |
|         pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef)
 | |
| 
 | |
|         y_padded = rotate(
 | |
|             y_padded_rotated,
 | |
|             -self.angle.to(y_padded_rotated),
 | |
|             InterpolationMode.BILINEAR,
 | |
|             fill=0,
 | |
|         )
 | |
|         y_height, y_width = y_padded.shape[2:]
 | |
|         y = y_padded[:, :, pad_h : y_height - pad_h, pad_w : y_width - pad_w]
 | |
|         return y
 | |
| 
 | |
| 
 | |
| class SELayer(nn.Module):
 | |
|     def __init__(self, channel, reduction=16):
 | |
|         super(SELayer, self).__init__()
 | |
|         self.avg_pool = nn.AdaptiveAvgPool2d(1)
 | |
|         self.fc = nn.Sequential(
 | |
|             nn.Linear(channel, channel // reduction, bias=False),
 | |
|             nn.ReLU(inplace=True),
 | |
|             nn.Linear(channel // reduction, channel, bias=False),
 | |
|             nn.Sigmoid(),
 | |
|         )
 | |
| 
 | |
|     def forward(self, x):
 | |
|         b, c, _, _ = x.size()
 | |
|         y = self.avg_pool(x).view(b, c)
 | |
|         y = self.fc(y).view(b, c, 1, 1)
 | |
|         res = x * y.expand_as(x)
 | |
|         return res
 | |
| 
 | |
| 
 | |
| class FourierUnit(nn.Module):
 | |
|     def __init__(
 | |
|         self,
 | |
|         in_channels,
 | |
|         out_channels,
 | |
|         groups=1,
 | |
|         spatial_scale_factor=None,
 | |
|         spatial_scale_mode="bilinear",
 | |
|         spectral_pos_encoding=False,
 | |
|         use_se=False,
 | |
|         se_kwargs=None,
 | |
|         ffc3d=False,
 | |
|         fft_norm="ortho",
 | |
|     ):
 | |
|         # bn_layer not used
 | |
|         super(FourierUnit, self).__init__()
 | |
|         self.groups = groups
 | |
| 
 | |
|         self.conv_layer = torch.nn.Conv2d(
 | |
|             in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0),
 | |
|             out_channels=out_channels * 2,
 | |
|             kernel_size=1,
 | |
|             stride=1,
 | |
|             padding=0,
 | |
|             groups=self.groups,
 | |
|             bias=False,
 | |
|         )
 | |
|         self.bn = torch.nn.BatchNorm2d(out_channels * 2)
 | |
|         self.relu = torch.nn.ReLU(inplace=True)
 | |
| 
 | |
|         # squeeze and excitation block
 | |
|         self.use_se = use_se
 | |
|         if use_se:
 | |
|             if se_kwargs is None:
 | |
|                 se_kwargs = {}
 | |
|             self.se = SELayer(self.conv_layer.in_channels, **se_kwargs)
 | |
| 
 | |
|         self.spatial_scale_factor = spatial_scale_factor
 | |
|         self.spatial_scale_mode = spatial_scale_mode
 | |
|         self.spectral_pos_encoding = spectral_pos_encoding
 | |
|         self.ffc3d = ffc3d
 | |
|         self.fft_norm = fft_norm
 | |
| 
 | |
|     def forward(self, x):
 | |
|         half_check = False
 | |
|         if x.type() == "torch.cuda.HalfTensor":
 | |
|             # half only works on gpu anyway
 | |
|             half_check = True
 | |
| 
 | |
|         batch = x.shape[0]
 | |
| 
 | |
|         if self.spatial_scale_factor is not None:
 | |
|             orig_size = x.shape[-2:]
 | |
|             x = F.interpolate(
 | |
|                 x,
 | |
|                 scale_factor=self.spatial_scale_factor,
 | |
|                 mode=self.spatial_scale_mode,
 | |
|                 align_corners=False,
 | |
|             )
 | |
| 
 | |
|         # (batch, c, h, w/2+1, 2)
 | |
|         fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
 | |
|         if half_check == True:
 | |
|             ffted = torch.fft.rfftn(
 | |
|                 x.float(), dim=fft_dim, norm=self.fft_norm
 | |
|             )  # .type(torch.cuda.HalfTensor)
 | |
|         else:
 | |
|             ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
 | |
| 
 | |
|         ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
 | |
|         ffted = ffted.permute(0, 1, 4, 2, 3).contiguous()  # (batch, c, 2, h, w/2+1)
 | |
|         ffted = ffted.view(
 | |
|             (
 | |
|                 batch,
 | |
|                 -1,
 | |
|             )
 | |
|             + ffted.size()[3:]
 | |
|         )
 | |
| 
 | |
|         if self.spectral_pos_encoding:
 | |
|             height, width = ffted.shape[-2:]
 | |
|             coords_vert = (
 | |
|                 torch.linspace(0, 1, height)[None, None, :, None]
 | |
|                 .expand(batch, 1, height, width)
 | |
|                 .to(ffted)
 | |
|             )
 | |
|             coords_hor = (
 | |
|                 torch.linspace(0, 1, width)[None, None, None, :]
 | |
|                 .expand(batch, 1, height, width)
 | |
|                 .to(ffted)
 | |
|             )
 | |
|             ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1)
 | |
| 
 | |
|         if self.use_se:
 | |
|             ffted = self.se(ffted)
 | |
| 
 | |
|         if half_check == True:
 | |
|             ffted = self.conv_layer(ffted.half())  # (batch, c*2, h, w/2+1)
 | |
|         else:
 | |
|             ffted = self.conv_layer(
 | |
|                 ffted
 | |
|             )  # .type(torch.cuda.FloatTensor)  # (batch, c*2, h, w/2+1)
 | |
| 
 | |
|         ffted = self.relu(self.bn(ffted))
 | |
|         # forcing to be always float
 | |
|         ffted = ffted.float()
 | |
| 
 | |
|         ffted = (
 | |
|             ffted.view(
 | |
|                 (
 | |
|                     batch,
 | |
|                     -1,
 | |
|                     2,
 | |
|                 )
 | |
|                 + ffted.size()[2:]
 | |
|             )
 | |
|             .permute(0, 1, 3, 4, 2)
 | |
|             .contiguous()
 | |
|         )  # (batch,c, t, h, w/2+1, 2)
 | |
| 
 | |
|         ffted = torch.complex(ffted[..., 0], ffted[..., 1])
 | |
| 
 | |
|         ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
 | |
|         output = torch.fft.irfftn(
 | |
|             ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm
 | |
|         )
 | |
| 
 | |
|         if half_check == True:
 | |
|             output = output.half()
 | |
| 
 | |
|         if self.spatial_scale_factor is not None:
 | |
|             output = F.interpolate(
 | |
|                 output,
 | |
|                 size=orig_size,
 | |
|                 mode=self.spatial_scale_mode,
 | |
|                 align_corners=False,
 | |
|             )
 | |
| 
 | |
|         return output
 | |
| 
 | |
| 
 | |
| class SpectralTransform(nn.Module):
 | |
|     def __init__(
 | |
|         self,
 | |
|         in_channels,
 | |
|         out_channels,
 | |
|         stride=1,
 | |
|         groups=1,
 | |
|         enable_lfu=True,
 | |
|         separable_fu=False,
 | |
|         **fu_kwargs,
 | |
|     ):
 | |
|         # bn_layer not used
 | |
|         super(SpectralTransform, self).__init__()
 | |
|         self.enable_lfu = enable_lfu
 | |
|         if stride == 2:
 | |
|             self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
 | |
|         else:
 | |
|             self.downsample = nn.Identity()
 | |
| 
 | |
|         self.stride = stride
 | |
|         self.conv1 = nn.Sequential(
 | |
|             nn.Conv2d(
 | |
|                 in_channels, out_channels // 2, kernel_size=1, groups=groups, bias=False
 | |
|             ),
 | |
|             nn.BatchNorm2d(out_channels // 2),
 | |
|             nn.ReLU(inplace=True),
 | |
|         )
 | |
|         fu_class = FourierUnit
 | |
|         self.fu = fu_class(out_channels // 2, out_channels // 2, groups, **fu_kwargs)
 | |
|         if self.enable_lfu:
 | |
|             self.lfu = fu_class(out_channels // 2, out_channels // 2, groups)
 | |
|         self.conv2 = torch.nn.Conv2d(
 | |
|             out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False
 | |
|         )
 | |
| 
 | |
|     def forward(self, x):
 | |
|         x = self.downsample(x)
 | |
|         x = self.conv1(x)
 | |
|         output = self.fu(x)
 | |
| 
 | |
|         if self.enable_lfu:
 | |
|             _, c, h, _ = x.shape
 | |
|             split_no = 2
 | |
|             split_s = h // split_no
 | |
|             xs = torch.cat(
 | |
|                 torch.split(x[:, : c // 4], split_s, dim=-2), dim=1
 | |
|             ).contiguous()
 | |
|             xs = torch.cat(torch.split(xs, split_s, dim=-1), dim=1).contiguous()
 | |
|             xs = self.lfu(xs)
 | |
|             xs = xs.repeat(1, 1, split_no, split_no).contiguous()
 | |
|         else:
 | |
|             xs = 0
 | |
| 
 | |
|         output = self.conv2(x + output + xs)
 | |
| 
 | |
|         return output
 | |
| 
 | |
| 
 | |
| class FFC(nn.Module):
 | |
|     def __init__(
 | |
|         self,
 | |
|         in_channels,
 | |
|         out_channels,
 | |
|         kernel_size,
 | |
|         ratio_gin,
 | |
|         ratio_gout,
 | |
|         stride=1,
 | |
|         padding=0,
 | |
|         dilation=1,
 | |
|         groups=1,
 | |
|         bias=False,
 | |
|         enable_lfu=True,
 | |
|         padding_type="reflect",
 | |
|         gated=False,
 | |
|         **spectral_kwargs,
 | |
|     ):
 | |
|         super(FFC, self).__init__()
 | |
| 
 | |
|         assert stride == 1 or stride == 2, "Stride should be 1 or 2."
 | |
|         self.stride = stride
 | |
| 
 | |
|         in_cg = int(in_channels * ratio_gin)
 | |
|         in_cl = in_channels - in_cg
 | |
|         out_cg = int(out_channels * ratio_gout)
 | |
|         out_cl = out_channels - out_cg
 | |
|         # groups_g = 1 if groups == 1 else int(groups * ratio_gout)
 | |
|         # groups_l = 1 if groups == 1 else groups - groups_g
 | |
| 
 | |
|         self.ratio_gin = ratio_gin
 | |
|         self.ratio_gout = ratio_gout
 | |
|         self.global_in_num = in_cg
 | |
| 
 | |
|         module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d
 | |
|         self.convl2l = module(
 | |
|             in_cl,
 | |
|             out_cl,
 | |
|             kernel_size,
 | |
|             stride,
 | |
|             padding,
 | |
|             dilation,
 | |
|             groups,
 | |
|             bias,
 | |
|             padding_mode=padding_type,
 | |
|         )
 | |
|         module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d
 | |
|         self.convl2g = module(
 | |
|             in_cl,
 | |
|             out_cg,
 | |
|             kernel_size,
 | |
|             stride,
 | |
|             padding,
 | |
|             dilation,
 | |
|             groups,
 | |
|             bias,
 | |
|             padding_mode=padding_type,
 | |
|         )
 | |
|         module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d
 | |
|         self.convg2l = module(
 | |
|             in_cg,
 | |
|             out_cl,
 | |
|             kernel_size,
 | |
|             stride,
 | |
|             padding,
 | |
|             dilation,
 | |
|             groups,
 | |
|             bias,
 | |
|             padding_mode=padding_type,
 | |
|         )
 | |
|         module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
 | |
|         self.convg2g = module(
 | |
|             in_cg,
 | |
|             out_cg,
 | |
|             stride,
 | |
|             1 if groups == 1 else groups // 2,
 | |
|             enable_lfu,
 | |
|             **spectral_kwargs,
 | |
|         )
 | |
| 
 | |
|         self.gated = gated
 | |
|         module = (
 | |
|             nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d
 | |
|         )
 | |
|         self.gate = module(in_channels, 2, 1)
 | |
| 
 | |
|     def forward(self, x):
 | |
|         x_l, x_g = x if type(x) is tuple else (x, 0)
 | |
|         out_xl, out_xg = 0, 0
 | |
| 
 | |
|         if self.gated:
 | |
|             total_input_parts = [x_l]
 | |
|             if torch.is_tensor(x_g):
 | |
|                 total_input_parts.append(x_g)
 | |
|             total_input = torch.cat(total_input_parts, dim=1)
 | |
| 
 | |
|             gates = torch.sigmoid(self.gate(total_input))
 | |
|             g2l_gate, l2g_gate = gates.chunk(2, dim=1)
 | |
|         else:
 | |
|             g2l_gate, l2g_gate = 1, 1
 | |
| 
 | |
|         if self.ratio_gout != 1:
 | |
|             out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
 | |
|         if self.ratio_gout != 0:
 | |
|             out_xg = self.convl2g(x_l) * l2g_gate + self.convg2g(x_g)
 | |
| 
 | |
|         return out_xl, out_xg
 | |
| 
 | |
| 
 | |
| class FFC_BN_ACT(nn.Module):
 | |
|     def __init__(
 | |
|         self,
 | |
|         in_channels,
 | |
|         out_channels,
 | |
|         kernel_size,
 | |
|         ratio_gin,
 | |
|         ratio_gout,
 | |
|         stride=1,
 | |
|         padding=0,
 | |
|         dilation=1,
 | |
|         groups=1,
 | |
|         bias=False,
 | |
|         norm_layer=nn.BatchNorm2d,
 | |
|         activation_layer=nn.Identity,
 | |
|         padding_type="reflect",
 | |
|         enable_lfu=True,
 | |
|         **kwargs,
 | |
|     ):
 | |
|         super(FFC_BN_ACT, self).__init__()
 | |
|         self.ffc = FFC(
 | |
|             in_channels,
 | |
|             out_channels,
 | |
|             kernel_size,
 | |
|             ratio_gin,
 | |
|             ratio_gout,
 | |
|             stride,
 | |
|             padding,
 | |
|             dilation,
 | |
|             groups,
 | |
|             bias,
 | |
|             enable_lfu,
 | |
|             padding_type=padding_type,
 | |
|             **kwargs,
 | |
|         )
 | |
|         lnorm = nn.Identity if ratio_gout == 1 else norm_layer
 | |
|         gnorm = nn.Identity if ratio_gout == 0 else norm_layer
 | |
|         global_channels = int(out_channels * ratio_gout)
 | |
|         self.bn_l = lnorm(out_channels - global_channels)
 | |
|         self.bn_g = gnorm(global_channels)
 | |
| 
 | |
|         lact = nn.Identity if ratio_gout == 1 else activation_layer
 | |
|         gact = nn.Identity if ratio_gout == 0 else activation_layer
 | |
|         self.act_l = lact(inplace=True)
 | |
|         self.act_g = gact(inplace=True)
 | |
| 
 | |
|     def forward(self, x):
 | |
|         x_l, x_g = self.ffc(x)
 | |
|         x_l = self.act_l(self.bn_l(x_l))
 | |
|         x_g = self.act_g(self.bn_g(x_g))
 | |
|         return x_l, x_g
 | |
| 
 | |
| 
 | |
| class FFCResnetBlock(nn.Module):
 | |
|     def __init__(
 | |
|         self,
 | |
|         dim,
 | |
|         padding_type,
 | |
|         norm_layer,
 | |
|         activation_layer=nn.ReLU,
 | |
|         dilation=1,
 | |
|         spatial_transform_kwargs=None,
 | |
|         inline=False,
 | |
|         **conv_kwargs,
 | |
|     ):
 | |
|         super().__init__()
 | |
|         self.conv1 = FFC_BN_ACT(
 | |
|             dim,
 | |
|             dim,
 | |
|             kernel_size=3,
 | |
|             padding=dilation,
 | |
|             dilation=dilation,
 | |
|             norm_layer=norm_layer,
 | |
|             activation_layer=activation_layer,
 | |
|             padding_type=padding_type,
 | |
|             **conv_kwargs,
 | |
|         )
 | |
|         self.conv2 = FFC_BN_ACT(
 | |
|             dim,
 | |
|             dim,
 | |
|             kernel_size=3,
 | |
|             padding=dilation,
 | |
|             dilation=dilation,
 | |
|             norm_layer=norm_layer,
 | |
|             activation_layer=activation_layer,
 | |
|             padding_type=padding_type,
 | |
|             **conv_kwargs,
 | |
|         )
 | |
|         if spatial_transform_kwargs is not None:
 | |
|             self.conv1 = LearnableSpatialTransformWrapper(
 | |
|                 self.conv1, **spatial_transform_kwargs
 | |
|             )
 | |
|             self.conv2 = LearnableSpatialTransformWrapper(
 | |
|                 self.conv2, **spatial_transform_kwargs
 | |
|             )
 | |
|         self.inline = inline
 | |
| 
 | |
|     def forward(self, x):
 | |
|         if self.inline:
 | |
|             x_l, x_g = (
 | |
|                 x[:, : -self.conv1.ffc.global_in_num],
 | |
|                 x[:, -self.conv1.ffc.global_in_num :],
 | |
|             )
 | |
|         else:
 | |
|             x_l, x_g = x if type(x) is tuple else (x, 0)
 | |
| 
 | |
|         id_l, id_g = x_l, x_g
 | |
| 
 | |
|         x_l, x_g = self.conv1((x_l, x_g))
 | |
|         x_l, x_g = self.conv2((x_l, x_g))
 | |
| 
 | |
|         x_l, x_g = id_l + x_l, id_g + x_g
 | |
|         out = x_l, x_g
 | |
|         if self.inline:
 | |
|             out = torch.cat(out, dim=1)
 | |
|         return out
 | |
| 
 | |
| 
 | |
| class ConcatTupleLayer(nn.Module):
 | |
|     def forward(self, x):
 | |
|         assert isinstance(x, tuple)
 | |
|         x_l, x_g = x
 | |
|         assert torch.is_tensor(x_l) or torch.is_tensor(x_g)
 | |
|         if not torch.is_tensor(x_g):
 | |
|             return x_l
 | |
|         return torch.cat(x, dim=1)
 | |
| 
 | |
| 
 | |
| class FFCResNetGenerator(nn.Module):
 | |
|     def __init__(
 | |
|         self,
 | |
|         input_nc,
 | |
|         output_nc,
 | |
|         ngf=64,
 | |
|         n_downsampling=3,
 | |
|         n_blocks=18,
 | |
|         norm_layer=nn.BatchNorm2d,
 | |
|         padding_type="reflect",
 | |
|         activation_layer=nn.ReLU,
 | |
|         up_norm_layer=nn.BatchNorm2d,
 | |
|         up_activation=nn.ReLU(True),
 | |
|         init_conv_kwargs={},
 | |
|         downsample_conv_kwargs={},
 | |
|         resnet_conv_kwargs={},
 | |
|         spatial_transform_layers=None,
 | |
|         spatial_transform_kwargs={},
 | |
|         max_features=1024,
 | |
|         out_ffc=False,
 | |
|         out_ffc_kwargs={},
 | |
|     ):
 | |
|         assert n_blocks >= 0
 | |
|         super().__init__()
 | |
|         """
 | |
|         init_conv_kwargs = {'ratio_gin': 0, 'ratio_gout': 0, 'enable_lfu': False}
 | |
|         downsample_conv_kwargs = {'ratio_gin': '${generator.init_conv_kwargs.ratio_gout}', 'ratio_gout': '${generator.downsample_conv_kwargs.ratio_gin}', 'enable_lfu': False}
 | |
|         resnet_conv_kwargs = {'ratio_gin': 0.75, 'ratio_gout': '${generator.resnet_conv_kwargs.ratio_gin}', 'enable_lfu': False}
 | |
|         spatial_transform_kwargs = {}
 | |
|         out_ffc_kwargs = {}
 | |
|         """
 | |
|         """
 | |
|         print(input_nc, output_nc, ngf, n_downsampling, n_blocks, norm_layer,
 | |
|                 padding_type, activation_layer,
 | |
|                 up_norm_layer, up_activation,
 | |
|                 spatial_transform_layers,
 | |
|                 add_out_act, max_features, out_ffc, file=sys.stderr)
 | |
| 
 | |
|         4 3 64 3 18 <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
 | |
|         reflect <class 'torch.nn.modules.activation.ReLU'>
 | |
|         <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
 | |
|         ReLU(inplace=True)
 | |
|         None sigmoid 1024 False
 | |
|         """
 | |
|         init_conv_kwargs = {"ratio_gin": 0, "ratio_gout": 0, "enable_lfu": False}
 | |
|         downsample_conv_kwargs = {"ratio_gin": 0, "ratio_gout": 0, "enable_lfu": False}
 | |
|         resnet_conv_kwargs = {
 | |
|             "ratio_gin": 0.75,
 | |
|             "ratio_gout": 0.75,
 | |
|             "enable_lfu": False,
 | |
|         }
 | |
|         spatial_transform_kwargs = {}
 | |
|         out_ffc_kwargs = {}
 | |
| 
 | |
|         model = [
 | |
|             nn.ReflectionPad2d(3),
 | |
|             FFC_BN_ACT(
 | |
|                 input_nc,
 | |
|                 ngf,
 | |
|                 kernel_size=7,
 | |
|                 padding=0,
 | |
|                 norm_layer=norm_layer,
 | |
|                 activation_layer=activation_layer,
 | |
|                 **init_conv_kwargs,
 | |
|             ),
 | |
|         ]
 | |
| 
 | |
|         ### downsample
 | |
|         for i in range(n_downsampling):
 | |
|             mult = 2**i
 | |
|             if i == n_downsampling - 1:
 | |
|                 cur_conv_kwargs = dict(downsample_conv_kwargs)
 | |
|                 cur_conv_kwargs["ratio_gout"] = resnet_conv_kwargs.get("ratio_gin", 0)
 | |
|             else:
 | |
|                 cur_conv_kwargs = downsample_conv_kwargs
 | |
|             model += [
 | |
|                 FFC_BN_ACT(
 | |
|                     min(max_features, ngf * mult),
 | |
|                     min(max_features, ngf * mult * 2),
 | |
|                     kernel_size=3,
 | |
|                     stride=2,
 | |
|                     padding=1,
 | |
|                     norm_layer=norm_layer,
 | |
|                     activation_layer=activation_layer,
 | |
|                     **cur_conv_kwargs,
 | |
|                 )
 | |
|             ]
 | |
| 
 | |
|         mult = 2**n_downsampling
 | |
|         feats_num_bottleneck = min(max_features, ngf * mult)
 | |
| 
 | |
|         ### resnet blocks
 | |
|         for i in range(n_blocks):
 | |
|             cur_resblock = FFCResnetBlock(
 | |
|                 feats_num_bottleneck,
 | |
|                 padding_type=padding_type,
 | |
|                 activation_layer=activation_layer,
 | |
|                 norm_layer=norm_layer,
 | |
|                 **resnet_conv_kwargs,
 | |
|             )
 | |
|             if spatial_transform_layers is not None and i in spatial_transform_layers:
 | |
|                 cur_resblock = LearnableSpatialTransformWrapper(
 | |
|                     cur_resblock, **spatial_transform_kwargs
 | |
|                 )
 | |
|             model += [cur_resblock]
 | |
| 
 | |
|         model += [ConcatTupleLayer()]
 | |
| 
 | |
|         ### upsample
 | |
|         for i in range(n_downsampling):
 | |
|             mult = 2 ** (n_downsampling - i)
 | |
|             model += [
 | |
|                 nn.ConvTranspose2d(
 | |
|                     min(max_features, ngf * mult),
 | |
|                     min(max_features, int(ngf * mult / 2)),
 | |
|                     kernel_size=3,
 | |
|                     stride=2,
 | |
|                     padding=1,
 | |
|                     output_padding=1,
 | |
|                 ),
 | |
|                 up_norm_layer(min(max_features, int(ngf * mult / 2))),
 | |
|                 up_activation,
 | |
|             ]
 | |
| 
 | |
|         if out_ffc:
 | |
|             model += [
 | |
|                 FFCResnetBlock(
 | |
|                     ngf,
 | |
|                     padding_type=padding_type,
 | |
|                     activation_layer=activation_layer,
 | |
|                     norm_layer=norm_layer,
 | |
|                     inline=True,
 | |
|                     **out_ffc_kwargs,
 | |
|                 )
 | |
|             ]
 | |
| 
 | |
|         model += [
 | |
|             nn.ReflectionPad2d(3),
 | |
|             nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
 | |
|         ]
 | |
|         model.append(nn.Sigmoid())
 | |
|         self.model = nn.Sequential(*model)
 | |
| 
 | |
|     def forward(self, image, mask):
 | |
|         return self.model(torch.cat([image, mask], dim=1))
 | |
| 
 | |
| 
 | |
| class LaMa(nn.Module):
 | |
|     def __init__(self, state_dict) -> None:
 | |
|         super(LaMa, self).__init__()
 | |
|         self.model_arch = "LaMa"
 | |
|         self.sub_type = "Inpaint"
 | |
|         self.in_nc = 4
 | |
|         self.out_nc = 3
 | |
|         self.scale = 1
 | |
| 
 | |
|         self.min_size = None
 | |
|         self.pad_mod = 8
 | |
|         self.pad_to_square = False
 | |
| 
 | |
|         self.model = FFCResNetGenerator(self.in_nc, self.out_nc)
 | |
|         self.state = {
 | |
|             k.replace("generator.model", "model.model"): v
 | |
|             for k, v in state_dict.items()
 | |
|         }
 | |
| 
 | |
|         self.supports_fp16 = False
 | |
|         self.support_bf16 = True
 | |
| 
 | |
|         self.load_state_dict(self.state, strict=False)
 | |
| 
 | |
|     def forward(self, img, mask):
 | |
|         masked_img = img * (1 - mask)
 | |
|         inpainted_mask = mask * self.model.forward(masked_img, mask)
 | |
|         result = inpainted_mask + (1 - mask) * img
 | |
|         return result
 |