100 lines
		
	
	
		
			3.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			100 lines
		
	
	
		
			3.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import logging as logger
 | |
| 
 | |
| from .architecture.DAT import DAT
 | |
| from .architecture.face.codeformer import CodeFormer
 | |
| from .architecture.face.gfpganv1_clean_arch import GFPGANv1Clean
 | |
| from .architecture.face.restoreformer_arch import RestoreFormer
 | |
| from .architecture.HAT import HAT
 | |
| from .architecture.LaMa import LaMa
 | |
| from .architecture.OmniSR.OmniSR import OmniSR
 | |
| from .architecture.RRDB import RRDBNet as ESRGAN
 | |
| from .architecture.SCUNet import SCUNet
 | |
| from .architecture.SPSR import SPSRNet as SPSR
 | |
| from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2
 | |
| from .architecture.SwiftSRGAN import Generator as SwiftSRGAN
 | |
| from .architecture.Swin2SR import Swin2SR
 | |
| from .architecture.SwinIR import SwinIR
 | |
| from .types import PyTorchModel
 | |
| 
 | |
| 
 | |
| class UnsupportedModel(Exception):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| def load_state_dict(state_dict) -> PyTorchModel:
 | |
|     logger.debug(f"Loading state dict into pytorch model arch")
 | |
| 
 | |
|     state_dict_keys = list(state_dict.keys())
 | |
| 
 | |
|     if "params_ema" in state_dict_keys:
 | |
|         state_dict = state_dict["params_ema"]
 | |
|     elif "params-ema" in state_dict_keys:
 | |
|         state_dict = state_dict["params-ema"]
 | |
|     elif "params" in state_dict_keys:
 | |
|         state_dict = state_dict["params"]
 | |
| 
 | |
|     state_dict_keys = list(state_dict.keys())
 | |
|     # SRVGGNet Real-ESRGAN (v2)
 | |
|     if "body.0.weight" in state_dict_keys and "body.1.weight" in state_dict_keys:
 | |
|         model = RealESRGANv2(state_dict)
 | |
|     # SPSR (ESRGAN with lots of extra layers)
 | |
|     elif "f_HR_conv1.0.weight" in state_dict:
 | |
|         model = SPSR(state_dict)
 | |
|     # Swift-SRGAN
 | |
|     elif (
 | |
|         "model" in state_dict_keys
 | |
|         and "initial.cnn.depthwise.weight" in state_dict["model"].keys()
 | |
|     ):
 | |
|         model = SwiftSRGAN(state_dict)
 | |
|     # SwinIR, Swin2SR, HAT
 | |
|     elif "layers.0.residual_group.blocks.0.norm1.weight" in state_dict_keys:
 | |
|         if (
 | |
|             "layers.0.residual_group.blocks.0.conv_block.cab.0.weight"
 | |
|             in state_dict_keys
 | |
|         ):
 | |
|             model = HAT(state_dict)
 | |
|         elif "patch_embed.proj.weight" in state_dict_keys:
 | |
|             model = Swin2SR(state_dict)
 | |
|         else:
 | |
|             model = SwinIR(state_dict)
 | |
|     # GFPGAN
 | |
|     elif (
 | |
|         "toRGB.0.weight" in state_dict_keys
 | |
|         and "stylegan_decoder.style_mlp.1.weight" in state_dict_keys
 | |
|     ):
 | |
|         model = GFPGANv1Clean(state_dict)
 | |
|     # RestoreFormer
 | |
|     elif (
 | |
|         "encoder.conv_in.weight" in state_dict_keys
 | |
|         and "encoder.down.0.block.0.norm1.weight" in state_dict_keys
 | |
|     ):
 | |
|         model = RestoreFormer(state_dict)
 | |
|     elif (
 | |
|         "encoder.blocks.0.weight" in state_dict_keys
 | |
|         and "quantize.embedding.weight" in state_dict_keys
 | |
|     ):
 | |
|         model = CodeFormer(state_dict)
 | |
|     # LaMa
 | |
|     elif (
 | |
|         "model.model.1.bn_l.running_mean" in state_dict_keys
 | |
|         or "generator.model.1.bn_l.running_mean" in state_dict_keys
 | |
|     ):
 | |
|         model = LaMa(state_dict)
 | |
|     # Omni-SR
 | |
|     elif "residual_layer.0.residual_layer.0.layer.0.fn.0.weight" in state_dict_keys:
 | |
|         model = OmniSR(state_dict)
 | |
|     # SCUNet
 | |
|     elif "m_head.0.weight" in state_dict_keys and "m_tail.0.weight" in state_dict_keys:
 | |
|         model = SCUNet(state_dict)
 | |
|     # DAT
 | |
|     elif "layers.0.blocks.2.attn.attn_mask_0" in state_dict_keys:
 | |
|         model = DAT(state_dict)
 | |
|     # Regular ESRGAN, "new-arch" ESRGAN, Real-ESRGAN v1
 | |
|     else:
 | |
|         try:
 | |
|             model = ESRGAN(state_dict)
 | |
|         except:
 | |
|             # pylint: disable=raise-missing-from
 | |
|             raise UnsupportedModel
 | |
|     return model
 |