94 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			94 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # https://github.com/city96/SD-Latent-Interposer/blob/main/interposer.py
 | |
| 
 | |
| import os
 | |
| import torch
 | |
| import safetensors.torch as sf
 | |
| import torch.nn as nn
 | |
| import fcbh.model_management
 | |
| 
 | |
| from fcbh.model_patcher import ModelPatcher
 | |
| from modules.path import vae_approx_path
 | |
| 
 | |
| 
 | |
| class Block(nn.Module):
 | |
|     def __init__(self, size):
 | |
|         super().__init__()
 | |
|         self.join = nn.ReLU()
 | |
|         self.long = nn.Sequential(
 | |
|             nn.Conv2d(size, size, kernel_size=3, stride=1, padding=1),
 | |
|             nn.LeakyReLU(0.1),
 | |
|             nn.Conv2d(size, size, kernel_size=3, stride=1, padding=1),
 | |
|             nn.LeakyReLU(0.1),
 | |
|             nn.Conv2d(size, size, kernel_size=3, stride=1, padding=1),
 | |
|         )
 | |
| 
 | |
|     def forward(self, x):
 | |
|         y = self.long(x)
 | |
|         z = self.join(y + x)
 | |
|         return z
 | |
| 
 | |
| 
 | |
| class Interposer(nn.Module):
 | |
|     def __init__(self):
 | |
|         super().__init__()
 | |
|         self.chan = 4
 | |
|         self.hid = 128
 | |
| 
 | |
|         self.head_join = nn.ReLU()
 | |
|         self.head_short = nn.Conv2d(self.chan, self.hid, kernel_size=3, stride=1, padding=1)
 | |
|         self.head_long = nn.Sequential(
 | |
|             nn.Conv2d(self.chan, self.hid, kernel_size=3, stride=1, padding=1),
 | |
|             nn.LeakyReLU(0.1),
 | |
|             nn.Conv2d(self.hid, self.hid, kernel_size=3, stride=1, padding=1),
 | |
|             nn.LeakyReLU(0.1),
 | |
|             nn.Conv2d(self.hid, self.hid, kernel_size=3, stride=1, padding=1),
 | |
|         )
 | |
|         self.core = nn.Sequential(
 | |
|             Block(self.hid),
 | |
|             Block(self.hid),
 | |
|             Block(self.hid),
 | |
|         )
 | |
|         self.tail = nn.Sequential(
 | |
|             nn.ReLU(),
 | |
|             nn.Conv2d(self.hid, self.chan, kernel_size=3, stride=1, padding=1)
 | |
|         )
 | |
| 
 | |
|     def forward(self, x):
 | |
|         y = self.head_join(
 | |
|             self.head_long(x) +
 | |
|             self.head_short(x)
 | |
|         )
 | |
|         z = self.core(y)
 | |
|         return self.tail(z)
 | |
| 
 | |
| 
 | |
| vae_approx_model = None
 | |
| vae_approx_filename = os.path.join(vae_approx_path, 'xl-to-v1_interposer-v3.1.safetensors')
 | |
| 
 | |
| 
 | |
| def parse(x):
 | |
|     global vae_approx_model
 | |
| 
 | |
|     x_origin = x.clone()
 | |
| 
 | |
|     if vae_approx_model is None:
 | |
|         model = Interposer()
 | |
|         model.eval()
 | |
|         sd = sf.load_file(vae_approx_filename)
 | |
|         model.load_state_dict(sd)
 | |
|         fp16 = fcbh.model_management.should_use_fp16()
 | |
|         if fp16:
 | |
|             model = model.half()
 | |
|         vae_approx_model = ModelPatcher(
 | |
|             model=model,
 | |
|             load_device=fcbh.model_management.get_torch_device(),
 | |
|             offload_device=torch.device('cpu')
 | |
|         )
 | |
|         vae_approx_model.dtype = torch.float16 if fp16 else torch.float32
 | |
| 
 | |
|     fcbh.model_management.load_model_gpu(vae_approx_model)
 | |
| 
 | |
|     x = x_origin.to(device=vae_approx_model.load_device, dtype=vae_approx_model.dtype)
 | |
|     x = vae_approx_model.model(x).to(x_origin)
 | |
|     return x
 |