diff --git a/args_manager.py b/args_manager.py index e5e7675..eeb38e1 100644 --- a/args_manager.py +++ b/args_manager.py @@ -20,6 +20,12 @@ args_parser.parser.add_argument("--disable-image-log", action='store_true', args_parser.parser.add_argument("--disable-analytics", action='store_true', help="Disables analytics for Gradio", default=False) +args_parser.parser.add_argument("--disable-preset-download", action='store_true', + help="Disables downloading models for presets", default=False) + +args_parser.parser.add_argument("--always-download-new-model", action='store_true', + help="Always download newer models ", default=False) + args_parser.parser.set_defaults( disable_cuda_malloc=True, in_browser=True, diff --git a/launch.py b/launch.py index e98045f..9dbd3b6 100644 --- a/launch.py +++ b/launch.py @@ -21,8 +21,7 @@ import fooocus_version from build_launcher import build_launcher from modules.launch_util import is_installed, run, python, run_pip, requirements_met from modules.model_loader import load_file_from_url -from modules.config import path_checkpoints, path_loras, path_vae_approx, path_fooocus_expansion, \ - checkpoint_downloads, path_embeddings, embeddings_downloads, lora_downloads +from modules import config REINSTALL_ALL = False @@ -70,25 +69,6 @@ vae_approx_filenames = [ ] -def download_models(): - for file_name, url in checkpoint_downloads.items(): - load_file_from_url(url=url, model_dir=path_checkpoints, file_name=file_name) - for file_name, url in embeddings_downloads.items(): - load_file_from_url(url=url, model_dir=path_embeddings, file_name=file_name) - for file_name, url in lora_downloads.items(): - load_file_from_url(url=url, model_dir=path_loras, file_name=file_name) - for file_name, url in vae_approx_filenames: - load_file_from_url(url=url, model_dir=path_vae_approx, file_name=file_name) - - load_file_from_url( - url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_expansion.bin', - model_dir=path_fooocus_expansion, - file_name='pytorch_model.bin' - ) - - return - - def ini_args(): from args_manager import args return args @@ -104,6 +84,43 @@ if args.gpu_device_id is not None: print("Set device to:", args.gpu_device_id) +def download_models(): + for file_name, url in vae_approx_filenames: + load_file_from_url(url=url, model_dir=config.path_vae_approx, file_name=file_name) + + load_file_from_url( + url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_expansion.bin', + model_dir=config.path_fooocus_expansion, + file_name='pytorch_model.bin' + ) + + if args.disable_preset_download: + print('Skipped model download.') + return + + if not args.always_download_new_model: + if not os.path.exists(os.path.join(config.path_checkpoints, config.default_base_model_name)): + for alternative_model_name in config.previous_default_models: + if os.path.exists(os.path.join(config.path_checkpoints, alternative_model_name)): + print(f'You do not have [{config.default_base_model_name}] but you have [{alternative_model_name}].') + print(f'Fooocus will use [{alternative_model_name}] to avoid downloading new models, ' + f'but you are not using latest models.') + print('Use --always-download-new-model to avoid fallback and always get new models.') + config.checkpoint_downloads = {} + config.default_base_model_name = alternative_model_name + break + + for file_name, url in config.checkpoint_downloads.items(): + load_file_from_url(url=url, model_dir=config.path_checkpoints, file_name=file_name) + for file_name, url in config.embeddings_downloads.items(): + load_file_from_url(url=url, model_dir=config.path_embeddings, file_name=file_name) + for file_name, url in config.lora_downloads.items(): + load_file_from_url(url=url, model_dir=config.path_loras, file_name=file_name) + + return + + download_models() + from webui import * diff --git a/modules/config.py b/modules/config.py index c7af33d..7feae8f 100644 --- a/modules/config.py +++ b/modules/config.py @@ -79,6 +79,13 @@ def try_load_deprecated_user_path_config(): try_load_deprecated_user_path_config() +try: + with open(os.path.abspath(f'./presets/default.json'), "r", encoding="utf-8") as json_file: + config_dict.update(json.load(json_file)) +except Exception as e: + print(f'Load default preset failed.') + print(e) + preset = args_manager.args.preset if isinstance(preset, str): @@ -153,9 +160,14 @@ def get_config_item_or_set_default(key, default_value, validator, disable_empty_ default_base_model_name = get_config_item_or_set_default( key='default_model', - default_value='juggernautXL_version6Rundiffusion.safetensors', + default_value='model.safetensors', validator=lambda x: isinstance(x, str) ) +previous_default_models = get_config_item_or_set_default( + key='previous_default_models', + default_value=[], + validator=lambda x: isinstance(x, list) and all(isinstance(k, str) for k in x) +) default_refiner_model_name = get_config_item_or_set_default( key='default_refiner', default_value='None', @@ -163,15 +175,15 @@ default_refiner_model_name = get_config_item_or_set_default( ) default_refiner_switch = get_config_item_or_set_default( key='default_refiner_switch', - default_value=0.5, + default_value=0.8, validator=lambda x: isinstance(x, numbers.Number) and 0 <= x <= 1 ) default_loras = get_config_item_or_set_default( key='default_loras', default_value=[ [ - "sd_xl_offset_example-lora_1.0.safetensors", - 0.1 + "None", + 1.0 ], [ "None", @@ -194,7 +206,7 @@ default_loras = get_config_item_or_set_default( ) default_cfg_scale = get_config_item_or_set_default( key='default_cfg_scale', - default_value=4.0, + default_value=7.0, validator=lambda x: isinstance(x, numbers.Number) ) default_sample_sharpness = get_config_item_or_set_default( @@ -255,16 +267,12 @@ default_image_number = get_config_item_or_set_default( ) checkpoint_downloads = get_config_item_or_set_default( key='checkpoint_downloads', - default_value={ - "juggernautXL_version6Rundiffusion.safetensors": "https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/juggernautXL_version6Rundiffusion.safetensors" - }, + default_value={}, validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items()) ) lora_downloads = get_config_item_or_set_default( key='lora_downloads', - default_value={ - "sd_xl_offset_example-lora_1.0.safetensors": "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_offset_example-lora_1.0.safetensors" - }, + default_value={}, validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items()) ) embeddings_downloads = get_config_item_or_set_default(