try fix colab with virtual ram (#378)

try fix colab with virtual ram (#378)
This commit is contained in:
lllyasviel 2023-09-15 01:24:07 -07:00 committed by GitHub
parent b5b4fd27f1
commit cf7cde08b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 248 additions and 28 deletions

View File

@ -1 +1 @@
version = '2.0.14'
version = '2.0.16'

View File

@ -16,6 +16,7 @@ def worker():
import modules.default_pipeline as pipeline
import modules.path
import modules.patch
import modules.virtual_memory as virtual_memory
from modules.sdxl_styles import apply_style, aspect_ratios, fooocus_expansion
from modules.private_logger import log
@ -80,10 +81,10 @@ def worker():
progressbar(3, 'Loading models ...')
pipeline.refresh_base_model(base_model_name)
pipeline.refresh_refiner_model(refiner_model_name)
pipeline.refresh_loras(loras)
pipeline.clear_all_caches()
pipeline.refresh_everything(
refiner_model_name=refiner_model_name,
base_model_name=base_model_name,
loras=loras)
progressbar(3, 'Processing prompts ...')
@ -137,6 +138,8 @@ def worker():
pool_top_k=negative_top_k)
if pipeline.xl_refiner is not None:
virtual_memory.load_from_virtual_memory(pipeline.xl_refiner.clip.cond_stage_model)
for i, t in enumerate(tasks):
progressbar(11, f'Encoding refiner positive #{i + 1} ...')
t['c'][1] = pipeline.clip_encode(sd=pipeline.xl_refiner, texts=t['positive'],
@ -147,6 +150,8 @@ def worker():
t['uc'][1] = pipeline.clip_encode(sd=pipeline.xl_refiner, texts=t['negative'],
pool_top_k=negative_top_k)
virtual_memory.try_move_to_virtual_memory(pipeline.xl_refiner.clip.cond_stage_model)
if performance_selction == 'Speed':
steps = 30
switch = 20

View File

@ -10,6 +10,7 @@ import comfy.utils
from comfy.sd import load_checkpoint_guess_config
from nodes import VAEDecode, EmptyLatentImage
from comfy.sample import prepare_mask, broadcast_cond, load_additional_models, cleanup_additional_models
from comfy.model_base import SDXLRefiner
from modules.samplers_advanced import KSampler, KSamplerWithRefiner
from modules.patch import patch_all
@ -20,7 +21,15 @@ opVAEDecode = VAEDecode()
class StableDiffusionModel:
def __init__(self, unet, vae, clip, clip_vision):
def __init__(self, unet, vae, clip, clip_vision, model_filename=None):
if isinstance(model_filename, str):
is_refiner = isinstance(unet.model, SDXLRefiner)
if unet is not None:
unet.model.model_file = dict(filename=model_filename, prefix='model')
if clip is not None:
clip.cond_stage_model.model_file = dict(filename=model_filename, prefix='refiner_clip' if is_refiner else 'base_clip')
if vae is not None:
vae.first_stage_model.model_file = dict(filename=model_filename, prefix='first_stage_model')
self.unet = unet
self.vae = vae
self.clip = clip
@ -38,7 +47,7 @@ class StableDiffusionModel:
@torch.no_grad()
def load_model(ckpt_filename):
unet, clip, vae, clip_vision = load_checkpoint_guess_config(ckpt_filename)
return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision)
return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, model_filename=ckpt_filename)
@torch.no_grad()

View File

