feat: add ability to load checkpoints and loras from multiple locations (#1256)

* Add ability to load checkpoints and loras from multiple locations

* Found another location a default path is required

* feat: use array as default

---------

Co-authored-by: Manuel Schmid <manuel.schmid@odt.net>
This commit is contained in:
dooglewoogle 2024-02-26 00:47:14 +13:00 committed by GitHub
parent 7cfb5e742d
commit ef1999c52c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 43 additions and 18 deletions

View File

@ -68,7 +68,6 @@ vae_approx_filenames = [
'https://huggingface.co/lllyasviel/misc/resolve/main/xl-to-v1_interposer-v3.1.safetensors')
]
def ini_args():
from args_manager import args
return args
@ -101,9 +100,9 @@ def download_models():
return
if not args.always_download_new_model:
if not os.path.exists(os.path.join(config.path_checkpoints, config.default_base_model_name)):
if not os.path.exists(os.path.join(config.paths_checkpoints[0], config.default_base_model_name)):
for alternative_model_name in config.previous_default_models:
if os.path.exists(os.path.join(config.path_checkpoints, alternative_model_name)):
if os.path.exists(os.path.join(config.paths_checkpoints[0], alternative_model_name)):
print(f'You do not have [{config.default_base_model_name}] but you have [{alternative_model_name}].')
print(f'Fooocus will use [{alternative_model_name}] to avoid downloading new models, '
f'but you are not using latest models.')
@ -113,11 +112,11 @@ def download_models():
break
for file_name, url in config.checkpoint_downloads.items():
load_file_from_url(url=url, model_dir=config.path_checkpoints, file_name=file_name)
load_file_from_url(url=url, model_dir=config.paths_checkpoints[0], file_name=file_name)
for file_name, url in config.embeddings_downloads.items():
load_file_from_url(url=url, model_dir=config.path_embeddings, file_name=file_name)
for file_name, url in config.lora_downloads.items():
load_file_from_url(url=url, model_dir=config.path_loras, file_name=file_name)
load_file_from_url(url=url, model_dir=config.paths_loras[0], file_name=file_name)
return

View File

@ -114,7 +114,7 @@ def get_path_output() -> str:
return path_output
def get_dir_or_set_default(key, default_value):
def get_dir_or_set_default(key, default_value, as_array=False):
global config_dict, visited_keys, always_save_keys
if key not in visited_keys:
@ -125,18 +125,29 @@ def get_dir_or_set_default(key, default_value):
v = config_dict.get(key, None)
if isinstance(v, str) and os.path.exists(v) and os.path.isdir(v):
return v if not as_array else [v]
elif isinstance(v, list) and all([os.path.exists(d) and os.path.isdir(d) for d in v]):
return v
else:
if v is not None:
print(f'Failed to load config key: {json.dumps({key:v})} is invalid or does not exist; will use {json.dumps({key:default_value})} instead.')
dp = os.path.abspath(os.path.join(os.path.dirname(__file__), default_value))
os.makedirs(dp, exist_ok=True)
if isinstance(default_value, list):
dp = []
for path in default_value:
abs_path = os.path.abspath(os.path.join(os.path.dirname(__file__), path))
dp.append(abs_path)
os.makedirs(abs_path, exist_ok=True)
else:
dp = os.path.abspath(os.path.join(os.path.dirname(__file__), default_value))
os.makedirs(dp, exist_ok=True)
if as_array:
dp = [dp]
config_dict[key] = dp
return dp
path_checkpoints = get_dir_or_set_default('path_checkpoints', '../models/checkpoints/')
path_loras = get_dir_or_set_default('path_loras', '../models/loras/')
paths_checkpoints = get_dir_or_set_default('path_checkpoints', ['../models/checkpoints/'], True)
paths_loras = get_dir_or_set_default('path_loras', ['../models/loras/'], True)
path_embeddings = get_dir_or_set_default('path_embeddings', '../models/embeddings/')
path_vae_approx = get_dir_or_set_default('path_vae_approx', '../models/vae_approx/')
path_upscale_models = get_dir_or_set_default('path_upscale_models', '../models/upscale_models/')
@ -404,14 +415,18 @@ model_filenames = []
lora_filenames = []
def get_model_filenames(folder_path, name_filter=None):
return get_files_from_folder(folder_path, ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch'], name_filter)
def get_model_filenames(folder_paths, name_filter=None):
extensions = ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch']
files = []
for folder in folder_paths:
files += get_files_from_folder(folder, extensions, name_filter)
return files
def update_all_model_names():
global model_filenames, lora_filenames
model_filenames = get_model_filenames(path_checkpoints)
lora_filenames = get_model_filenames(path_loras)
model_filenames = get_model_filenames(paths_checkpoints)
lora_filenames = get_model_filenames(paths_loras)
return
@ -456,7 +471,7 @@ def downloading_inpaint_models(v):
def downloading_sdxl_lcm_lora():
load_file_from_url(
url='https://huggingface.co/lllyasviel/misc/resolve/main/sdxl_lcm_lora.safetensors',
model_dir=path_loras,
model_dir=paths_loras[0],
file_name='sdxl_lcm_lora.safetensors'
)
return 'sdxl_lcm_lora.safetensors'

View File

@ -18,6 +18,7 @@ from ldm_patched.contrib.external import VAEDecode, EmptyLatentImage, VAEEncode,
from ldm_patched.contrib.external_freelunch import FreeU_V2
from ldm_patched.modules.sample import prepare_mask
from modules.lora import match_lora
from modules.util import get_file_from_folder_list
from ldm_patched.modules.lora import model_lora_keys_unet, model_lora_keys_clip
from modules.config import path_embeddings
from ldm_patched.contrib.external_model_advanced import ModelSamplingDiscrete
@ -79,7 +80,7 @@ class StableDiffusionModel:
if os.path.exists(name):
lora_filename = name
else:
lora_filename = os.path.join(modules.config.path_loras, name)
lora_filename = get_file_from_folder_list(name, modules.config.paths_loras)
if not os.path.exists(lora_filename):
print(f'Lora file not found: {lora_filename}')

View File

@ -11,6 +11,7 @@ from extras.expansion import FooocusExpansion
from ldm_patched.modules.model_base import SDXL, SDXLRefiner
from modules.sample_hijack import clip_separate
from modules.util import get_file_from_folder_list
model_base = core.StableDiffusionModel()
@ -60,7 +61,7 @@ def assert_model_integrity():
def refresh_base_model(name):
global model_base
filename = os.path.abspath(os.path.realpath(os.path.join(modules.config.path_checkpoints, name)))
filename = get_file_from_folder_list(name, modules.config.paths_checkpoints)
if model_base.filename == filename:
return
@ -76,7 +77,7 @@ def refresh_base_model(name):
def refresh_refiner_model(name):
global model_refiner
filename = os.path.abspath(os.path.realpath(os.path.join(modules.config.path_checkpoints, name)))
filename = get_file_from_folder_list(name, modules.config.paths_checkpoints)
if model_refiner.filename == filename:
return

View File

@ -177,5 +177,14 @@ def get_files_from_folder(folder_path, exensions=None, name_filter=None):
return filenames
def get_file_from_folder_list(name, folders):
for folder in folders:
filename = os.path.abspath(os.path.realpath(os.path.join(folder, name)))
if os.path.isfile(filename):
return filename
return os.path.abspath(os.path.realpath(os.path.join(folders[0], name)))
def ordinal_suffix(number: int) -> str:
return 'th' if 10 <= number % 100 <= 20 else {1: 'st', 2: 'nd', 3: 'rd'}.get(number % 10, 'th')