bfloat16 vae (#456)

* bfloat16 vae

* bfloat16 vae

* bfloat16 vae
This commit is contained in:
lllyasviel 2023-09-20 08:16:20 -07:00 committed by GitHub
parent 6597b3df64
commit cdf642437c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 47 additions and 5 deletions

View File

@ -1 +1 @@
version = '2.0.75'
version = '2.0.76'

View File

@ -1,7 +1,5 @@
import threading
import numpy as np
import torch
buffer = []
outputs = []
@ -10,6 +8,8 @@ outputs = []
def worker():
global buffer, outputs
import numpy as np
import torch
import time
import shared
import random

View File

@ -1,3 +1,8 @@
from modules.patch import patch_all
patch_all()
import os
import random
import einops
@ -13,10 +18,8 @@ from comfy.sample import prepare_mask, broadcast_cond, load_additional_models, c
from comfy.model_base import SDXLRefiner
from comfy.sd import model_lora_keys_unet, model_lora_keys_clip, load_lora
from modules.samplers_advanced import KSamplerBasic, KSamplerWithRefiner
from modules.patch import patch_all
patch_all()
opEmptyLatentImage = EmptyLatentImage()
opVAEDecode = VAEDecode()
opVAEEncode = VAEEncode()

View File

@ -11,6 +11,7 @@ import comfy.k_diffusion.sampling
import comfy.sd1_clip
import modules.inpaint_worker as inpaint_worker
import comfy.ldm.modules.diffusionmodules.openaimodel
import comfy.ldm.modules.diffusionmodules.model
import comfy.sd
from comfy.k_diffusion import utils
@ -391,7 +392,45 @@ def patched_SD1ClipModel_forward(self, tokens):
return z.float(), pooled_output.float()
VAE_DTYPE = None
def vae_dtype_patched():
global VAE_DTYPE
if VAE_DTYPE is None:
VAE_DTYPE = torch.float32
if comfy.model_management.is_nvidia():
torch_version = torch.version.__version__
if int(torch_version[0]) >= 2:
if torch.cuda.is_bf16_supported():
VAE_DTYPE = torch.bfloat16
print('BFloat16 VAE: Enabled')
return VAE_DTYPE
def vae_bf16_upsample_forward(self, x):
try:
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
except: # operation not implemented for bf16
b, c, h, w = x.shape
out = torch.empty((b, c, h * 2, w * 2), dtype=x.dtype, layout=x.layout, device=x.device)
split = 8
l = out.shape[1] // split
for i in range(0, out.shape[1], l):
out[:, i:i + l] = torch.nn.functional.interpolate(x[:, i:i + l].to(torch.float32), scale_factor=2.0,
mode="nearest").to(x.dtype)
del x
x = out
if self.with_conv:
x = self.conv(x)
return x
def patch_all():
comfy.model_management.vae_dtype = vae_dtype_patched
comfy.ldm.modules.diffusionmodules.model.Upsample.forward = vae_bf16_upsample_forward
comfy.sd1_clip.SD1ClipModel.forward = patched_SD1ClipModel_forward
comfy.sd.ModelPatcher.calculate_weight = calculate_weight_patched