get each image during generating

get each image during generating
This commit is contained in:
lllyasviel 2023-10-26 17:17:05 -07:00
parent f6eee62520
commit 4607316c2f
4 changed files with 44 additions and 14 deletions

View File

@ -1 +1 @@
version = '2.1.749' version = '2.1.750'

View File

@ -3,10 +3,11 @@ import threading
buffer = [] buffer = []
outputs = [] outputs = []
global_results = []
def worker(): def worker():
global buffer, outputs global buffer, outputs, global_results
import traceback import traceback
import numpy as np import numpy as np
@ -47,6 +48,20 @@ def worker():
print(f'[Fooocus] {text}') print(f'[Fooocus] {text}')
outputs.append(['preview', (number, text, None)]) outputs.append(['preview', (number, text, None)])
def yield_result(imgs, do_not_show_finished_images=False):
global global_results
if not isinstance(imgs, list):
imgs = [imgs]
global_results = global_results + imgs
if do_not_show_finished_images:
return
outputs.append(['results', global_results])
return
@torch.no_grad() @torch.no_grad()
@torch.inference_mode() @torch.inference_mode()
def handler(args): def handler(args):
@ -356,7 +371,7 @@ def worker():
if direct_return: if direct_return:
d = [('Upscale (Fast)', '2x')] d = [('Upscale (Fast)', '2x')]
log(uov_input_image, d, single_line_number=1) log(uov_input_image, d, single_line_number=1)
outputs.append(['results', [uov_input_image]]) yield_result(uov_input_image, do_not_show_finished_images=True)
return return
tiled = True tiled = True
@ -408,7 +423,7 @@ def worker():
pipeline.final_unet.model.diffusion_model.in_inpaint = True pipeline.final_unet.model.diffusion_model.in_inpaint = True
if advanced_parameters.debugging_cn_preprocessor: if advanced_parameters.debugging_cn_preprocessor:
outputs.append(['results', inpaint_worker.current_task.visualize_mask_processing()]) yield_result(inpaint_worker.current_task.visualize_mask_processing(), do_not_show_finished_images=True)
return return
progressbar(13, 'VAE Inpaint encoding ...') progressbar(13, 'VAE Inpaint encoding ...')
@ -454,7 +469,7 @@ def worker():
cn_img = HWC3(cn_img) cn_img = HWC3(cn_img)
task[0] = core.numpy_to_pytorch(cn_img) task[0] = core.numpy_to_pytorch(cn_img)
if advanced_parameters.debugging_cn_preprocessor: if advanced_parameters.debugging_cn_preprocessor:
outputs.append(['results', [cn_img]]) yield_result(cn_img, do_not_show_finished_images=True)
return return
for task in cn_tasks[flags.cn_cpds]: for task in cn_tasks[flags.cn_cpds]:
cn_img, cn_stop, cn_weight = task cn_img, cn_stop, cn_weight = task
@ -463,7 +478,7 @@ def worker():
cn_img = HWC3(cn_img) cn_img = HWC3(cn_img)
task[0] = core.numpy_to_pytorch(cn_img) task[0] = core.numpy_to_pytorch(cn_img)
if advanced_parameters.debugging_cn_preprocessor: if advanced_parameters.debugging_cn_preprocessor:
outputs.append(['results', [cn_img]]) yield_result(cn_img, do_not_show_finished_images=True)
return return
for task in cn_tasks[flags.cn_ip]: for task in cn_tasks[flags.cn_ip]:
cn_img, cn_stop, cn_weight = task cn_img, cn_stop, cn_weight = task
@ -474,7 +489,7 @@ def worker():
task[0] = ip_adapter.preprocess(cn_img) task[0] = ip_adapter.preprocess(cn_img)
if advanced_parameters.debugging_cn_preprocessor: if advanced_parameters.debugging_cn_preprocessor:
outputs.append(['results', [cn_img]]) yield_result(cn_img, do_not_show_finished_images=True)
return return
if len(cn_tasks[flags.cn_ip]) > 0: if len(cn_tasks[flags.cn_ip]) > 0:
@ -490,7 +505,6 @@ def worker():
advanced_parameters.freeu_s2 advanced_parameters.freeu_s2
) )
results = []
all_steps = steps * image_number all_steps = steps * image_number
preparation_time = time.perf_counter() - execution_start_time preparation_time = time.perf_counter() - execution_start_time
@ -566,7 +580,7 @@ def worker():
d.append((f'LoRA [{n}] weight', w)) d.append((f'LoRA [{n}] weight', w))
log(x, d, single_line_number=3) log(x, d, single_line_number=3)
results += imgs yield_result(imgs, do_not_show_finished_images=len(tasks) == 1)
except fcbh.model_management.InterruptProcessingException as e: except fcbh.model_management.InterruptProcessingException as e:
if shared.last_stop == 'skip': if shared.last_stop == 'skip':
print('User skipped') print('User skipped')
@ -578,8 +592,6 @@ def worker():
execution_time = time.perf_counter() - execution_start_time execution_time = time.perf_counter() - execution_start_time
print(f'Generating and saving time: {execution_time:.2f} seconds') print(f'Generating and saving time: {execution_time:.2f} seconds')
outputs.append(['results', results])
pipeline.prepare_text_encoder(async_call=True) pipeline.prepare_text_encoder(async_call=True)
return return
@ -591,7 +603,9 @@ def worker():
handler(task) handler(task)
except: except:
traceback.print_exc() traceback.print_exc()
outputs.append(['results', []]) if len(buffer) == 0:
outputs.append(['finish', global_results])
global_results = []
pass pass

