try fix colab with virtual ram (#378)

try fix colab with virtual ram (#378)
This commit is contained in:
lllyasviel 2023-09-15 01:24:07 -07:00 committed by GitHub
parent b5b4fd27f1
commit cf7cde08b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 248 additions and 28 deletions

View File

@ -1 +1 @@
version = '2.0.14' version = '2.0.16'

View File

@ -16,6 +16,7 @@ def worker():
import modules.default_pipeline as pipeline import modules.default_pipeline as pipeline
import modules.path import modules.path
import modules.patch import modules.patch
import modules.virtual_memory as virtual_memory
from modules.sdxl_styles import apply_style, aspect_ratios, fooocus_expansion from modules.sdxl_styles import apply_style, aspect_ratios, fooocus_expansion
from modules.private_logger import log from modules.private_logger import log
@ -80,10 +81,10 @@ def worker():
progressbar(3, 'Loading models ...') progressbar(3, 'Loading models ...')
pipeline.refresh_base_model(base_model_name) pipeline.refresh_everything(
pipeline.refresh_refiner_model(refiner_model_name) refiner_model_name=refiner_model_name,
pipeline.refresh_loras(loras) base_model_name=base_model_name,
pipeline.clear_all_caches() loras=loras)
progressbar(3, 'Processing prompts ...') progressbar(3, 'Processing prompts ...')
@ -137,6 +138,8 @@ def worker():
pool_top_k=negative_top_k) pool_top_k=negative_top_k)
if pipeline.xl_refiner is not None: if pipeline.xl_refiner is not None:
virtual_memory.load_from_virtual_memory(pipeline.xl_refiner.clip.cond_stage_model)
for i, t in enumerate(tasks): for i, t in enumerate(tasks):
progressbar(11, f'Encoding refiner positive #{i + 1} ...') progressbar(11, f'Encoding refiner positive #{i + 1} ...')
t['c'][1] = pipeline.clip_encode(sd=pipeline.xl_refiner, texts=t['positive'], t['c'][1] = pipeline.clip_encode(sd=pipeline.xl_refiner, texts=t['positive'],
@ -147,6 +150,8 @@ def worker():
t['uc'][1] = pipeline.clip_encode(sd=pipeline.xl_refiner, texts=t['negative'], t['uc'][1] = pipeline.clip_encode(sd=pipeline.xl_refiner, texts=t['negative'],
pool_top_k=negative_top_k) pool_top_k=negative_top_k)
virtual_memory.try_move_to_virtual_memory(pipeline.xl_refiner.clip.cond_stage_model)
if performance_selction == 'Speed': if performance_selction == 'Speed':
steps = 30 steps = 30
switch = 20 switch = 20

View File

@ -10,6 +10,7 @@ import comfy.utils
from comfy.sd import load_checkpoint_guess_config from comfy.sd import load_checkpoint_guess_config
from nodes import VAEDecode, EmptyLatentImage from nodes import VAEDecode, EmptyLatentImage
from comfy.sample import prepare_mask, broadcast_cond, load_additional_models, cleanup_additional_models from comfy.sample import prepare_mask, broadcast_cond, load_additional_models, cleanup_additional_models
from comfy.model_base import SDXLRefiner
from modules.samplers_advanced import KSampler, KSamplerWithRefiner from modules.samplers_advanced import KSampler, KSamplerWithRefiner
from modules.patch import patch_all from modules.patch import patch_all
@ -20,7 +21,15 @@ opVAEDecode = VAEDecode()
class StableDiffusionModel: class StableDiffusionModel:
def __init__(self, unet, vae, clip, clip_vision): def __init__(self, unet, vae, clip, clip_vision, model_filename=None):
if isinstance(model_filename, str):
is_refiner = isinstance(unet.model, SDXLRefiner)
if unet is not None:
unet.model.model_file = dict(filename=model_filename, prefix='model')
if clip is not None:
clip.cond_stage_model.model_file = dict(filename=model_filename, prefix='refiner_clip' if is_refiner else 'base_clip')
if vae is not None:
vae.first_stage_model.model_file = dict(filename=model_filename, prefix='first_stage_model')
self.unet = unet self.unet = unet
self.vae = vae self.vae = vae
self.clip = clip self.clip = clip
@ -38,7 +47,7 @@ class StableDiffusionModel:
@torch.no_grad() @torch.no_grad()
def load_model(ckpt_filename): def load_model(ckpt_filename):
unet, clip, vae, clip_vision = load_checkpoint_guess_config(ckpt_filename) unet, clip, vae, clip_vision = load_checkpoint_guess_config(ckpt_filename)
return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision) return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, model_filename=ckpt_filename)
@torch.no_grad() @torch.no_grad()

