feat: add preset selection to Gradio UI (session based) (#1570)

* add preset selection

uses meta parsing to set presets in user session (UI elements only)

* add LoRA handling

* use default config as fallback value

* add preset refresh on "Refresh All Files" click

* add special handling for default_styles and default_aspect_ratio

* sort styles after preset change

* code cleanup

* download missing models from preset

* set default refiner to "None" in preset realistic

* use state_is_generating for preset selection change

* DRY output parameter handling

* feat: add argument --disable-preset-selection

useful for cloud provisioning to prevent model switches and keep models loaded

* feat: keep prompt when not set in preset, use more robust syntax

* fix: add default return values when preset download is disabled

https://github.com/mashb1t/Fooocus/issues/20

* feat: add translation for preset label

* refactor: unify preset loading methods in config

* refactor: code cleanup
This commit is contained in:
Manuel Schmid 2024-03-15 22:04:27 +01:00 committed by GitHub
parent 8baafcd79c
commit 4a44be36fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 133 additions and 70 deletions

View File

@ -4,7 +4,10 @@ import os
from tempfile import gettempdir
args_parser.parser.add_argument("--share", action='store_true', help="Set whether to share on Gradio.")
args_parser.parser.add_argument("--preset", type=str, default=None, help="Apply specified UI preset.")
args_parser.parser.add_argument("--disable-preset-selection", action='store_true',
help="Disables preset selection in Gradio.")
args_parser.parser.add_argument("--language", type=str, default='default',
help="Translate UI using json files in [language] folder. "

View File

@ -38,6 +38,7 @@
"* \"Inpaint or Outpaint\" is powered by the sampler \"DPMPP Fooocus Seamless 2M SDE Karras Inpaint Sampler\" (beta)": "* \"Inpaint or Outpaint\" is powered by the sampler \"DPMPP Fooocus Seamless 2M SDE Karras Inpaint Sampler\" (beta)",
"Setting": "Setting",
"Style": "Style",
"Preset": "Preset",
"Performance": "Performance",
"Speed": "Speed",
"Quality": "Quality",

View File

@ -93,7 +93,7 @@ if config.temp_path_cleanup_on_launch:
print(f"[Cleanup] Failed to delete content of temp dir.")
def download_models():
def download_models(default_model, previous_default_models, checkpoint_downloads, embeddings_downloads, lora_downloads):
for file_name, url in vae_approx_filenames:
load_file_from_url(url=url, model_dir=config.path_vae_approx, file_name=file_name)
@ -105,30 +105,32 @@ def download_models():
if args.disable_preset_download:
print('Skipped model download.')
return
return default_model, checkpoint_downloads
if not args.always_download_new_model:
if not os.path.exists(os.path.join(config.paths_checkpoints[0], config.default_base_model_name)):
for alternative_model_name in config.previous_default_models:
if not os.path.exists(os.path.join(config.paths_checkpoints[0], default_model)):
for alternative_model_name in previous_default_models:
if os.path.exists(os.path.join(config.paths_checkpoints[0], alternative_model_name)):
print(f'You do not have [{config.default_base_model_name}] but you have [{alternative_model_name}].')
print(f'You do not have [{default_model}] but you have [{alternative_model_name}].')
print(f'Fooocus will use [{alternative_model_name}] to avoid downloading new models, '
f'but you are not using latest models.')
f'but you are not using the latest models.')
print('Use --always-download-new-model to avoid fallback and always get new models.')
config.checkpoint_downloads = {}
config.default_base_model_name = alternative_model_name
checkpoint_downloads = {}
default_model = alternative_model_name
break
for file_name, url in config.checkpoint_downloads.items():
for file_name, url in checkpoint_downloads.items():
load_file_from_url(url=url, model_dir=config.paths_checkpoints[0], file_name=file_name)
for file_name, url in config.embeddings_downloads.items():
for file_name, url in embeddings_downloads.items():
load_file_from_url(url=url, model_dir=config.path_embeddings, file_name=file_name)
for file_name, url in config.lora_downloads.items():
for file_name, url in lora_downloads.items():
load_file_from_url(url=url, model_dir=config.paths_loras[0], file_name=file_name)
return
return default_model, checkpoint_downloads
download_models()
config.default_base_model_name, config.checkpoint_downloads = download_models(
config.default_base_model_name, config.previous_default_models, config.checkpoint_downloads,
config.embeddings_downloads, config.lora_downloads)
from webui import *

View File

@ -97,21 +97,44 @@ def try_load_deprecated_user_path_config():
try_load_deprecated_user_path_config()
def get_presets():
preset_folder = 'presets'
presets = ['initial']
if not os.path.exists(preset_folder):
print('No presets found.')
return presets
return presets + [f[:f.index('.json')] for f in os.listdir(preset_folder) if f.endswith('.json')]
def try_get_preset_content(preset):
if isinstance(preset, str):
preset_path = os.path.abspath(f'./presets/{preset}.json')
try:
if os.path.exists(preset_path):
with open(preset_path, "r", encoding="utf-8") as json_file:
json_content = json.load(json_file)
print(f'Loaded preset: {preset_path}')
return json_content
else:
raise FileNotFoundError
except Exception as e:
print(f'Load preset [{preset_path}] failed')
print(e)
return {}
try:
with open(os.path.abspath(f'./presets/default.json'), "r", encoding="utf-8") as json_file:
config_dict.update(json.load(json_file))
except Exception as e:
print(f'Load default preset failed.')
print(e)
available_presets = get_presets()
preset = args_manager.args.preset
if isinstance(preset, str):
preset_path = os.path.abspath(f'./presets/{preset}.json')
try:
if os.path.exists(preset_path):
with open(preset_path, "r", encoding="utf-8") as json_file:
config_dict.update(json.load(json_file))
print(f'Loaded preset: {preset_path}')
else:
raise FileNotFoundError
except Exception as e:
print(f'Load preset [{preset_path}] failed')
print(e)
config_dict.update(try_get_preset_content(preset))
def get_path_output() -> str:
"""
@ -241,7 +264,7 @@ temp_path_cleanup_on_launch = get_config_item_or_set_default(
default_value=True,
validator=lambda x: isinstance(x, bool)
)
default_base_model_name = get_config_item_or_set_default(
default_base_model_name = default_model = get_config_item_or_set_default(
key='default_model',
default_value='model.safetensors',
validator=lambda x: isinstance(x, str)
@ -251,7 +274,7 @@ previous_default_models = get_config_item_or_set_default(
default_value=[],
validator=lambda x: isinstance(x, list) and all(isinstance(k, str) for k in x)
)
default_refiner_model_name = get_config_item_or_set_default(
default_refiner_model_name = default_refiner = get_config_item_or_set_default(
key='default_refiner',
default_value='None',
validator=lambda x: isinstance(x, str)
@ -451,29 +474,30 @@ example_inpaint_prompts = [[x] for x in example_inpaint_prompts]
config_dict["default_loras"] = default_loras = default_loras[:default_max_lora_number] + [['None', 1.0] for _ in range(default_max_lora_number - len(default_loras))]
possible_preset_keys = [
"default_model",
"default_refiner",
"default_refiner_switch",
"default_loras_min_weight",
"default_loras_max_weight",
"default_loras",
"default_max_lora_number",
"default_cfg_scale",
"default_sample_sharpness",
"default_sampler",
"default_scheduler",
"default_performance",
"default_prompt",
"default_prompt_negative",
"default_styles",
"default_aspect_ratio",
"default_save_metadata_to_images",
"checkpoint_downloads",
"embeddings_downloads",
"lora_downloads",
]
# mapping config to meta parameter
possible_preset_keys = {
"default_model": "base_model",
"default_refiner": "refiner_model",
"default_refiner_switch": "refiner_switch",
"previous_default_models": "previous_default_models",
"default_loras_min_weight": "default_loras_min_weight",
"default_loras_max_weight": "default_loras_max_weight",
"default_loras": "<processed>",
"default_cfg_scale": "guidance_scale",
"default_sample_sharpness": "sharpness",
"default_sampler": "sampler",
"default_scheduler": "scheduler",
"default_overwrite_step": "steps",
"default_performance": "performance",
"default_prompt": "prompt",
"default_prompt_negative": "negative_prompt",
"default_styles": "styles",
"default_aspect_ratio": "resolution",
"default_save_metadata_to_images": "default_save_metadata_to_images",
"checkpoint_downloads": "checkpoint_downloads",
"embeddings_downloads": "embeddings_downloads",
"lora_downloads": "lora_downloads"
}
REWRITE_PRESET = False
@ -530,10 +554,11 @@ def get_model_filenames(folder_paths, extensions=None, name_filter=None):
def update_files():
global model_filenames, lora_filenames, wildcard_filenames
global model_filenames, lora_filenames, wildcard_filenames, available_presets
model_filenames = get_model_filenames(paths_checkpoints)
lora_filenames = get_model_filenames(paths_loras)
wildcard_filenames = get_files_from_folder(path_wildcards, ['.txt'])
available_presets = get_presets()
return

View File

@ -210,9 +210,8 @@ def parse_meta_from_preset(preset_content):
height = height[:height.index(" ")]
preset_prepared[meta_key] = (width, height)
else:
preset_prepared[meta_key] = items[settings_key] if settings_key in items and items[
settings_key] is not None else getattr(modules.config, settings_key)
preset_prepared[meta_key] = items[settings_key] if settings_key in items and items[settings_key] is not None else getattr(modules.config, settings_key)
if settings_key == "default_styles" or settings_key == "default_aspect_ratio":
preset_prepared[meta_key] = str(preset_prepared[meta_key])
@ -570,4 +569,4 @@ def get_exif(metadata: str | None, metadata_scheme: str):
exif[0x0131] = 'Fooocus v' + fooocus_version.version
# 0x927C = MakerNote
exif[0x927C] = metadata_scheme
return exif
return exif

View File

@ -1,6 +1,6 @@
{
"default_model": "realisticStockPhoto_v20.safetensors",
"default_refiner": "",
"default_refiner": "None",
"default_refiner_switch": 0.5,
"default_loras": [
[

View File

@ -15,6 +15,7 @@ import modules.style_sorter as style_sorter
import modules.meta_parser
import args_manager
import copy
import launch
from modules.sdxl_styles import legal_style_names
from modules.private_logger import get_current_html_path
@ -252,6 +253,11 @@ with shared.gradio_root:
with gr.Column(scale=1, visible=modules.config.default_advanced_checkbox) as advanced_column:
with gr.Tab(label='Setting'):
if not args_manager.args.disable_preset_selection:
preset_selection = gr.Radio(label='Preset',
choices=modules.config.available_presets,
value=args_manager.args.preset if args_manager.args.preset else "initial",
interactive=True)
performance_selection = gr.Radio(label='Performance',
choices=flags.Performance.list(),
value=modules.config.default_performance)
@ -518,13 +524,50 @@ with shared.gradio_root:
modules.config.update_files()
results = [gr.update(choices=modules.config.model_filenames)]
results += [gr.update(choices=['None'] + modules.config.model_filenames)]
if not args_manager.args.disable_preset_selection:
results += [gr.update(choices=modules.config.available_presets)]
for i in range(modules.config.default_max_lora_number):
results += [gr.update(interactive=True), gr.update(choices=['None'] + modules.config.lora_filenames), gr.update()]
results += [gr.update(interactive=True),
gr.update(choices=['None'] + modules.config.lora_filenames), gr.update()]
return results
refresh_files.click(refresh_files_clicked, [], [base_model, refiner_model] + lora_ctrls,
refresh_files_output = [base_model, refiner_model]
if not args_manager.args.disable_preset_selection:
refresh_files_output += [preset_selection]
refresh_files.click(refresh_files_clicked, [], refresh_files_output + lora_ctrls,
queue=False, show_progress=False)
state_is_generating = gr.State(False)
load_data_outputs = [advanced_checkbox, image_number, prompt, negative_prompt, style_selections,
performance_selection, overwrite_step, overwrite_switch, aspect_ratios_selection,
overwrite_width, overwrite_height, guidance_scale, sharpness, adm_scaler_positive,
adm_scaler_negative, adm_scaler_end, refiner_swap_method, adaptive_cfg, base_model,
refiner_model, refiner_switch, sampler_name, scheduler_name, seed_random, image_seed,
generate_button, load_parameter_button] + freeu_ctrls + lora_ctrls
if not args_manager.args.disable_preset_selection:
def preset_selection_change(preset, is_generating):
preset_content = modules.config.try_get_preset_content(preset) if preset != 'initial' else {}
preset_prepared = modules.meta_parser.parse_meta_from_preset(preset_content)
default_model = preset_prepared.get('base_model')
previous_default_models = preset_prepared.get('previous_default_models', [])
checkpoint_downloads = preset_prepared.get('checkpoint_downloads', {})
embeddings_downloads = preset_prepared.get('embeddings_downloads', {})
lora_downloads = preset_prepared.get('lora_downloads', {})
preset_prepared['base_model'], preset_prepared['lora_downloads'] = launch.download_models(
default_model, previous_default_models, checkpoint_downloads, embeddings_downloads, lora_downloads)
if 'prompt' in preset_prepared and preset_prepared.get('prompt') == '':
del preset_prepared['prompt']
return modules.meta_parser.load_parameter_button_click(json.dumps(preset_prepared), is_generating)
preset_selection.change(preset_selection_change, inputs=[preset_selection, state_is_generating], outputs=load_data_outputs, queue=False, show_progress=True) \
.then(fn=style_sorter.sort_styles, inputs=style_selections, outputs=style_selections, queue=False, show_progress=False) \
performance_selection.change(lambda x: [gr.update(interactive=not flags.Performance.has_restricted_features(x))] * 11 +
[gr.update(visible=not flags.Performance.has_restricted_features(x))] * 1 +
[gr.update(interactive=not flags.Performance.has_restricted_features(x), value=flags.Performance.has_restricted_features(x))] * 1,
@ -600,8 +643,6 @@ with shared.gradio_root:
ctrls += ip_ctrls
state_is_generating = gr.State(False)
def parse_meta(raw_prompt_txt, is_generating):
loaded_json = None
if is_json(raw_prompt_txt):
@ -617,13 +658,6 @@ with shared.gradio_root:
prompt.input(parse_meta, inputs=[prompt, state_is_generating], outputs=[prompt, generate_button, load_parameter_button], queue=False, show_progress=False)
load_data_outputs = [advanced_checkbox, image_number, prompt, negative_prompt, style_selections,
performance_selection, overwrite_step, overwrite_switch, aspect_ratios_selection,
overwrite_width, overwrite_height, guidance_scale, sharpness, adm_scaler_positive,
adm_scaler_negative, adm_scaler_end, refiner_swap_method, adaptive_cfg, base_model,
refiner_model, refiner_switch, sampler_name, scheduler_name, seed_random, image_seed,
generate_button, load_parameter_button] + freeu_ctrls + lora_ctrls
load_parameter_button.click(modules.meta_parser.load_parameter_button_click, inputs=[prompt, state_is_generating], outputs=load_data_outputs, queue=False, show_progress=False)
def trigger_metadata_import(filepath, state_is_generating):
@ -637,7 +671,6 @@ with shared.gradio_root:
return modules.meta_parser.load_parameter_button_click(parsed_parameters, state_is_generating)
metadata_import_button.click(trigger_metadata_import, inputs=[metadata_input_image, state_is_generating], outputs=load_data_outputs, queue=False, show_progress=True) \
.then(style_sorter.sort_styles, inputs=style_selections, outputs=style_selections, queue=False, show_progress=False)