Fooocus/modules/virtual_memory.py
2023-09-15 15:46:44 -07:00

186 lines
6.4 KiB
Python

import torch
import time
import gc
from safetensors import safe_open
from comfy import model_management
from comfy.diffusers_convert import textenc_conversion_lst
ALWAYS_USE_VM = None
if isinstance(ALWAYS_USE_VM, bool):
print(f'[Virtual Memory System] Forced = {ALWAYS_USE_VM}')
if 'cpu' in model_management.unet_offload_device().type.lower():
logic_memory = model_management.total_ram
global_virtual_memory_activated = ALWAYS_USE_VM if isinstance(ALWAYS_USE_VM, bool) else logic_memory < 30000
print(f'[Virtual Memory System] Logic target is CPU, memory = {logic_memory}')
else:
logic_memory = model_management.total_vram
global_virtual_memory_activated = ALWAYS_USE_VM if isinstance(ALWAYS_USE_VM, bool) else logic_memory < 22000
print(f'[Virtual Memory System] Logic target is GPU, memory = {logic_memory}')
print(f'[Virtual Memory System] Activated = {global_virtual_memory_activated}')
@torch.no_grad()
def recursive_set(obj, key, value):
if obj is None:
return
if '.' in key:
k1, k2 = key.split('.', 1)
recursive_set(getattr(obj, k1, None), k2, value)
else:
setattr(obj, key, value)
@torch.no_grad()
def recursive_del(obj, key):
if obj is None:
return
if '.' in key:
k1, k2 = key.split('.', 1)
recursive_del(getattr(obj, k1, None), k2)
else:
delattr(obj, key)
@torch.no_grad()
def force_load_state_dict(model, state_dict):
for k in list(state_dict.keys()):
p = torch.nn.Parameter(state_dict[k], requires_grad=False)
recursive_set(model, k, p)
del state_dict[k]
return
@torch.no_grad()
def only_load_safetensors_keys(filename):
try:
with safe_open(filename, framework="pt", device='cpu') as f:
result = list(f.keys())
assert len(result) > 0
return result
except:
return None
@torch.no_grad()
def move_to_virtual_memory(model, comfy_unload=True):
timer = time.time()
if comfy_unload:
model_management.unload_model()
virtual_memory_dict = getattr(model, 'virtual_memory_dict', None)
if isinstance(virtual_memory_dict, dict):
# Already in virtual memory.
return
model_file = getattr(model, 'model_file', None)
assert isinstance(model_file, dict)
filename = model_file['filename']
prefix = model_file['prefix']
if hasattr(model, 'virtual_memory_safetensors_keys'):
safetensors_keys = model.virtual_memory_safetensors_keys
else:
safetensors_keys = only_load_safetensors_keys(filename)
model.virtual_memory_safetensors_keys = safetensors_keys
if safetensors_keys is None:
print(f'[Virtual Memory System] Error: The Virtual Memory System currently only support safetensors models!')
return
sd = model.state_dict()
original_device = list(sd.values())[0].device.type
model_file['original_device'] = original_device
virtual_memory_dict = {}
for k, v in sd.items():
current_key = k
current_flag = None
if prefix == 'refiner_clip':
current_key_in_safetensors = k
for a, b in textenc_conversion_lst:
current_key_in_safetensors = current_key_in_safetensors.replace(b, a)
current_key_in_safetensors = current_key_in_safetensors.replace('clip_g.transformer.text_model.encoder.layers.', 'conditioner.embedders.0.model.transformer.resblocks.')
current_key_in_safetensors = current_key_in_safetensors.replace('clip_g.text_projection', 'conditioner.embedders.0.model.text_projection')
current_key_in_safetensors = current_key_in_safetensors.replace('clip_g.logit_scale', 'conditioner.embedders.0.model.logit_scale')
current_key_in_safetensors = current_key_in_safetensors.replace('clip_g.', 'conditioner.embedders.0.model.')
for e in ["weight", "bias"]:
for i, k in enumerate(['q', 'k', 'v']):
e_flag = f'.{k}_proj.{e}'
if current_key_in_safetensors.endswith(e_flag):
current_key_in_safetensors = current_key_in_safetensors[:-len(e_flag)] + f'.in_proj_{e}'
current_flag = (1280 * i, 1280 * (i + 1))
else:
current_key_in_safetensors = prefix + '.' + k
current_device = torch.device(index=v.device.index, type=v.device.type)
if current_key_in_safetensors in safetensors_keys:
virtual_memory_dict[current_key] = (current_key_in_safetensors, current_device, current_flag)
recursive_del(model, current_key)
else:
# print(f'[Virtual Memory System] Missed key: {current_key}')
pass
del sd
gc.collect()
model_management.soft_empty_cache()
model.virtual_memory_dict = virtual_memory_dict
print(f'[Virtual Memory System] time = {str("%.5f" % (time.time() - timer))}s: {prefix} released from {original_device}: {filename}')
return
@torch.no_grad()
def load_from_virtual_memory(model):
timer = time.time()
virtual_memory_dict = getattr(model, 'virtual_memory_dict', None)
if not isinstance(virtual_memory_dict, dict):
# Not in virtual memory.
return
model_file = getattr(model, 'model_file', None)
assert isinstance(model_file, dict)
filename = model_file['filename']
prefix = model_file['prefix']
original_device = model_file['original_device']
with safe_open(filename, framework="pt", device=original_device) as f:
for current_key, (current_key_in_safetensors, current_device, current_flag) in virtual_memory_dict.items():
tensor = f.get_tensor(current_key_in_safetensors).to(current_device)
if isinstance(current_flag, tuple) and len(current_flag) == 2:
a, b = current_flag
tensor = tensor[a:b]
parameter = torch.nn.Parameter(tensor, requires_grad=False)
recursive_set(model, current_key, parameter)
print(f'[Virtual Memory System] time = {str("%.5f" % (time.time() - timer))}s: {prefix} loaded to {original_device}: {filename}')
del model.virtual_memory_dict
return
@torch.no_grad()
def try_move_to_virtual_memory(model, comfy_unload=True):
if not global_virtual_memory_activated:
return
import modules.default_pipeline as pipeline
if pipeline.xl_refiner is None:
# If users do not use refiner, no need to use this.
return
move_to_virtual_memory(model, comfy_unload)