186 lines
6.4 KiB
Python
186 lines
6.4 KiB
Python
import torch
|
|
import time
|
|
import gc
|
|
|
|
from safetensors import safe_open
|
|
from comfy import model_management
|
|
from comfy.diffusers_convert import textenc_conversion_lst
|
|
|
|
|
|
ALWAYS_USE_VM = None
|
|
|
|
if isinstance(ALWAYS_USE_VM, bool):
|
|
print(f'[Virtual Memory System] Forced = {ALWAYS_USE_VM}')
|
|
|
|
if 'cpu' in model_management.unet_offload_device().type.lower():
|
|
logic_memory = model_management.total_ram
|
|
global_virtual_memory_activated = ALWAYS_USE_VM if isinstance(ALWAYS_USE_VM, bool) else logic_memory < 30000
|
|
print(f'[Virtual Memory System] Logic target is CPU, memory = {logic_memory}')
|
|
else:
|
|
logic_memory = model_management.total_vram
|
|
global_virtual_memory_activated = ALWAYS_USE_VM if isinstance(ALWAYS_USE_VM, bool) else logic_memory < 22000
|
|
print(f'[Virtual Memory System] Logic target is GPU, memory = {logic_memory}')
|
|
|
|
|
|
print(f'[Virtual Memory System] Activated = {global_virtual_memory_activated}')
|
|
|
|
|
|
@torch.no_grad()
|
|
def recursive_set(obj, key, value):
|
|
if obj is None:
|
|
return
|
|
if '.' in key:
|
|
k1, k2 = key.split('.', 1)
|
|
recursive_set(getattr(obj, k1, None), k2, value)
|
|
else:
|
|
setattr(obj, key, value)
|
|
|
|
|
|
@torch.no_grad()
|
|
def recursive_del(obj, key):
|
|
if obj is None:
|
|
return
|
|
if '.' in key:
|
|
k1, k2 = key.split('.', 1)
|
|
recursive_del(getattr(obj, k1, None), k2)
|
|
else:
|
|
delattr(obj, key)
|
|
|
|
|
|
@torch.no_grad()
|
|
def force_load_state_dict(model, state_dict):
|
|
for k in list(state_dict.keys()):
|
|
p = torch.nn.Parameter(state_dict[k], requires_grad=False)
|
|
recursive_set(model, k, p)
|
|
del state_dict[k]
|
|
return
|
|
|
|
|
|
@torch.no_grad()
|
|
def only_load_safetensors_keys(filename):
|
|
try:
|
|
with safe_open(filename, framework="pt", device='cpu') as f:
|
|
result = list(f.keys())
|
|
assert len(result) > 0
|
|
return result
|
|
except:
|
|
return None
|
|
|
|
|
|
@torch.no_grad()
|
|
def move_to_virtual_memory(model, comfy_unload=True):
|
|
timer = time.time()
|
|
|
|
if comfy_unload:
|
|
model_management.unload_model()
|
|
|
|
virtual_memory_dict = getattr(model, 'virtual_memory_dict', None)
|
|
if isinstance(virtual_memory_dict, dict):
|
|
# Already in virtual memory.
|
|
return
|
|
|
|
model_file = getattr(model, 'model_file', None)
|
|
assert isinstance(model_file, dict)
|
|
|
|
filename = model_file['filename']
|
|
prefix = model_file['prefix']
|
|
|
|
if hasattr(model, 'virtual_memory_safetensors_keys'):
|
|
safetensors_keys = model.virtual_memory_safetensors_keys
|
|
else:
|
|
safetensors_keys = only_load_safetensors_keys(filename)
|
|
model.virtual_memory_safetensors_keys = safetensors_keys
|
|
|
|
if safetensors_keys is None:
|
|
print(f'[Virtual Memory System] Error: The Virtual Memory System currently only support safetensors models!')
|
|
return
|
|
|
|
sd = model.state_dict()
|
|
original_device = list(sd.values())[0].device.type
|
|
model_file['original_device'] = original_device
|
|
|
|
virtual_memory_dict = {}
|
|
|
|
for k, v in sd.items():
|
|
current_key = k
|
|
current_flag = None
|
|
if prefix == 'refiner_clip':
|
|
current_key_in_safetensors = k
|
|
|
|
for a, b in textenc_conversion_lst:
|
|
current_key_in_safetensors = current_key_in_safetensors.replace(b, a)
|
|
|
|
current_key_in_safetensors = current_key_in_safetensors.replace('clip_g.transformer.text_model.encoder.layers.', 'conditioner.embedders.0.model.transformer.resblocks.')
|
|
current_key_in_safetensors = current_key_in_safetensors.replace('clip_g.text_projection', 'conditioner.embedders.0.model.text_projection')
|
|
current_key_in_safetensors = current_key_in_safetensors.replace('clip_g.logit_scale', 'conditioner.embedders.0.model.logit_scale')
|
|
current_key_in_safetensors = current_key_in_safetensors.replace('clip_g.', 'conditioner.embedders.0.model.')
|
|
|
|
for e in ["weight", "bias"]:
|
|
for i, k in enumerate(['q', 'k', 'v']):
|
|
e_flag = f'.{k}_proj.{e}'
|
|
if current_key_in_safetensors.endswith(e_flag):
|
|
current_key_in_safetensors = current_key_in_safetensors[:-len(e_flag)] + f'.in_proj_{e}'
|
|
current_flag = (1280 * i, 1280 * (i + 1))
|
|
else:
|
|
current_key_in_safetensors = prefix + '.' + k
|
|
current_device = torch.device(index=v.device.index, type=v.device.type)
|
|
if current_key_in_safetensors in safetensors_keys:
|
|
virtual_memory_dict[current_key] = (current_key_in_safetensors, current_device, current_flag)
|
|
recursive_del(model, current_key)
|
|
else:
|
|
# print(f'[Virtual Memory System] Missed key: {current_key}')
|
|
pass
|
|
|
|
del sd
|
|
gc.collect()
|
|
model_management.soft_empty_cache()
|
|
|
|
model.virtual_memory_dict = virtual_memory_dict
|
|
|
|
print(f'[Virtual Memory System] time = {str("%.5f" % (time.time() - timer))}s: {prefix} released from {original_device}: {filename}')
|
|
return
|
|
|
|
|
|
@torch.no_grad()
|
|
def load_from_virtual_memory(model):
|
|
timer = time.time()
|
|
|
|
virtual_memory_dict = getattr(model, 'virtual_memory_dict', None)
|
|
if not isinstance(virtual_memory_dict, dict):
|
|
# Not in virtual memory.
|
|
return
|
|
|
|
model_file = getattr(model, 'model_file', None)
|
|
assert isinstance(model_file, dict)
|
|
|
|
filename = model_file['filename']
|
|
prefix = model_file['prefix']
|
|
original_device = model_file['original_device']
|
|
|
|
with safe_open(filename, framework="pt", device=original_device) as f:
|
|
for current_key, (current_key_in_safetensors, current_device, current_flag) in virtual_memory_dict.items():
|
|
tensor = f.get_tensor(current_key_in_safetensors).to(current_device)
|
|
if isinstance(current_flag, tuple) and len(current_flag) == 2:
|
|
a, b = current_flag
|
|
tensor = tensor[a:b]
|
|
parameter = torch.nn.Parameter(tensor, requires_grad=False)
|
|
recursive_set(model, current_key, parameter)
|
|
|
|
print(f'[Virtual Memory System] time = {str("%.5f" % (time.time() - timer))}s: {prefix} loaded to {original_device}: {filename}')
|
|
del model.virtual_memory_dict
|
|
return
|
|
|
|
|
|
@torch.no_grad()
|
|
def try_move_to_virtual_memory(model, comfy_unload=True):
|
|
if not global_virtual_memory_activated:
|
|
return
|
|
|
|
import modules.default_pipeline as pipeline
|
|
|
|
if pipeline.xl_refiner is None:
|
|
# If users do not use refiner, no need to use this.
|
|
return
|
|
|
|
move_to_virtual_memory(model, comfy_unload)
|