move expansion to managed device (#364)

* move expansion to managed device

* move expansion to managed device

* move expansion to managed device

* move expansion to managed device

* move expansion to managed device

* move expansion to managed device
This commit is contained in:
lllyasviel 2023-09-13 12:48:27 -07:00 committed by GitHub
parent 53beede21d
commit e32f04da34
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 17 additions and 3 deletions

View File

@ -1 +1 @@
version = '2.0.4'
version = '2.0.5'

View File

@ -89,7 +89,8 @@ def download_models():
def clear_comfy_args():
argv = sys.argv
sys.argv = [sys.argv[0]]
import comfy.cli_args
from comfy.cli_args import args as comfy_args
comfy_args.disable_cuda_malloc = True
sys.argv = argv

View File

@ -1,6 +1,9 @@
import torch
import comfy.model_management as model_management
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, set_seed
from modules.path import fooocus_expansion_path
from comfy.sd import ModelPatcher
fooocus_magic_split = [
@ -27,14 +30,22 @@ class FooocusExpansion:
def __init__(self):
self.tokenizer = AutoTokenizer.from_pretrained(fooocus_expansion_path)
self.model = AutoModelForCausalLM.from_pretrained(fooocus_expansion_path)
load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device()
self.patcher = ModelPatcher(self.model, load_device=load_device, offload_device=offload_device)
self.pipe = pipeline('text-generation',
model=self.model,
tokenizer=self.tokenizer,
device='cpu',
torch_dtype=torch.float32)
print('Fooocus Expansion engine loaded.')
print(f'Fooocus Expansion engine loaded.')
def __call__(self, prompt, seed):
model_management.load_model_gpu(self.patcher)
self.pipe.device = self.patcher.load_device
seed = int(seed)
set_seed(seed)
origin = safe_str(prompt)

View File

@ -76,5 +76,7 @@ def text_encoder_device_patched():
def patch_all():
comfy.model_management.text_encoder_device = text_encoder_device_patched
print(f'Fooocus Text Processing Pipelines are retargeted to {str(comfy.model_management.text_encoder_device())}')
comfy.k_diffusion.external.DiscreteEpsDDPMDenoiser.forward = patched_discrete_eps_ddpm_denoiser_forward
comfy.model_base.SDXL.encode_adm = sdxl_encode_adm_patched