View File

@ -2,6 +2,7 @@ import modules.core as core
import os import os
import torch import torch
import modules.path import modules.path
import modules.virtual_memory as virtual_memory
import comfy.model_management as model_management import comfy.model_management as model_management
from comfy.model_base import SDXL, SDXLRefiner from comfy.model_base import SDXL, SDXLRefiner
@ -21,10 +22,12 @@ xl_base_patched_hash = ''
def refresh_base_model(name): def refresh_base_model(name):
global xl_base, xl_base_hash, xl_base_patched, xl_base_patched_hash global xl_base, xl_base_hash, xl_base_patched, xl_base_patched_hash
if xl_base_hash == str(name):
return
filename = os.path.join(modules.path.modelfile_path, name) filename = os.path.abspath(os.path.realpath(os.path.join(modules.path.modelfile_path, name)))
model_hash = filename
if xl_base_hash == model_hash:
return
if xl_base is not None: if xl_base is not None:
xl_base.to_meta() xl_base.to_meta()
@ -36,21 +39,25 @@ def refresh_base_model(name):
xl_base = None xl_base = None
xl_base_hash = '' xl_base_hash = ''
refresh_base_model(modules.path.default_base_model_name) refresh_base_model(modules.path.default_base_model_name)
xl_base_hash = name xl_base_hash = model_hash
xl_base_patched = xl_base xl_base_patched = xl_base
xl_base_patched_hash = '' xl_base_patched_hash = ''
return return
xl_base_hash = name xl_base_hash = model_hash
xl_base_patched = xl_base xl_base_patched = xl_base
xl_base_patched_hash = '' xl_base_patched_hash = ''
print(f'Base model loaded: {xl_base_hash}') print(f'Base model loaded: {model_hash}')
return return
def refresh_refiner_model(name): def refresh_refiner_model(name):
global xl_refiner, xl_refiner_hash global xl_refiner, xl_refiner_hash
if xl_refiner_hash == str(name):
filename = os.path.abspath(os.path.realpath(os.path.join(modules.path.modelfile_path, name)))
model_hash = filename
if xl_refiner_hash == model_hash:
return return
if name == 'None': if name == 'None':
@ -59,8 +66,6 @@ def refresh_refiner_model(name):
print(f'Refiner unloaded.') print(f'Refiner unloaded.')
return return
filename = os.path.join(modules.path.modelfile_path, name)
if xl_refiner is not None: if xl_refiner is not None:
xl_refiner.to_meta() xl_refiner.to_meta()
xl_refiner = None xl_refiner = None
@ -73,8 +78,8 @@ def refresh_refiner_model(name):
print(f'Refiner unloaded.') print(f'Refiner unloaded.')
return return
xl_refiner_hash = name xl_refiner_hash = model_hash
print(f'Refiner model loaded: {xl_refiner_hash}') print(f'Refiner model loaded: {model_hash}')
xl_refiner.vae.first_stage_model.to('meta') xl_refiner.vae.first_stage_model.to('meta')
xl_refiner.vae = None xl_refiner.vae = None
@ -100,13 +105,6 @@ def refresh_loras(loras):
return return
refresh_base_model(modules.path.default_base_model_name)
refresh_refiner_model(modules.path.default_refiner_model_name)
refresh_loras([(modules.path.default_lora_name, 0.5), ('None', 0.5), ('None', 0.5), ('None', 0.5), ('None', 0.5)])
expansion = FooocusExpansion()
@torch.no_grad() @torch.no_grad()
def clip_encode_single(clip, text, verbose=False): def clip_encode_single(clip, text, verbose=False):
cached = clip.fcs_cond_cache.get(text, None) cached = clip.fcs_cond_cache.get(text, None)
@ -133,8 +131,6 @@ def clip_encode(sd, texts, pool_top_k=1):
if len(texts) == 0: if len(texts) == 0:
return None return None
model_management.soft_empty_cache()
clip = sd.clip clip = sd.clip
cond_list = [] cond_list = []
pooled_acc = 0 pooled_acc = 0
@ -164,6 +160,29 @@ def clear_all_caches():
clear_sd_cond_cache(xl_refiner) clear_sd_cond_cache(xl_refiner)
def refresh_everything(refiner_model_name, base_model_name, loras):
refresh_refiner_model(refiner_model_name)
if xl_refiner is not None:
virtual_memory.try_move_to_virtual_memory(xl_refiner.unet.model)
virtual_memory.try_move_to_virtual_memory(xl_refiner.clip.cond_stage_model)
refresh_base_model(base_model_name)
virtual_memory.load_from_virtual_memory(xl_base.unet.model)
refresh_loras(loras)
clear_all_caches()
return
refresh_everything(
refiner_model_name=modules.path.default_refiner_model_name,
base_model_name=modules.path.default_base_model_name,
loras=[(modules.path.default_lora_name, 0.5), ('None', 0.5), ('None', 0.5), ('None', 0.5), ('None', 0.5)]
)
expansion = FooocusExpansion()
@torch.no_grad() @torch.no_grad()
def patch_all_models(): def patch_all_models():
assert xl_base is not None assert xl_base is not None
@ -181,7 +200,10 @@ def patch_all_models():
@torch.no_grad() @torch.no_grad()
def process_diffusion(positive_cond, negative_cond, steps, switch, width, height, image_seed, callback): def process_diffusion(positive_cond, negative_cond, steps, switch, width, height, image_seed, callback):
patch_all_models() patch_all_models()
model_management.soft_empty_cache()
if xl_refiner is not None:
virtual_memory.try_move_to_virtual_memory(xl_refiner.unet.model)
virtual_memory.load_from_virtual_memory(xl_base.unet.model)
empty_latent = core.generate_empty_latent(width=width, height=height, batch_size=1) empty_latent = core.generate_empty_latent(width=width, height=height, batch_size=1)

