diff --git a/fooocus_version.py b/fooocus_version.py index b311819..10afe32 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '2.0.75' +version = '2.0.76' diff --git a/modules/async_worker.py b/modules/async_worker.py index 5f315ad..379b2c8 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -1,7 +1,5 @@ import threading -import numpy as np -import torch buffer = [] outputs = [] @@ -10,6 +8,8 @@ outputs = [] def worker(): global buffer, outputs + import numpy as np + import torch import time import shared import random diff --git a/modules/core.py b/modules/core.py index e57a81a..82f8fad 100644 --- a/modules/core.py +++ b/modules/core.py @@ -1,3 +1,8 @@ +from modules.patch import patch_all + +patch_all() + + import os import random import einops @@ -13,10 +18,8 @@ from comfy.sample import prepare_mask, broadcast_cond, load_additional_models, c from comfy.model_base import SDXLRefiner from comfy.sd import model_lora_keys_unet, model_lora_keys_clip, load_lora from modules.samplers_advanced import KSamplerBasic, KSamplerWithRefiner -from modules.patch import patch_all -patch_all() opEmptyLatentImage = EmptyLatentImage() opVAEDecode = VAEDecode() opVAEEncode = VAEEncode() diff --git a/modules/patch.py b/modules/patch.py index 34ce2da..7344b5a 100644 --- a/modules/patch.py +++ b/modules/patch.py @@ -11,6 +11,7 @@ import comfy.k_diffusion.sampling import comfy.sd1_clip import modules.inpaint_worker as inpaint_worker import comfy.ldm.modules.diffusionmodules.openaimodel +import comfy.ldm.modules.diffusionmodules.model import comfy.sd from comfy.k_diffusion import utils @@ -391,7 +392,45 @@ def patched_SD1ClipModel_forward(self, tokens): return z.float(), pooled_output.float() +VAE_DTYPE = None + + +def vae_dtype_patched(): + global VAE_DTYPE + if VAE_DTYPE is None: + VAE_DTYPE = torch.float32 + if comfy.model_management.is_nvidia(): + torch_version = torch.version.__version__ + if int(torch_version[0]) >= 2: + if torch.cuda.is_bf16_supported(): + VAE_DTYPE = torch.bfloat16 + print('BFloat16 VAE: Enabled') + return VAE_DTYPE + + +def vae_bf16_upsample_forward(self, x): + try: + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + except: # operation not implemented for bf16 + b, c, h, w = x.shape + out = torch.empty((b, c, h * 2, w * 2), dtype=x.dtype, layout=x.layout, device=x.device) + split = 8 + l = out.shape[1] // split + for i in range(0, out.shape[1], l): + out[:, i:i + l] = torch.nn.functional.interpolate(x[:, i:i + l].to(torch.float32), scale_factor=2.0, + mode="nearest").to(x.dtype) + del x + x = out + + if self.with_conv: + x = self.conv(x) + return x + + def patch_all(): + comfy.model_management.vae_dtype = vae_dtype_patched + comfy.ldm.modules.diffusionmodules.model.Upsample.forward = vae_bf16_upsample_forward + comfy.sd1_clip.SD1ClipModel.forward = patched_SD1ClipModel_forward comfy.sd.ModelPatcher.calculate_weight = calculate_weight_patched