View File

@ -1,3 +1,7 @@
# 2.1.750
* New UI: now you can get each image during generating.
# 2.1.743 # 2.1.743
* Improved GPT2 by removing some tokens that may corrupt styles. * Improved GPT2 by removing some tokens that may corrupt styles.

View File

@ -22,10 +22,13 @@ from modules.auth import auth_enabled, check_auth
def generate_clicked(*args): def generate_clicked(*args):
# outputs=[progress_html, progress_window, progress_gallery, gallery]
execution_start_time = time.perf_counter() execution_start_time = time.perf_counter()
yield gr.update(visible=True, value=modules.html.make_progress_html(1, 'Initializing ...')), \ yield gr.update(visible=True, value=modules.html.make_progress_html(1, 'Initializing ...')), \
gr.update(visible=True, value=None), \ gr.update(visible=True, value=None), \
gr.update(visible=False, value=None), \
gr.update(visible=False) gr.update(visible=False)
worker.buffer.append(list(args)) worker.buffer.append(list(args))
@ -39,9 +42,16 @@ def generate_clicked(*args):
percentage, title, image = product percentage, title, image = product
yield gr.update(visible=True, value=modules.html.make_progress_html(percentage, title)), \ yield gr.update(visible=True, value=modules.html.make_progress_html(percentage, title)), \
gr.update(visible=True, value=image) if image is not None else gr.update(), \ gr.update(visible=True, value=image) if image is not None else gr.update(), \
gr.update(), \
gr.update(visible=False) gr.update(visible=False)
if flag == 'results': if flag == 'results':
yield gr.update(visible=True), \
gr.update(visible=True), \
gr.update(visible=True, value=product), \
gr.update(visible=False)
if flag == 'finish':
yield gr.update(visible=False), \ yield gr.update(visible=False), \
gr.update(visible=False), \
gr.update(visible=False), \ gr.update(visible=False), \
gr.update(visible=True, value=product) gr.update(visible=True, value=product)
finished = True finished = True
@ -60,7 +70,9 @@ shared.gradio_root = gr.Blocks(
with shared.gradio_root: with shared.gradio_root:
with gr.Row(): with gr.Row():
with gr.Column(scale=2): with gr.Column(scale=2):
progress_window = grh.Image(label='Preview', show_label=True, height=640, visible=False) with gr.Row():
progress_window = grh.Image(label='Preview', show_label=True, height=640, visible=False)
progress_gallery = gr.Gallery(label='Finished Images', show_label=True, object_fit='contain', height=640, visible=False)
progress_html = gr.HTML(value=modules.html.make_progress_html(32, 'Progress 32%'), visible=False, elem_id='progress-bar', elem_classes='progress-bar') progress_html = gr.HTML(value=modules.html.make_progress_html(32, 'Progress 32%'), visible=False, elem_id='progress-bar', elem_classes='progress-bar')
gallery = gr.Gallery(label='Gallery', show_label=False, object_fit='contain', height=745, visible=True, elem_classes='resizable_area') gallery = gr.Gallery(label='Gallery', show_label=False, object_fit='contain', height=745, visible=True, elem_classes='resizable_area')
with gr.Row(elem_classes='type_row'): with gr.Row(elem_classes='type_row'):
@ -356,7 +368,7 @@ with shared.gradio_root:
generate_button.click(lambda: (gr.update(visible=True, interactive=True), gr.update(visible=True, interactive=True), gr.update(visible=False), []), outputs=[stop_button, skip_button, generate_button, gallery]) \ generate_button.click(lambda: (gr.update(visible=True, interactive=True), gr.update(visible=True, interactive=True), gr.update(visible=False), []), outputs=[stop_button, skip_button, generate_button, gallery]) \
.then(fn=refresh_seed, inputs=[seed_random, image_seed], outputs=image_seed) \ .then(fn=refresh_seed, inputs=[seed_random, image_seed], outputs=image_seed) \
.then(advanced_parameters.set_all_advanced_parameters, inputs=adps) \ .then(advanced_parameters.set_all_advanced_parameters, inputs=adps) \
.then(fn=generate_clicked, inputs=ctrls, outputs=[progress_html, progress_window, gallery]) \ .then(fn=generate_clicked, inputs=ctrls, outputs=[progress_html, progress_window, progress_gallery, gallery]) \
.then(lambda: (gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)), outputs=[generate_button, stop_button, skip_button]) \ .then(lambda: (gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)), outputs=[generate_button, stop_button, skip_button]) \
.then(fn=None, _js='playNotification') .then(fn=None, _js='playNotification')