diff --git a/.gitignore b/.gitignore index cafa6db..dcbea2e 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ __pycache__ *.patch *.backup *.corrupted +sorted_styles.json /language/default.json lena.png lena_result.png diff --git a/javascript/localization.js b/javascript/localization.js index 7c6ad83..8fda68e 100644 --- a/javascript/localization.js +++ b/javascript/localization.js @@ -73,6 +73,10 @@ function processNode(node) { }); } +function refresh_style_localization() { + processNode(document.querySelector('.style_selections')); +} + function localizeWholePage() { processNode(gradioApp()); diff --git a/javascript/viewer.js b/javascript/viewer.js index 634ef8f..0e34a93 100644 --- a/javascript/viewer.js +++ b/javascript/viewer.js @@ -7,7 +7,7 @@ function refresh_grid() { if (gridContainer) if (final_gallery) { let rect = final_gallery.getBoundingClientRect(); let cols = Math.ceil((rect.width - 16.0) / rect.height); - if(cols < 2) cols = 2; + if (cols < 2) cols = 2; gridContainer.style.setProperty('--grid-cols', cols); } } @@ -56,3 +56,27 @@ window.addEventListener('resize', (e) => { onUiLoaded(async () => { resized(); }); + +function on_style_selection_blur() { + let target = document.querySelector("#gradio_receiver_style_selections textarea"); + target.value = "on_style_selection_blur " + Math.random(); + let e = new Event("input", {bubbles: true}) + Object.defineProperty(e, "target", {value: target}) + target.dispatchEvent(e); +} + +onUiLoaded(async () => { + let spans = document.querySelectorAll('.aspect_ratios span'); + + spans.forEach(function (span) { + span.innerHTML = span.innerHTML.replace(/</g, '<').replace(/>/g, '>'); + }); + + document.querySelector('.style_selections').addEventListener('focusout', function (event) { + setTimeout(() => { + if (!this.contains(document.activeElement)) { + on_style_selection_blur(); + } + }, 200); + }); +}); diff --git a/modules/async_worker.py b/modules/async_worker.py index aa5ea83..3860e45 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -209,7 +209,7 @@ def worker(): tiled = False inpaint_worker.current_task = None - width, height = aspect_ratios_selection.split('×') + width, height = aspect_ratios_selection.replace('×', ' ').split(' ')[:2] width, height = int(width), int(height) skip_prompt_processing = False diff --git a/modules/config.py b/modules/config.py index 61395f9..8cf77d2 100644 --- a/modules/config.py +++ b/modules/config.py @@ -1,5 +1,6 @@ import os import json +import math import args_manager import modules.flags import modules.sdxl_styles @@ -247,6 +248,17 @@ default_aspect_ratio = get_config_item_or_set_default( validator=lambda x: x in available_aspect_ratios ) + +def add_ratio(x): + a, b = x.replace('*', ' ').split(' ')[:2] + a, b = int(a), int(b) + g = math.gcd(a, b) + return f'{a}×{b} \U00002223 {a // g}:{b // g}' + + +default_aspect_ratio = add_ratio(default_aspect_ratio) +available_aspect_ratios = [add_ratio(x) for x in available_aspect_ratios] + with open(config_path, "w", encoding="utf-8") as json_file: json.dump({k: config_dict[k] for k in always_save_keys}, json_file, indent=4) @@ -264,9 +276,6 @@ os.makedirs(path_outputs, exist_ok=True) model_filenames = [] lora_filenames = [] -available_aspect_ratios = [x.replace('*', '×') for x in available_aspect_ratios] -default_aspect_ratio = default_aspect_ratio.replace('*', '×') - def get_model_filenames(folder_path, name_filter=None): return get_files_from_folder(folder_path, ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch'], name_filter) diff --git a/modules/html.py b/modules/html.py index 8afe9ac..3ec6f2d 100644 --- a/modules/html.py +++ b/modules/html.py @@ -100,6 +100,18 @@ progress::after { overflow: auto !important; } +.aspect_ratios label { + width: 140px !important; +} + +.aspect_ratios label span { + white-space: nowrap !important; +} + +.aspect_ratios label input { + margin-left: -5px !important; +} + ''' progress_html = '''
diff --git a/modules/localization.py b/modules/localization.py index 9403b17..b21d4a5 100644 --- a/modules/localization.py +++ b/modules/localization.py @@ -2,29 +2,30 @@ import json import os +current_translation = {} localization_root = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'language') def localization_js(filename): - data = {} + global current_translation if isinstance(filename, str): full_name = os.path.abspath(os.path.join(localization_root, filename + '.json')) if os.path.exists(full_name): try: with open(full_name, encoding='utf-8') as f: - data = json.load(f) - assert isinstance(data, dict) - for k, v in data.items(): + current_translation = json.load(f) + assert isinstance(current_translation, dict) + for k, v in current_translation.items(): assert isinstance(k, str) assert isinstance(v, str) except Exception as e: print(str(e)) print(f'Failed to load localization file {full_name}') - # data = {k: 'XXX' for k in data.keys()} # use this to see if all texts are covered + # current_translation = {k: 'XXX' for k in current_translation.keys()} # use this to see if all texts are covered - return f"window.localization = {json.dumps(data)}" + return f"window.localization = {json.dumps(current_translation)}" def dump_english_config(components): diff --git a/modules/style_sorter.py b/modules/style_sorter.py new file mode 100644 index 0000000..21742bd --- /dev/null +++ b/modules/style_sorter.py @@ -0,0 +1,54 @@ +import os +import gradio as gr +import modules.localization as localization +import json + + +all_styles = [] + + +def try_load_sorted_styles(style_names, default_selected): + global all_styles + + all_styles = style_names + + try: + if os.path.exists('sorted_styles.json'): + with open('sorted_styles.json', 'rt', encoding='utf-8') as fp: + sorted_styles = json.load(fp) + if len(sorted_styles) == len(all_styles): + if all(x in all_styles for x in sorted_styles): + if all(x in sorted_styles for x in all_styles): + all_styles = sorted_styles + except Exception as e: + print('Load style sorting failed.') + print(e) + + unselected = [y for y in all_styles if y not in default_selected] + all_styles = default_selected + unselected + + return all_styles + + +def sort_styles(selected): + unselected = [y for y in all_styles if y not in selected] + sorted_styles = selected + unselected + try: + with open('sorted_styles.json', 'wt', encoding='utf-8') as fp: + json.dump(sorted_styles, fp, indent=4) + except Exception as e: + print('Write style sorting failed.') + print(e) + return gr.CheckboxGroup.update(choices=sorted_styles) + + +def localization_key(x): + return x + localization.current_translation.get(x, '') + + +def search_styles(selected, query): + unselected = [y for y in all_styles if y not in selected] + matched = [y for y in unselected if query.lower() in localization_key(y).lower()] if len(query) > 0 else [] + unmatched = [y for y in unselected if y not in matched] + sorted_styles = matched + selected + unmatched + return gr.CheckboxGroup.update(choices=sorted_styles) diff --git a/webui.py b/webui.py index 845f706..845c8ad 100644 --- a/webui.py +++ b/webui.py @@ -11,7 +11,9 @@ import modules.constants as constants import modules.flags as flags import modules.gradio_hijack as grh import modules.advanced_parameters as advanced_parameters +import modules.style_sorter as style_sorter import args_manager +import copy from modules.sdxl_styles import legal_style_names from modules.private_logger import get_current_html_path @@ -193,7 +195,8 @@ with shared.gradio_root: choices=['Speed', 'Quality', 'Extreme Speed'], value='Speed') aspect_ratios_selection = gr.Radio(label='Aspect Ratios', choices=modules.config.available_aspect_ratios, - value=modules.config.default_aspect_ratio, info='width × height') + value=modules.config.default_aspect_ratio, info='width × height', + elem_classes='aspect_ratios') image_number = gr.Slider(label='Image Number', minimum=1, maximum=32, step=1, value=modules.config.default_image_number) negative_prompt = gr.Textbox(label='Negative Prompt', show_label=True, placeholder="Type prompt here.", info='Describing what you do not want to see.', lines=2, @@ -222,10 +225,34 @@ with shared.gradio_root: gr.HTML(f'\U0001F4DA History Log') with gr.Tab(label='Style'): + initial_style_sorting = style_sorter.try_load_sorted_styles( + style_names=legal_style_names, default_selected=modules.config.default_styles) + + style_search_bar = gr.Textbox(show_label=False, container=False, + placeholder="\U0001F50E Type here to search styles ...", + value="", + label='Search Styles') style_selections = gr.CheckboxGroup(show_label=False, container=False, - choices=legal_style_names, - value=modules.config.default_styles, - label='Image Style') + choices=initial_style_sorting, + value=copy.deepcopy(modules.config.default_styles), + label='Selected Styles', + elem_classes=['style_selections']) + gradio_receiver_style_selections = gr.Textbox(elem_id='gradio_receiver_style_selections', visible=False) + + style_search_bar.change(style_sorter.search_styles, + inputs=[style_selections, style_search_bar], + outputs=style_selections, + queue=False, + show_progress=False).then( + lambda: None, _js='()=>{refresh_style_localization();}') + + gradio_receiver_style_selections.input(style_sorter.sort_styles, + inputs=style_selections, + outputs=style_selections, + queue=False, + show_progress=False).then( + lambda: None, _js='()=>{refresh_style_localization();}') + with gr.Tab(label='Model'): with gr.Row(): base_model = gr.Dropdown(label='Base Model (SDXL only)', choices=modules.config.model_filenames, value=modules.config.default_base_model_name, show_label=True)