minor revise (#382)

* minor revise

* minor revise
This commit is contained in:
lllyasviel 2023-09-15 03:40:06 -07:00 committed by GitHub
parent d1b4389098
commit 8ef00d87b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 15 additions and 6 deletions

View File

@ -1 +1 @@
version = '2.0.17'
version = '2.0.18'

View File

@ -1,4 +1,7 @@
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import sys
import platform
import fooocus_version

View File

@ -1,4 +1,5 @@
import torch
import time
import gc
from safetensors import safe_open
@ -8,16 +9,16 @@ from comfy.diffusers_convert import textenc_conversion_lst
ALWAYS_USE_VM = None
if ALWAYS_USE_VM is not None:
if isinstance(ALWAYS_USE_VM, bool):
print(f'[Virtual Memory System] Forced = {ALWAYS_USE_VM}')
if 'cpu' in model_management.unet_offload_device().type.lower():
logic_memory = model_management.total_ram
global_virtual_memory_activated = ALWAYS_USE_VM if ALWAYS_USE_VM is not None else logic_memory < 30000
global_virtual_memory_activated = ALWAYS_USE_VM if isinstance(ALWAYS_USE_VM, bool) else logic_memory < 30000
print(f'[Virtual Memory System] Logic target is CPU, memory = {logic_memory}')
else:
logic_memory = model_management.total_vram
global_virtual_memory_activated = ALWAYS_USE_VM if ALWAYS_USE_VM is not None else logic_memory < 22000
global_virtual_memory_activated = ALWAYS_USE_VM if isinstance(ALWAYS_USE_VM, bool) else logic_memory < 22000
print(f'[Virtual Memory System] Logic target is GPU, memory = {logic_memory}')
@ -68,6 +69,8 @@ def only_load_safetensors_keys(filename):
@torch.no_grad()
def move_to_virtual_memory(model, comfy_unload=True):
timer = time.time()
if comfy_unload:
model_management.unload_model()
@ -129,12 +132,15 @@ def move_to_virtual_memory(model, comfy_unload=True):
model_management.soft_empty_cache()
model.virtual_memory_dict = virtual_memory_dict
print(f'[Virtual Memory System] {prefix} released from {original_device}: {filename}')
print(f'[Virtual Memory System] time = {str("%.5f" % (time.time() - timer))}s: {prefix} released from {original_device}: {filename}')
return
@torch.no_grad()
def load_from_virtual_memory(model):
timer = time.time()
virtual_memory_dict = getattr(model, 'virtual_memory_dict', None)
if not isinstance(virtual_memory_dict, dict):
# Not in virtual memory.
@ -156,7 +162,7 @@ def load_from_virtual_memory(model):
parameter = torch.nn.Parameter(tensor, requires_grad=False)
recursive_set(model, current_key, parameter)
print(f'[Virtual Memory System] {prefix} loaded to {original_device}: {filename}')
print(f'[Virtual Memory System] time = {str("%.5f" % (time.time() - timer))}s: {prefix} loaded to {original_device}: {filename}')
del model.virtual_memory_dict
return