@ -2,6 +2,7 @@ import modules.core as core
import os
import torch
import modules.path
import modules.virtual_memory as virtual_memory
import comfy.model_management as model_management
from comfy.model_base import SDXL, SDXLRefiner
@ -21,10 +22,12 @@ xl_base_patched_hash = ''
def refresh_base_model(name):
global xl_base, xl_base_hash, xl_base_patched, xl_base_patched_hash
if xl_base_hash == str(name):
return
filename = os.path.join(modules.path.modelfile_path, name)
filename = os.path.abspath(os.path.realpath(os.path.join(modules.path.modelfile_path, name)))
model_hash = filename
if xl_base_hash == model_hash:
return
if xl_base is not None:
xl_base.to_meta()
@ -36,21 +39,25 @@ def refresh_base_model(name):
xl_base = None
xl_base_hash = ''
refresh_base_model(modules.path.default_base_model_name)
xl_base_hash = name
xl_base_hash = model_hash
xl_base_patched = xl_base
xl_base_patched_hash = ''
return
xl_base_hash = name
xl_base_hash = model_hash
xl_base_patched = xl_base
xl_base_patched_hash = ''
print(f'Base model loaded: {xl_base_hash}')
print(f'Base model loaded: {model_hash}')
return
def refresh_refiner_model(name):
global xl_refiner, xl_refiner_hash
if xl_refiner_hash == str(name):
filename = os.path.abspath(os.path.realpath(os.path.join(modules.path.modelfile_path, name)))
model_hash = filename
if xl_refiner_hash == model_hash:
return
if name == 'None':
@ -59,8 +66,6 @@ def refresh_refiner_model(name):
print(f'Refiner unloaded.')
return
filename = os.path.join(modules.path.modelfile_path, name)
if xl_refiner is not None:
xl_refiner.to_meta()
xl_refiner = None
@ -73,8 +78,8 @@ def refresh_refiner_model(name):
print(f'Refiner unloaded.')
return
xl_refiner_hash = name
print(f'Refiner model loaded: {xl_refiner_hash}')
xl_refiner_hash = model_hash
print(f'Refiner model loaded: {model_hash}')
xl_refiner.vae.first_stage_model.to('meta')
xl_refiner.vae = None
@ -100,13 +105,6 @@ def refresh_loras(loras):
return
refresh_base_model(modules.path.default_base_model_name)
refresh_refiner_model(modules.path.default_refiner_model_name)
refresh_loras([(modules.path.default_lora_name, 0.5), ('None', 0.5), ('None', 0.5), ('None', 0.5), ('None', 0.5)])
expansion = FooocusExpansion()
@torch.no_grad()
def clip_encode_single(clip, text, verbose=False):
cached = clip.fcs_cond_cache.get(text, None)
@ -133,8 +131,6 @@ def clip_encode(sd, texts, pool_top_k=1):
if len(texts) == 0:
return None
model_management.soft_empty_cache()
clip = sd.clip
cond_list = []
pooled_acc = 0
@ -164,6 +160,29 @@ def clear_all_caches():
clear_sd_cond_cache(xl_refiner)
def refresh_everything(refiner_model_name, base_model_name, loras):
refresh_refiner_model(refiner_model_name)
if xl_refiner is not None:
virtual_memory.try_move_to_virtual_memory(xl_refiner.unet.model)
virtual_memory.try_move_to_virtual_memory(xl_refiner.clip.cond_stage_model)
refresh_base_model(base_model_name)
virtual_memory.load_from_virtual_memory(xl_base.unet.model)
refresh_loras(loras)
clear_all_caches()
return
refresh_everything(
refiner_model_name=modules.path.default_refiner_model_name,
base_model_name=modules.path.default_base_model_name,
loras=[(modules.path.default_lora_name, 0.5), ('None', 0.5), ('None', 0.5), ('None', 0.5), ('None', 0.5)]
)
expansion = FooocusExpansion()
@torch.no_grad()
def patch_all_models():
assert xl_base is not None
@ -181,7 +200,10 @@ def patch_all_models():
@torch.no_grad()
def process_diffusion(positive_cond, negative_cond, steps, switch, width, height, image_seed, callback):
patch_all_models()
model_management.soft_empty_cache()
if xl_refiner is not None:
virtual_memory.try_move_to_virtual_memory(xl_refiner.unet.model)
virtual_memory.load_from_virtual_memory(xl_base.unet.model)
empty_latent = core.generate_empty_latent(width=width, height=height, batch_size=1)

View File

@ -1,6 +1,7 @@
from comfy.samplers import *
import comfy.model_management
import modules.virtual_memory
class KSamplerWithRefiner:
@ -152,6 +153,8 @@ class KSamplerWithRefiner:
noise.shape[3], noise.shape[2], self.device, "negative")
def refiner_switch():
modules.virtual_memory.try_move_to_virtual_memory(self.model_denoise.inner_model)
modules.virtual_memory.load_from_virtual_memory(self.refiner_model_denoise.inner_model)
comfy.model_management.load_model_gpu(self.refiner_model_patcher)
self.model_denoise.inner_model = self.refiner_model_denoise.inner_model
for i in range(len(positive)):

175
modules/virtual_memory.py Normal file
View File

