feat: add preset selection to Gradio UI (session based) (#1570)
* add preset selection uses meta parsing to set presets in user session (UI elements only) * add LoRA handling * use default config as fallback value * add preset refresh on "Refresh All Files" click * add special handling for default_styles and default_aspect_ratio * sort styles after preset change * code cleanup * download missing models from preset * set default refiner to "None" in preset realistic * use state_is_generating for preset selection change * DRY output parameter handling * feat: add argument --disable-preset-selection useful for cloud provisioning to prevent model switches and keep models loaded * feat: keep prompt when not set in preset, use more robust syntax * fix: add default return values when preset download is disabled https://github.com/mashb1t/Fooocus/issues/20 * feat: add translation for preset label * refactor: unify preset loading methods in config * refactor: code cleanup
This commit is contained in:
parent
8baafcd79c
commit
4a44be36fd
@ -4,7 +4,10 @@ import os
|
||||
from tempfile import gettempdir
|
||||
|
||||
args_parser.parser.add_argument("--share", action='store_true', help="Set whether to share on Gradio.")
|
||||
|
||||
args_parser.parser.add_argument("--preset", type=str, default=None, help="Apply specified UI preset.")
|
||||
args_parser.parser.add_argument("--disable-preset-selection", action='store_true',
|
||||
help="Disables preset selection in Gradio.")
|
||||
|
||||
args_parser.parser.add_argument("--language", type=str, default='default',
|
||||
help="Translate UI using json files in [language] folder. "
|
||||
|
@ -38,6 +38,7 @@
|
||||
"* \"Inpaint or Outpaint\" is powered by the sampler \"DPMPP Fooocus Seamless 2M SDE Karras Inpaint Sampler\" (beta)": "* \"Inpaint or Outpaint\" is powered by the sampler \"DPMPP Fooocus Seamless 2M SDE Karras Inpaint Sampler\" (beta)",
|
||||
"Setting": "Setting",
|
||||
"Style": "Style",
|
||||
"Preset": "Preset",
|
||||
"Performance": "Performance",
|
||||
"Speed": "Speed",
|
||||
"Quality": "Quality",
|
||||
|
28
launch.py
28
launch.py
@ -93,7 +93,7 @@ if config.temp_path_cleanup_on_launch:
|
||||
print(f"[Cleanup] Failed to delete content of temp dir.")
|
||||
|
||||
|
||||
def download_models():
|
||||
def download_models(default_model, previous_default_models, checkpoint_downloads, embeddings_downloads, lora_downloads):
|
||||
for file_name, url in vae_approx_filenames:
|
||||
load_file_from_url(url=url, model_dir=config.path_vae_approx, file_name=file_name)
|
||||
|
||||
@ -105,30 +105,32 @@ def download_models():
|
||||
|
||||
if args.disable_preset_download:
|
||||
print('Skipped model download.')
|
||||
return
|
||||
return default_model, checkpoint_downloads
|
||||
|
||||
if not args.always_download_new_model:
|
||||
if not os.path.exists(os.path.join(config.paths_checkpoints[0], config.default_base_model_name)):
|
||||
for alternative_model_name in config.previous_default_models:
|
||||
if not os.path.exists(os.path.join(config.paths_checkpoints[0], default_model)):
|
||||
for alternative_model_name in previous_default_models:
|
||||
if os.path.exists(os.path.join(config.paths_checkpoints[0], alternative_model_name)):
|
||||
print(f'You do not have [{config.default_base_model_name}] but you have [{alternative_model_name}].')
|
||||
print(f'You do not have [{default_model}] but you have [{alternative_model_name}].')
|
||||
print(f'Fooocus will use [{alternative_model_name}] to avoid downloading new models, '
|
||||
f'but you are not using latest models.')
|
||||
f'but you are not using the latest models.')
|
||||
print('Use --always-download-new-model to avoid fallback and always get new models.')
|
||||
config.checkpoint_downloads = {}
|
||||
config.default_base_model_name = alternative_model_name
|
||||
checkpoint_downloads = {}
|
||||
default_model = alternative_model_name
|
||||
break
|
||||
|
||||
for file_name, url in config.checkpoint_downloads.items():
|
||||
for file_name, url in checkpoint_downloads.items():
|
||||
load_file_from_url(url=url, model_dir=config.paths_checkpoints[0], file_name=file_name)
|
||||
for file_name, url in config.embeddings_downloads.items():
|
||||
for file_name, url in embeddings_downloads.items():
|
||||
load_file_from_url(url=url, model_dir=config.path_embeddings, file_name=file_name)
|
||||
for file_name, url in config.lora_downloads.items():
|
||||
for file_name, url in lora_downloads.items():
|
||||
load_file_from_url(url=url, model_dir=config.paths_loras[0], file_name=file_name)
|
||||
|
||||
return
|
||||
return default_model, checkpoint_downloads
|
||||
|
||||
|
||||
download_models()
|
||||
config.default_base_model_name, config.checkpoint_downloads = download_models(
|
||||
config.default_base_model_name, config.previous_default_models, config.checkpoint_downloads,
|
||||
config.embeddings_downloads, config.lora_downloads)
|
||||
|
||||
from webui import *
|
||||
|
@ -97,21 +97,44 @@ def try_load_deprecated_user_path_config():
|
||||
|
||||
try_load_deprecated_user_path_config()
|
||||
|
||||
|
||||
def get_presets():
|
||||
preset_folder = 'presets'
|
||||
presets = ['initial']
|
||||
if not os.path.exists(preset_folder):
|
||||
print('No presets found.')
|
||||
return presets
|
||||
|
||||
return presets + [f[:f.index('.json')] for f in os.listdir(preset_folder) if f.endswith('.json')]
|
||||
|
||||
|
||||
def try_get_preset_content(preset):
|
||||
if isinstance(preset, str):
|
||||
preset_path = os.path.abspath(f'./presets/{preset}.json')
|
||||
try:
|
||||
if os.path.exists(preset_path):
|
||||
with open(preset_path, "r", encoding="utf-8") as json_file:
|
||||
json_content = json.load(json_file)
|
||||
print(f'Loaded preset: {preset_path}')
|
||||
return json_content
|
||||
else:
|
||||
raise FileNotFoundError
|
||||
except Exception as e:
|
||||
print(f'Load preset [{preset_path}] failed')
|
||||
print(e)
|
||||
return {}
|
||||
|
||||
|
||||
try:
|
||||
with open(os.path.abspath(f'./presets/default.json'), "r", encoding="utf-8") as json_file:
|
||||
config_dict.update(json.load(json_file))
|
||||
except Exception as e:
|
||||
print(f'Load default preset failed.')
|
||||
print(e)
|
||||
|
||||
available_presets = get_presets()
|
||||
preset = args_manager.args.preset
|
||||
|
||||
if isinstance(preset, str):
|
||||
preset_path = os.path.abspath(f'./presets/{preset}.json')
|
||||
try:
|
||||
if os.path.exists(preset_path):
|
||||
with open(preset_path, "r", encoding="utf-8") as json_file:
|
||||
config_dict.update(json.load(json_file))
|
||||
print(f'Loaded preset: {preset_path}')
|
||||
else:
|
||||
raise FileNotFoundError
|
||||
except Exception as e:
|
||||
print(f'Load preset [{preset_path}] failed')
|
||||
print(e)
|
||||
|
||||
config_dict.update(try_get_preset_content(preset))
|
||||
|
||||
def get_path_output() -> str:
|
||||
"""
|
||||
@ -241,7 +264,7 @@ temp_path_cleanup_on_launch = get_config_item_or_set_default(
|
||||
default_value=True,
|
||||
validator=lambda x: isinstance(x, bool)
|
||||
)
|
||||
default_base_model_name = get_config_item_or_set_default(
|
||||
default_base_model_name = default_model = get_config_item_or_set_default(
|
||||
key='default_model',
|
||||
default_value='model.safetensors',
|
||||
validator=lambda x: isinstance(x, str)
|
||||
@ -251,7 +274,7 @@ previous_default_models = get_config_item_or_set_default(
|
||||
default_value=[],
|
||||
validator=lambda x: isinstance(x, list) and all(isinstance(k, str) for k in x)
|
||||
)
|
||||
default_refiner_model_name = get_config_item_or_set_default(
|
||||
default_refiner_model_name = default_refiner = get_config_item_or_set_default(
|
||||
key='default_refiner',
|
||||
default_value='None',
|
||||
validator=lambda x: isinstance(x, str)
|
||||
@ -451,29 +474,30 @@ example_inpaint_prompts = [[x] for x in example_inpaint_prompts]
|
||||
|
||||
config_dict["default_loras"] = default_loras = default_loras[:default_max_lora_number] + [['None', 1.0] for _ in range(default_max_lora_number - len(default_loras))]
|
||||
|
||||
possible_preset_keys = [
|
||||
"default_model",
|
||||
"default_refiner",
|
||||
"default_refiner_switch",
|
||||
"default_loras_min_weight",
|
||||
"default_loras_max_weight",
|
||||
"default_loras",
|
||||
"default_max_lora_number",
|
||||
"default_cfg_scale",
|
||||
"default_sample_sharpness",
|
||||
"default_sampler",
|
||||
"default_scheduler",
|
||||
"default_performance",
|
||||
"default_prompt",
|
||||
"default_prompt_negative",
|
||||
"default_styles",
|
||||
"default_aspect_ratio",
|
||||
"default_save_metadata_to_images",
|
||||
"checkpoint_downloads",
|
||||
"embeddings_downloads",
|
||||
"lora_downloads",
|
||||
]
|
||||
|
||||
# mapping config to meta parameter
|
||||
possible_preset_keys = {
|
||||
"default_model": "base_model",
|
||||
"default_refiner": "refiner_model",
|
||||
"default_refiner_switch": "refiner_switch",
|
||||
"previous_default_models": "previous_default_models",
|
||||
"default_loras_min_weight": "default_loras_min_weight",
|
||||
"default_loras_max_weight": "default_loras_max_weight",
|
||||
"default_loras": "<processed>",
|
||||
"default_cfg_scale": "guidance_scale",
|
||||
"default_sample_sharpness": "sharpness",
|
||||
"default_sampler": "sampler",
|
||||
"default_scheduler": "scheduler",
|
||||
"default_overwrite_step": "steps",
|
||||
"default_performance": "performance",
|
||||
"default_prompt": "prompt",
|
||||
"default_prompt_negative": "negative_prompt",
|
||||
"default_styles": "styles",
|
||||
"default_aspect_ratio": "resolution",
|
||||
"default_save_metadata_to_images": "default_save_metadata_to_images",
|
||||
"checkpoint_downloads": "checkpoint_downloads",
|
||||
"embeddings_downloads": "embeddings_downloads",
|
||||
"lora_downloads": "lora_downloads"
|
||||
}
|
||||
|
||||
REWRITE_PRESET = False
|
||||
|
||||
@ -530,10 +554,11 @@ def get_model_filenames(folder_paths, extensions=None, name_filter=None):
|
||||
|
||||
|
||||
def update_files():
|
||||
global model_filenames, lora_filenames, wildcard_filenames
|
||||
global model_filenames, lora_filenames, wildcard_filenames, available_presets
|
||||
model_filenames = get_model_filenames(paths_checkpoints)
|
||||
lora_filenames = get_model_filenames(paths_loras)
|
||||
wildcard_filenames = get_files_from_folder(path_wildcards, ['.txt'])
|
||||
available_presets = get_presets()
|
||||
return
|
||||
|
||||
|
||||
|
@ -210,9 +210,8 @@ def parse_meta_from_preset(preset_content):
|
||||
height = height[:height.index(" ")]
|
||||
preset_prepared[meta_key] = (width, height)
|
||||
else:
|
||||
preset_prepared[meta_key] = items[settings_key] if settings_key in items and items[
|
||||
settings_key] is not None else getattr(modules.config, settings_key)
|
||||
|
||||
preset_prepared[meta_key] = items[settings_key] if settings_key in items and items[settings_key] is not None else getattr(modules.config, settings_key)
|
||||
|
||||
if settings_key == "default_styles" or settings_key == "default_aspect_ratio":
|
||||
preset_prepared[meta_key] = str(preset_prepared[meta_key])
|
||||
|
||||
@ -570,4 +569,4 @@ def get_exif(metadata: str | None, metadata_scheme: str):
|
||||
exif[0x0131] = 'Fooocus v' + fooocus_version.version
|
||||
# 0x927C = MakerNote
|
||||
exif[0x927C] = metadata_scheme
|
||||
return exif
|
||||
return exif
|
||||
|
@ -1,6 +1,6 @@
|
||||
{
|
||||
"default_model": "realisticStockPhoto_v20.safetensors",
|
||||
"default_refiner": "",
|
||||
"default_refiner": "None",
|
||||
"default_refiner_switch": 0.5,
|
||||
"default_loras": [
|
||||
[
|
||||
|
57
webui.py
57
webui.py
@ -15,6 +15,7 @@ import modules.style_sorter as style_sorter
|
||||
import modules.meta_parser
|
||||
import args_manager
|
||||
import copy
|
||||
import launch
|
||||
|
||||
from modules.sdxl_styles import legal_style_names
|
||||
from modules.private_logger import get_current_html_path
|
||||
@ -252,6 +253,11 @@ with shared.gradio_root:
|
||||
|
||||
with gr.Column(scale=1, visible=modules.config.default_advanced_checkbox) as advanced_column:
|
||||
with gr.Tab(label='Setting'):
|
||||
if not args_manager.args.disable_preset_selection:
|
||||
preset_selection = gr.Radio(label='Preset',
|
||||
choices=modules.config.available_presets,
|
||||
value=args_manager.args.preset if args_manager.args.preset else "initial",
|
||||
interactive=True)
|
||||
performance_selection = gr.Radio(label='Performance',
|
||||
choices=flags.Performance.list(),
|
||||
value=modules.config.default_performance)
|
||||
@ -518,13 +524,50 @@ with shared.gradio_root:
|
||||
modules.config.update_files()
|
||||
results = [gr.update(choices=modules.config.model_filenames)]
|
||||
results += [gr.update(choices=['None'] + modules.config.model_filenames)]
|
||||
if not args_manager.args.disable_preset_selection:
|
||||
results += [gr.update(choices=modules.config.available_presets)]
|
||||
for i in range(modules.config.default_max_lora_number):
|
||||
results += [gr.update(interactive=True), gr.update(choices=['None'] + modules.config.lora_filenames), gr.update()]
|
||||
results += [gr.update(interactive=True),
|
||||
gr.update(choices=['None'] + modules.config.lora_filenames), gr.update()]
|
||||
return results
|
||||
|
||||
refresh_files.click(refresh_files_clicked, [], [base_model, refiner_model] + lora_ctrls,
|
||||
refresh_files_output = [base_model, refiner_model]
|
||||
if not args_manager.args.disable_preset_selection:
|
||||
refresh_files_output += [preset_selection]
|
||||
refresh_files.click(refresh_files_clicked, [], refresh_files_output + lora_ctrls,
|
||||
queue=False, show_progress=False)
|
||||
|
||||
state_is_generating = gr.State(False)
|
||||
|
||||
load_data_outputs = [advanced_checkbox, image_number, prompt, negative_prompt, style_selections,
|
||||
performance_selection, overwrite_step, overwrite_switch, aspect_ratios_selection,
|
||||
overwrite_width, overwrite_height, guidance_scale, sharpness, adm_scaler_positive,
|
||||
adm_scaler_negative, adm_scaler_end, refiner_swap_method, adaptive_cfg, base_model,
|
||||
refiner_model, refiner_switch, sampler_name, scheduler_name, seed_random, image_seed,
|
||||
generate_button, load_parameter_button] + freeu_ctrls + lora_ctrls
|
||||
|
||||
if not args_manager.args.disable_preset_selection:
|
||||
def preset_selection_change(preset, is_generating):
|
||||
preset_content = modules.config.try_get_preset_content(preset) if preset != 'initial' else {}
|
||||
preset_prepared = modules.meta_parser.parse_meta_from_preset(preset_content)
|
||||
|
||||
default_model = preset_prepared.get('base_model')
|
||||
previous_default_models = preset_prepared.get('previous_default_models', [])
|
||||
checkpoint_downloads = preset_prepared.get('checkpoint_downloads', {})
|
||||
embeddings_downloads = preset_prepared.get('embeddings_downloads', {})
|
||||
lora_downloads = preset_prepared.get('lora_downloads', {})
|
||||
|
||||
preset_prepared['base_model'], preset_prepared['lora_downloads'] = launch.download_models(
|
||||
default_model, previous_default_models, checkpoint_downloads, embeddings_downloads, lora_downloads)
|
||||
|
||||
if 'prompt' in preset_prepared and preset_prepared.get('prompt') == '':
|
||||
del preset_prepared['prompt']
|
||||
|
||||
return modules.meta_parser.load_parameter_button_click(json.dumps(preset_prepared), is_generating)
|
||||
|
||||
preset_selection.change(preset_selection_change, inputs=[preset_selection, state_is_generating], outputs=load_data_outputs, queue=False, show_progress=True) \
|
||||
.then(fn=style_sorter.sort_styles, inputs=style_selections, outputs=style_selections, queue=False, show_progress=False) \
|
||||
|
||||
performance_selection.change(lambda x: [gr.update(interactive=not flags.Performance.has_restricted_features(x))] * 11 +
|
||||
[gr.update(visible=not flags.Performance.has_restricted_features(x))] * 1 +
|
||||
[gr.update(interactive=not flags.Performance.has_restricted_features(x), value=flags.Performance.has_restricted_features(x))] * 1,
|
||||
@ -600,8 +643,6 @@ with shared.gradio_root:
|
||||
|
||||
ctrls += ip_ctrls
|
||||
|
||||
state_is_generating = gr.State(False)
|
||||
|
||||
def parse_meta(raw_prompt_txt, is_generating):
|
||||
loaded_json = None
|
||||
if is_json(raw_prompt_txt):
|
||||
@ -617,13 +658,6 @@ with shared.gradio_root:
|
||||
|
||||
prompt.input(parse_meta, inputs=[prompt, state_is_generating], outputs=[prompt, generate_button, load_parameter_button], queue=False, show_progress=False)
|
||||
|
||||
load_data_outputs = [advanced_checkbox, image_number, prompt, negative_prompt, style_selections,
|
||||
performance_selection, overwrite_step, overwrite_switch, aspect_ratios_selection,
|
||||
overwrite_width, overwrite_height, guidance_scale, sharpness, adm_scaler_positive,
|
||||
adm_scaler_negative, adm_scaler_end, refiner_swap_method, adaptive_cfg, base_model,
|
||||
refiner_model, refiner_switch, sampler_name, scheduler_name, seed_random, image_seed,
|
||||
generate_button, load_parameter_button] + freeu_ctrls + lora_ctrls
|
||||
|
||||
load_parameter_button.click(modules.meta_parser.load_parameter_button_click, inputs=[prompt, state_is_generating], outputs=load_data_outputs, queue=False, show_progress=False)
|
||||
|
||||
def trigger_metadata_import(filepath, state_is_generating):
|
||||
@ -637,7 +671,6 @@ with shared.gradio_root:
|
||||
|
||||
return modules.meta_parser.load_parameter_button_click(parsed_parameters, state_is_generating)
|
||||
|
||||
|
||||
metadata_import_button.click(trigger_metadata_import, inputs=[metadata_input_image, state_is_generating], outputs=load_data_outputs, queue=False, show_progress=True) \
|
||||
.then(style_sorter.sort_styles, inputs=style_selections, outputs=style_selections, queue=False, show_progress=False)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user