View File

@ -1,6 +1,7 @@
from comfy.samplers import * from comfy.samplers import *
import comfy.model_management import comfy.model_management
import modules.virtual_memory
class KSamplerWithRefiner: class KSamplerWithRefiner:
@ -152,6 +153,8 @@ class KSamplerWithRefiner:
noise.shape[3], noise.shape[2], self.device, "negative") noise.shape[3], noise.shape[2], self.device, "negative")
def refiner_switch(): def refiner_switch():
modules.virtual_memory.try_move_to_virtual_memory(self.model_denoise.inner_model)
modules.virtual_memory.load_from_virtual_memory(self.refiner_model_denoise.inner_model)
comfy.model_management.load_model_gpu(self.refiner_model_patcher) comfy.model_management.load_model_gpu(self.refiner_model_patcher)
self.model_denoise.inner_model = self.refiner_model_denoise.inner_model self.model_denoise.inner_model = self.refiner_model_denoise.inner_model
for i in range(len(positive)): for i in range(len(positive)):

175
modules/virtual_memory.py Normal file
View File

@ -0,0 +1,175 @@
import torch
import gc
from safetensors import safe_open
from comfy import model_management
from comfy.diffusers_convert import textenc_conversion_lst
ALWAYS_USE_VM = None
if ALWAYS_USE_VM is not None:
print(f'[Virtual Memory System] Forced = {ALWAYS_USE_VM}')
if 'cpu' in model_management.unet_offload_device().type.lower():
logic_memory = model_management.total_ram
global_virtual_memory_activated = ALWAYS_USE_VM if ALWAYS_USE_VM is not None else logic_memory < 30000
print(f'[Virtual Memory System] Logic target is CPU, memory = {logic_memory}')
else:
logic_memory = model_management.total_vram
global_virtual_memory_activated = ALWAYS_USE_VM if ALWAYS_USE_VM is not None else logic_memory < 22000
print(f'[Virtual Memory System] Logic target is GPU, memory = {logic_memory}')
print(f'[Virtual Memory System] Activated = {global_virtual_memory_activated}')
@torch.no_grad()
def recursive_set(obj, key, value):
if obj is None:
return
if '.' in key:
k1, k2 = key.split('.', 1)
recursive_set(getattr(obj, k1, None), k2, value)
else:
setattr(obj, key, value)
@torch.no_grad()
def recursive_del(obj, key):
if obj is None:
return
if '.' in key:
k1, k2 = key.split('.', 1)
recursive_del(getattr(obj, k1, None), k2)
else:
delattr(obj, key)
@torch.no_grad()
def force_load_state_dict(model, state_dict):
for k in list(state_dict.keys()):
p = torch.nn.Parameter(state_dict[k], requires_grad=False)
recursive_set(model, k, p)
del state_dict[k]
return
@torch.no_grad()
def only_load_safetensors_keys(filename):
try:
with safe_open(filename, framework="pt", device='cpu') as f:
result = list(f.keys())
assert len(result) > 0
return result
except:
return None
@torch.no_grad()
def move_to_virtual_memory(model, comfy_unload=True):
if comfy_unload:
model_management.unload_model()
virtual_memory_dict = getattr(model, 'virtual_memory_dict', None)
if isinstance(virtual_memory_dict, dict):
# Already in virtual memory.
return
model_file = getattr(model, 'model_file', None)
assert isinstance(model_file, dict)
filename = model_file['filename']
prefix = model_file['prefix']
safetensors_keys = only_load_safetensors_keys(filename)
if safetensors_keys is None:
print(f'[Virtual Memory System] Error: The Virtual Memory System currently only support safetensors models!')
return
sd = model.state_dict()
original_device = list(sd.values())[0].device.type
model_file['original_device'] = original_device
virtual_memory_dict = {}
for k, v in sd.items():
current_key = k
current_flag = None
if prefix == 'refiner_clip':
current_key_in_safetensors = k
for a, b in textenc_conversion_lst:
current_key_in_safetensors = current_key_in_safetensors.replace(b, a)
current_key_in_safetensors = current_key_in_safetensors.replace('clip_g.transformer.text_model.encoder.layers.', 'conditioner.embedders.0.model.transformer.resblocks.')
current_key_in_safetensors = current_key_in_safetensors.replace('clip_g.text_projection', 'conditioner.embedders.0.model.text_projection')
current_key_in_safetensors = current_key_in_safetensors.replace('clip_g.logit_scale', 'conditioner.embedders.0.model.logit_scale')
current_key_in_safetensors = current_key_in_safetensors.replace('clip_g.', 'conditioner.embedders.0.model.')
for e in ["weight", "bias"]:
for i, k in enumerate(['q', 'k', 'v']):
e_flag = f'.{k}_proj.{e}'
if current_key_in_safetensors.endswith(e_flag):
current_key_in_safetensors = current_key_in_safetensors[:-len(e_flag)] + f'.in_proj_{e}'
current_flag = (1280 * i, 1280 * (i + 1))
else:
current_key_in_safetensors = prefix + '.' + k
current_device = torch.device(index=v.device.index, type=v.device.type)
if current_key_in_safetensors in safetensors_keys:
virtual_memory_dict[current_key] = (current_key_in_safetensors, current_device, current_flag)
recursive_del(model, current_key)
else:
# print(f'[Virtual Memory System] Missed key: {current_key}')
pass
del sd
gc.collect()
model_management.soft_empty_cache()
model.virtual_memory_dict = virtual_memory_dict
print(f'[Virtual Memory System] {prefix} released from {original_device}: {filename}')
return
@torch.no_grad()
def load_from_virtual_memory(model):
virtual_memory_dict = getattr(model, 'virtual_memory_dict', None)
if not isinstance(virtual_memory_dict, dict):
# Not in virtual memory.
return
model_file = getattr(model, 'model_file', None)
assert isinstance(model_file, dict)
filename = model_file['filename']
prefix = model_file['prefix']
original_device = model_file['original_device']
with safe_open(filename, framework="pt", device=original_device) as f:
for current_key, (current_key_in_safetensors, current_device, current_flag) in virtual_memory_dict.items():
tensor = f.get_tensor(current_key_in_safetensors).to(current_device)
if isinstance(current_flag, tuple) and len(current_flag) == 2:
a, b = current_flag
tensor = tensor[a:b]
parameter = torch.nn.Parameter(tensor, requires_grad=False)
recursive_set(model, current_key, parameter)
print(f'[Virtual Memory System] {prefix} loaded to {original_device}: {filename}')
del model.virtual_memory_dict
return
@torch.no_grad()
def try_move_to_virtual_memory(model, comfy_unload=True):
if not global_virtual_memory_activated:
return
import modules.default_pipeline as pipeline
if pipeline.xl_refiner is None:
# If users do not use refiner, no need to use this.
return
move_to_virtual_memory(model, comfy_unload)

View File

@ -1,3 +1,9 @@
### 2.0.16
* Virtual memory system implemented. Now Colab can run both base model and refiner model with 7.8GB RAM + 5.3GB VRAM, and it never crashes.
* If you are lucky enough to read this line, keep in mind that ComfyUI cannot do this. This is very reasonable that Fooocus is more optimized because it only need to handle a fixed pipeline, but ComfyUI need to consider arbitrary pipelines.
* But if we just consider the optimization of this fixed workload, after 2.0.16, Fooocus has become the most optimized SDXL app, outperforming ComfyUI.
### 2.0.0 ### 2.0.0
* V2 released. * V2 released.