try fix colab with virtual ram (#378)
try fix colab with virtual ram (#378)
This commit is contained in:
parent
b5b4fd27f1
commit
cf7cde08b1
@ -1 +1 @@
|
||||
version = '2.0.14'
|
||||
version = '2.0.16'
|
||||
|
@ -16,6 +16,7 @@ def worker():
|
||||
import modules.default_pipeline as pipeline
|
||||
import modules.path
|
||||
import modules.patch
|
||||
import modules.virtual_memory as virtual_memory
|
||||
|
||||
from modules.sdxl_styles import apply_style, aspect_ratios, fooocus_expansion
|
||||
from modules.private_logger import log
|
||||
@ -80,10 +81,10 @@ def worker():
|
||||
|
||||
progressbar(3, 'Loading models ...')
|
||||
|
||||
pipeline.refresh_base_model(base_model_name)
|
||||
pipeline.refresh_refiner_model(refiner_model_name)
|
||||
pipeline.refresh_loras(loras)
|
||||
pipeline.clear_all_caches()
|
||||
pipeline.refresh_everything(
|
||||
refiner_model_name=refiner_model_name,
|
||||
base_model_name=base_model_name,
|
||||
loras=loras)
|
||||
|
||||
progressbar(3, 'Processing prompts ...')
|
||||
|
||||
@ -137,6 +138,8 @@ def worker():
|
||||
pool_top_k=negative_top_k)
|
||||
|
||||
if pipeline.xl_refiner is not None:
|
||||
virtual_memory.load_from_virtual_memory(pipeline.xl_refiner.clip.cond_stage_model)
|
||||
|
||||
for i, t in enumerate(tasks):
|
||||
progressbar(11, f'Encoding refiner positive #{i + 1} ...')
|
||||
t['c'][1] = pipeline.clip_encode(sd=pipeline.xl_refiner, texts=t['positive'],
|
||||
@ -147,6 +150,8 @@ def worker():
|
||||
t['uc'][1] = pipeline.clip_encode(sd=pipeline.xl_refiner, texts=t['negative'],
|
||||
pool_top_k=negative_top_k)
|
||||
|
||||
virtual_memory.try_move_to_virtual_memory(pipeline.xl_refiner.clip.cond_stage_model)
|
||||
|
||||
if performance_selction == 'Speed':
|
||||
steps = 30
|
||||
switch = 20
|
||||
|
@ -10,6 +10,7 @@ import comfy.utils
|
||||
from comfy.sd import load_checkpoint_guess_config
|
||||
from nodes import VAEDecode, EmptyLatentImage
|
||||
from comfy.sample import prepare_mask, broadcast_cond, load_additional_models, cleanup_additional_models
|
||||
from comfy.model_base import SDXLRefiner
|
||||
from modules.samplers_advanced import KSampler, KSamplerWithRefiner
|
||||
from modules.patch import patch_all
|
||||
|
||||
@ -20,7 +21,15 @@ opVAEDecode = VAEDecode()
|
||||
|
||||
|
||||
class StableDiffusionModel:
|
||||
def __init__(self, unet, vae, clip, clip_vision):
|
||||
def __init__(self, unet, vae, clip, clip_vision, model_filename=None):
|
||||
if isinstance(model_filename, str):
|
||||
is_refiner = isinstance(unet.model, SDXLRefiner)
|
||||
if unet is not None:
|
||||
unet.model.model_file = dict(filename=model_filename, prefix='model')
|
||||
if clip is not None:
|
||||
clip.cond_stage_model.model_file = dict(filename=model_filename, prefix='refiner_clip' if is_refiner else 'base_clip')
|
||||
if vae is not None:
|
||||
vae.first_stage_model.model_file = dict(filename=model_filename, prefix='first_stage_model')
|
||||
self.unet = unet
|
||||
self.vae = vae
|
||||
self.clip = clip
|
||||
@ -38,7 +47,7 @@ class StableDiffusionModel:
|
||||
@torch.no_grad()
|
||||
def load_model(ckpt_filename):
|
||||
unet, clip, vae, clip_vision = load_checkpoint_guess_config(ckpt_filename)
|
||||
return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision)
|
||||
return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, model_filename=ckpt_filename)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
|
@ -2,6 +2,7 @@ import modules.core as core
|
||||
import os
|
||||
import torch
|
||||
import modules.path
|
||||
import modules.virtual_memory as virtual_memory
|
||||
import comfy.model_management as model_management
|
||||
|
||||
from comfy.model_base import SDXL, SDXLRefiner
|
||||
@ -21,10 +22,12 @@ xl_base_patched_hash = ''
|
||||
|
||||
def refresh_base_model(name):
|
||||
global xl_base, xl_base_hash, xl_base_patched, xl_base_patched_hash
|
||||
if xl_base_hash == str(name):
|
||||
return
|
||||
|
||||
filename = os.path.join(modules.path.modelfile_path, name)
|
||||
filename = os.path.abspath(os.path.realpath(os.path.join(modules.path.modelfile_path, name)))
|
||||
model_hash = filename
|
||||
|
||||
if xl_base_hash == model_hash:
|
||||
return
|
||||
|
||||
if xl_base is not None:
|
||||
xl_base.to_meta()
|
||||
@ -36,21 +39,25 @@ def refresh_base_model(name):
|
||||
xl_base = None
|
||||
xl_base_hash = ''
|
||||
refresh_base_model(modules.path.default_base_model_name)
|
||||
xl_base_hash = name
|
||||
xl_base_hash = model_hash
|
||||
xl_base_patched = xl_base
|
||||
xl_base_patched_hash = ''
|
||||
return
|
||||
|
||||
xl_base_hash = name
|
||||
xl_base_hash = model_hash
|
||||
xl_base_patched = xl_base
|
||||
xl_base_patched_hash = ''
|
||||
print(f'Base model loaded: {xl_base_hash}')
|
||||
print(f'Base model loaded: {model_hash}')
|
||||
return
|
||||
|
||||
|
||||
def refresh_refiner_model(name):
|
||||
global xl_refiner, xl_refiner_hash
|
||||
if xl_refiner_hash == str(name):
|
||||
|
||||
filename = os.path.abspath(os.path.realpath(os.path.join(modules.path.modelfile_path, name)))
|
||||
model_hash = filename
|
||||
|
||||
if xl_refiner_hash == model_hash:
|
||||
return
|
||||
|
||||
if name == 'None':
|
||||
@ -59,8 +66,6 @@ def refresh_refiner_model(name):
|
||||
print(f'Refiner unloaded.')
|
||||
return
|
||||
|
||||
filename = os.path.join(modules.path.modelfile_path, name)
|
||||
|
||||
if xl_refiner is not None:
|
||||
xl_refiner.to_meta()
|
||||
xl_refiner = None
|
||||
@ -73,8 +78,8 @@ def refresh_refiner_model(name):
|
||||
print(f'Refiner unloaded.')
|
||||
return
|
||||
|
||||
xl_refiner_hash = name
|
||||
print(f'Refiner model loaded: {xl_refiner_hash}')
|
||||
xl_refiner_hash = model_hash
|
||||
print(f'Refiner model loaded: {model_hash}')
|
||||
|
||||
xl_refiner.vae.first_stage_model.to('meta')
|
||||
xl_refiner.vae = None
|
||||
@ -100,13 +105,6 @@ def refresh_loras(loras):
|
||||
return
|
||||
|
||||
|
||||
refresh_base_model(modules.path.default_base_model_name)
|
||||
refresh_refiner_model(modules.path.default_refiner_model_name)
|
||||
refresh_loras([(modules.path.default_lora_name, 0.5), ('None', 0.5), ('None', 0.5), ('None', 0.5), ('None', 0.5)])
|
||||
|
||||
expansion = FooocusExpansion()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def clip_encode_single(clip, text, verbose=False):
|
||||
cached = clip.fcs_cond_cache.get(text, None)
|
||||
@ -133,8 +131,6 @@ def clip_encode(sd, texts, pool_top_k=1):
|
||||
if len(texts) == 0:
|
||||
return None
|
||||
|
||||
model_management.soft_empty_cache()
|
||||
|
||||
clip = sd.clip
|
||||
cond_list = []
|
||||
pooled_acc = 0
|
||||
@ -164,6 +160,29 @@ def clear_all_caches():
|
||||
clear_sd_cond_cache(xl_refiner)
|
||||
|
||||
|
||||
def refresh_everything(refiner_model_name, base_model_name, loras):
|
||||
refresh_refiner_model(refiner_model_name)
|
||||
if xl_refiner is not None:
|
||||
virtual_memory.try_move_to_virtual_memory(xl_refiner.unet.model)
|
||||
virtual_memory.try_move_to_virtual_memory(xl_refiner.clip.cond_stage_model)
|
||||
|
||||
refresh_base_model(base_model_name)
|
||||
virtual_memory.load_from_virtual_memory(xl_base.unet.model)
|
||||
|
||||
refresh_loras(loras)
|
||||
clear_all_caches()
|
||||
return
|
||||
|
||||
|
||||
refresh_everything(
|
||||
refiner_model_name=modules.path.default_refiner_model_name,
|
||||
base_model_name=modules.path.default_base_model_name,
|
||||
loras=[(modules.path.default_lora_name, 0.5), ('None', 0.5), ('None', 0.5), ('None', 0.5), ('None', 0.5)]
|
||||
)
|
||||
|
||||
expansion = FooocusExpansion()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def patch_all_models():
|
||||
assert xl_base is not None
|
||||
@ -181,7 +200,10 @@ def patch_all_models():
|
||||
@torch.no_grad()
|
||||
def process_diffusion(positive_cond, negative_cond, steps, switch, width, height, image_seed, callback):
|
||||
patch_all_models()
|
||||
model_management.soft_empty_cache()
|
||||
|
||||
if xl_refiner is not None:
|
||||
virtual_memory.try_move_to_virtual_memory(xl_refiner.unet.model)
|
||||
virtual_memory.load_from_virtual_memory(xl_base.unet.model)
|
||||
|
||||
empty_latent = core.generate_empty_latent(width=width, height=height, batch_size=1)
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
from comfy.samplers import *
|
||||
|
||||
import comfy.model_management
|
||||
import modules.virtual_memory
|
||||
|
||||
|
||||
class KSamplerWithRefiner:
|
||||
@ -152,6 +153,8 @@ class KSamplerWithRefiner:
|
||||
noise.shape[3], noise.shape[2], self.device, "negative")
|
||||
|
||||
def refiner_switch():
|
||||
modules.virtual_memory.try_move_to_virtual_memory(self.model_denoise.inner_model)
|
||||
modules.virtual_memory.load_from_virtual_memory(self.refiner_model_denoise.inner_model)
|
||||
comfy.model_management.load_model_gpu(self.refiner_model_patcher)
|
||||
self.model_denoise.inner_model = self.refiner_model_denoise.inner_model
|
||||
for i in range(len(positive)):
|
||||
|
175
modules/virtual_memory.py
Normal file
175
modules/virtual_memory.py
Normal file
@ -0,0 +1,175 @@
|
||||
import torch
|
||||
import gc
|
||||
|
||||
from safetensors import safe_open
|
||||
from comfy import model_management
|
||||
from comfy.diffusers_convert import textenc_conversion_lst
|
||||
|
||||
|
||||
ALWAYS_USE_VM = None
|
||||
|
||||
if ALWAYS_USE_VM is not None:
|
||||
print(f'[Virtual Memory System] Forced = {ALWAYS_USE_VM}')
|
||||
|
||||
if 'cpu' in model_management.unet_offload_device().type.lower():
|
||||
logic_memory = model_management.total_ram
|
||||
global_virtual_memory_activated = ALWAYS_USE_VM if ALWAYS_USE_VM is not None else logic_memory < 30000
|
||||
print(f'[Virtual Memory System] Logic target is CPU, memory = {logic_memory}')
|
||||
else:
|
||||
logic_memory = model_management.total_vram
|
||||
global_virtual_memory_activated = ALWAYS_USE_VM if ALWAYS_USE_VM is not None else logic_memory < 22000
|
||||
print(f'[Virtual Memory System] Logic target is GPU, memory = {logic_memory}')
|
||||
|
||||
|
||||
print(f'[Virtual Memory System] Activated = {global_virtual_memory_activated}')
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def recursive_set(obj, key, value):
|
||||
if obj is None:
|
||||
return
|
||||
if '.' in key:
|
||||
k1, k2 = key.split('.', 1)
|
||||
recursive_set(getattr(obj, k1, None), k2, value)
|
||||
else:
|
||||
setattr(obj, key, value)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def recursive_del(obj, key):
|
||||
if obj is None:
|
||||
return
|
||||
if '.' in key:
|
||||
k1, k2 = key.split('.', 1)
|
||||
recursive_del(getattr(obj, k1, None), k2)
|
||||
else:
|
||||
delattr(obj, key)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def force_load_state_dict(model, state_dict):
|
||||
for k in list(state_dict.keys()):
|
||||
p = torch.nn.Parameter(state_dict[k], requires_grad=False)
|
||||
recursive_set(model, k, p)
|
||||
del state_dict[k]
|
||||
return
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def only_load_safetensors_keys(filename):
|
||||
try:
|
||||
with safe_open(filename, framework="pt", device='cpu') as f:
|
||||
result = list(f.keys())
|
||||
assert len(result) > 0
|
||||
return result
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def move_to_virtual_memory(model, comfy_unload=True):
|
||||
if comfy_unload:
|
||||
model_management.unload_model()
|
||||
|
||||
virtual_memory_dict = getattr(model, 'virtual_memory_dict', None)
|
||||
if isinstance(virtual_memory_dict, dict):
|
||||
# Already in virtual memory.
|
||||
return
|
||||
|
||||
model_file = getattr(model, 'model_file', None)
|
||||
assert isinstance(model_file, dict)
|
||||
|
||||
filename = model_file['filename']
|
||||
prefix = model_file['prefix']
|
||||
|
||||
safetensors_keys = only_load_safetensors_keys(filename)
|
||||
|
||||
if safetensors_keys is None:
|
||||
print(f'[Virtual Memory System] Error: The Virtual Memory System currently only support safetensors models!')
|
||||
return
|
||||
|
||||
sd = model.state_dict()
|
||||
original_device = list(sd.values())[0].device.type
|
||||
model_file['original_device'] = original_device
|
||||
|
||||
virtual_memory_dict = {}
|
||||
|
||||
for k, v in sd.items():
|
||||
current_key = k
|
||||
current_flag = None
|
||||
if prefix == 'refiner_clip':
|
||||
current_key_in_safetensors = k
|
||||
|
||||
for a, b in textenc_conversion_lst:
|
||||
current_key_in_safetensors = current_key_in_safetensors.replace(b, a)
|
||||
|
||||
current_key_in_safetensors = current_key_in_safetensors.replace('clip_g.transformer.text_model.encoder.layers.', 'conditioner.embedders.0.model.transformer.resblocks.')
|
||||
current_key_in_safetensors = current_key_in_safetensors.replace('clip_g.text_projection', 'conditioner.embedders.0.model.text_projection')
|
||||
current_key_in_safetensors = current_key_in_safetensors.replace('clip_g.logit_scale', 'conditioner.embedders.0.model.logit_scale')
|
||||
current_key_in_safetensors = current_key_in_safetensors.replace('clip_g.', 'conditioner.embedders.0.model.')
|
||||
|
||||
for e in ["weight", "bias"]:
|
||||
for i, k in enumerate(['q', 'k', 'v']):
|
||||
e_flag = f'.{k}_proj.{e}'
|
||||
if current_key_in_safetensors.endswith(e_flag):
|
||||
current_key_in_safetensors = current_key_in_safetensors[:-len(e_flag)] + f'.in_proj_{e}'
|
||||
current_flag = (1280 * i, 1280 * (i + 1))
|
||||
else:
|
||||
current_key_in_safetensors = prefix + '.' + k
|
||||
current_device = torch.device(index=v.device.index, type=v.device.type)
|
||||
if current_key_in_safetensors in safetensors_keys:
|
||||
virtual_memory_dict[current_key] = (current_key_in_safetensors, current_device, current_flag)
|
||||
recursive_del(model, current_key)
|
||||
else:
|
||||
# print(f'[Virtual Memory System] Missed key: {current_key}')
|
||||
pass
|
||||
|
||||
del sd
|
||||
gc.collect()
|
||||
model_management.soft_empty_cache()
|
||||
|
||||
model.virtual_memory_dict = virtual_memory_dict
|
||||
print(f'[Virtual Memory System] {prefix} released from {original_device}: {filename}')
|
||||
return
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def load_from_virtual_memory(model):
|
||||
virtual_memory_dict = getattr(model, 'virtual_memory_dict', None)
|
||||
if not isinstance(virtual_memory_dict, dict):
|
||||
# Not in virtual memory.
|
||||
return
|
||||
|
||||
model_file = getattr(model, 'model_file', None)
|
||||
assert isinstance(model_file, dict)
|
||||
|
||||
filename = model_file['filename']
|
||||
prefix = model_file['prefix']
|
||||
original_device = model_file['original_device']
|
||||
|
||||
with safe_open(filename, framework="pt", device=original_device) as f:
|
||||
for current_key, (current_key_in_safetensors, current_device, current_flag) in virtual_memory_dict.items():
|
||||
tensor = f.get_tensor(current_key_in_safetensors).to(current_device)
|
||||
if isinstance(current_flag, tuple) and len(current_flag) == 2:
|
||||
a, b = current_flag
|
||||
tensor = tensor[a:b]
|
||||
parameter = torch.nn.Parameter(tensor, requires_grad=False)
|
||||
recursive_set(model, current_key, parameter)
|
||||
|
||||
print(f'[Virtual Memory System] {prefix} loaded to {original_device}: {filename}')
|
||||
del model.virtual_memory_dict
|
||||
return
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def try_move_to_virtual_memory(model, comfy_unload=True):
|
||||
if not global_virtual_memory_activated:
|
||||
return
|
||||
|
||||
import modules.default_pipeline as pipeline
|
||||
|
||||
if pipeline.xl_refiner is None:
|
||||
# If users do not use refiner, no need to use this.
|
||||
return
|
||||
|
||||
move_to_virtual_memory(model, comfy_unload)
|
@ -1,3 +1,9 @@
|
||||
### 2.0.16
|
||||
|
||||
* Virtual memory system implemented. Now Colab can run both base model and refiner model with 7.8GB RAM + 5.3GB VRAM, and it never crashes.
|
||||
* If you are lucky enough to read this line, keep in mind that ComfyUI cannot do this. This is very reasonable that Fooocus is more optimized because it only need to handle a fixed pipeline, but ComfyUI need to consider arbitrary pipelines.
|
||||
* But if we just consider the optimization of this fixed workload, after 2.0.16, Fooocus has become the most optimized SDXL app, outperforming ComfyUI.
|
||||
|
||||
### 2.0.0
|
||||
|
||||
* V2 released.
|
||||
|
Loading…
Reference in New Issue
Block a user