@ -0,0 +1,175 @@
import torch
import gc
from safetensors import safe_open
from comfy import model_management
from comfy.diffusers_convert import textenc_conversion_lst
ALWAYS_USE_VM = None
if ALWAYS_USE_VM is not None:
print(f'[Virtual Memory System] Forced = {ALWAYS_USE_VM}')
if 'cpu' in model_management.unet_offload_device().type.lower():
logic_memory = model_management.total_ram
global_virtual_memory_activated = ALWAYS_USE_VM if ALWAYS_USE_VM is not None else logic_memory < 30000
print(f'[Virtual Memory System] Logic target is CPU, memory = {logic_memory}')
else:
logic_memory = model_management.total_vram
global_virtual_memory_activated = ALWAYS_USE_VM if ALWAYS_USE_VM is not None else logic_memory < 22000
print(f'[Virtual Memory System] Logic target is GPU, memory = {logic_memory}')
print(f'[Virtual Memory System] Activated = {global_virtual_memory_activated}')
@torch.no_grad()
def recursive_set(obj, key, value):
if obj is None:
return
if '.' in key:
k1, k2 = key.split('.', 1)
recursive_set(getattr(obj, k1, None), k2, value)
else:
setattr(obj, key, value)
@torch.no_grad()
def recursive_del(obj, key):
if obj is None:
return
if '.' in key:
k1, k2 = key.split('.', 1)
recursive_del(getattr(obj, k1, None), k2)
else:
delattr(obj, key)
@torch.no_grad()
def force_load_state_dict(model, state_dict):
for k in list(state_dict.keys()):
p = torch.nn.Parameter(state_dict[k], requires_grad=False)
recursive_set(model, k, p)
del state_dict[k]
return
@torch.no_grad()
def only_load_safetensors_keys(filename):
try:
with safe_open(filename, framework="pt", device='cpu') as f:
result = list(f.keys())
assert len(result) > 0
return result
except:
return None
@torch.no_grad()
def move_to_virtual_memory(model, comfy_unload=True):
if comfy_unload:
model_management.unload_model()
virtual_memory_dict = getattr(model, 'virtual_memory_dict', None)
if isinstance(virtual_memory_dict, dict):
# Already in virtual memory.
return
model_file = getattr(model, 'model_file', None)
assert isinstance(model_file, dict)
filename = model_file['filename']
prefix = model_file['prefix']
safetensors_keys = only_load_safetensors_keys(filename)
if safetensors_keys is None:
print(f'[Virtual Memory System] Error: The Virtual Memory System currently only support safetensors models!')
return
sd = model.state_dict()
original_device = list(sd.values())[0].device.type
model_file['original_device'] = original_device
virtual_memory_dict = {}
for k, v in sd.items():
current_key = k
current_flag = None
if prefix == 'refiner_clip':
current_key_in_safetensors = k
for a, b in textenc_conversion_lst:
current_key_in_safetensors = current_key_in_safetensors.replace(b, a)
current_key_in_safetensors = current_key_in_safetensors.replace('clip_g.transformer.text_model.encoder.layers.', 'conditioner.embedders.0.model.transformer.resblocks.')
current_key_in_safetensors = current_key_in_safetensors.replace('clip_g.text_projection', 'conditioner.embedders.0.model.text_projection')
current_key_in_safetensors = current_key_in_safetensors.replace('clip_g.logit_scale', 'conditioner.embedders.0.model.logit_scale')
current_key_in_safetensors = current_key_in_safetensors.replace('clip_g.', 'conditioner.embedders.0.model.')
for e in ["weight", "bias"]:
for i, k in enumerate(['q', 'k', 'v']):
e_flag = f'.{k}_proj.{e}'
if current_key_in_safetensors.endswith(e_flag):
current_key_in_safetensors = current_key_in_safetensors[:-len(e_flag)] + f'.in_proj_{e}'
current_flag = (1280 * i, 1280 * (i + 1))
else:
current_key_in_safetensors = prefix + '.' + k
current_device = torch.device(index=v.device.index, type=v.device.type)
if current_key_in_safetensors in safetensors_keys:
virtual_memory_dict[current_key] = (current_key_in_safetensors, current_device, current_flag)
recursive_del(model, current_key)
else:
# print(f'[Virtual Memory System] Missed key: {current_key}')
pass
del sd
gc.collect()
model_management.soft_empty_cache()
model.virtual_memory_dict = virtual_memory_dict
print(f'[Virtual Memory System] {prefix} released from {original_device}: {filename}')
return
@torch.no_grad()
def load_from_virtual_memory(model):
virtual_memory_dict = getattr(model, 'virtual_memory_dict', None)
if not isinstance(virtual_memory_dict, dict):
# Not in virtual memory.
return
model_file = getattr(model, 'model_file', None)
assert isinstance(model_file, dict)
filename = model_file['filename']
prefix = model_file['prefix']
original_device = model_file['original_device']
with safe_open(filename, framework="pt", device=original_device) as f:
for current_key, (current_key_in_safetensors, current_device, current_flag) in virtual_memory_dict.items():
tensor = f.get_tensor(current_key_in_safetensors).to(current_device)
if isinstance(current_flag, tuple) and len(current_flag) == 2:
a, b = current_flag
tensor = tensor[a:b]
parameter = torch.nn.Parameter(tensor, requires_grad=False)
recursive_set(model, current_key, parameter)
print(f'[Virtual Memory System] {prefix} loaded to {original_device}: {filename}')
del model.virtual_memory_dict
return
@torch.no_grad()
def try_move_to_virtual_memory(model, comfy_unload=True):
if not global_virtual_memory_activated:
return
import modules.default_pipeline as pipeline
if pipeline.xl_refiner is None:
# If users do not use refiner, no need to use this.
return
move_to_virtual_memory(model, comfy_unload)

View File

@ -1,3 +1,9 @@
### 2.0.16
* Virtual memory system implemented. Now Colab can run both base model and refiner model with 7.8GB RAM + 5.3GB VRAM, and it never crashes.
* If you are lucky enough to read this line, keep in mind that ComfyUI cannot do this. This is very reasonable that Fooocus is more optimized because it only need to handle a fixed pipeline, but ComfyUI need to consider arbitrary pipelines.
* But if we just consider the optimization of this fixed workload, after 2.0.16, Fooocus has become the most optimized SDXL app, outperforming ComfyUI.
### 2.0.0
* V2 released.