94 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			94 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# https://github.com/city96/SD-Latent-Interposer/blob/main/interposer.py
 | 
						|
 | 
						|
import os
 | 
						|
import torch
 | 
						|
import safetensors.torch as sf
 | 
						|
import torch.nn as nn
 | 
						|
import fcbh.model_management
 | 
						|
 | 
						|
from fcbh.model_patcher import ModelPatcher
 | 
						|
from modules.path import vae_approx_path
 | 
						|
 | 
						|
 | 
						|
class Block(nn.Module):
 | 
						|
    def __init__(self, size):
 | 
						|
        super().__init__()
 | 
						|
        self.join = nn.ReLU()
 | 
						|
        self.long = nn.Sequential(
 | 
						|
            nn.Conv2d(size, size, kernel_size=3, stride=1, padding=1),
 | 
						|
            nn.LeakyReLU(0.1),
 | 
						|
            nn.Conv2d(size, size, kernel_size=3, stride=1, padding=1),
 | 
						|
            nn.LeakyReLU(0.1),
 | 
						|
            nn.Conv2d(size, size, kernel_size=3, stride=1, padding=1),
 | 
						|
        )
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        y = self.long(x)
 | 
						|
        z = self.join(y + x)
 | 
						|
        return z
 | 
						|
 | 
						|
 | 
						|
class Interposer(nn.Module):
 | 
						|
    def __init__(self):
 | 
						|
        super().__init__()
 | 
						|
        self.chan = 4
 | 
						|
        self.hid = 128
 | 
						|
 | 
						|
        self.head_join = nn.ReLU()
 | 
						|
        self.head_short = nn.Conv2d(self.chan, self.hid, kernel_size=3, stride=1, padding=1)
 | 
						|
        self.head_long = nn.Sequential(
 | 
						|
            nn.Conv2d(self.chan, self.hid, kernel_size=3, stride=1, padding=1),
 | 
						|
            nn.LeakyReLU(0.1),
 | 
						|
            nn.Conv2d(self.hid, self.hid, kernel_size=3, stride=1, padding=1),
 | 
						|
            nn.LeakyReLU(0.1),
 | 
						|
            nn.Conv2d(self.hid, self.hid, kernel_size=3, stride=1, padding=1),
 | 
						|
        )
 | 
						|
        self.core = nn.Sequential(
 | 
						|
            Block(self.hid),
 | 
						|
            Block(self.hid),
 | 
						|
            Block(self.hid),
 | 
						|
        )
 | 
						|
        self.tail = nn.Sequential(
 | 
						|
            nn.ReLU(),
 | 
						|
            nn.Conv2d(self.hid, self.chan, kernel_size=3, stride=1, padding=1)
 | 
						|
        )
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        y = self.head_join(
 | 
						|
            self.head_long(x) +
 | 
						|
            self.head_short(x)
 | 
						|
        )
 | 
						|
        z = self.core(y)
 | 
						|
        return self.tail(z)
 | 
						|
 | 
						|
 | 
						|
vae_approx_model = None
 | 
						|
vae_approx_filename = os.path.join(vae_approx_path, 'xl-to-v1_interposer-v3.1.safetensors')
 | 
						|
 | 
						|
 | 
						|
def parse(x):
 | 
						|
    global vae_approx_model
 | 
						|
 | 
						|
    x_origin = x.clone()
 | 
						|
 | 
						|
    if vae_approx_model is None:
 | 
						|
        model = Interposer()
 | 
						|
        model.eval()
 | 
						|
        sd = sf.load_file(vae_approx_filename)
 | 
						|
        model.load_state_dict(sd)
 | 
						|
        fp16 = fcbh.model_management.should_use_fp16()
 | 
						|
        if fp16:
 | 
						|
            model = model.half()
 | 
						|
        vae_approx_model = ModelPatcher(
 | 
						|
            model=model,
 | 
						|
            load_device=fcbh.model_management.get_torch_device(),
 | 
						|
            offload_device=torch.device('cpu')
 | 
						|
        )
 | 
						|
        vae_approx_model.dtype = torch.float16 if fp16 else torch.float32
 | 
						|
 | 
						|
    fcbh.model_management.load_model_gpu(vae_approx_model)
 | 
						|
 | 
						|
    x = x_origin.to(device=vae_approx_model.load_device, dtype=vae_approx_model.dtype)
 | 
						|
    x = vae_approx_model.model(x).to(x_origin)
 | 
						|
    return x
 |