[Fooocus 2.0.50] Variation/Upscale (Midjourney Toolbar) (#389)

This commit is contained in:
lllyasviel 2023-09-16 03:29:41 -07:00 committed by GitHub
parent 58c29aed00
commit 8ef31d33af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 446 additions and 100 deletions

1
.gitignore vendored
View File

@ -6,7 +6,6 @@ __pycache__
lena.png lena.png
lena_result.png lena_result.png
lena_test.py lena_test.py
!taesdxl_decoder.pth
/repositories /repositories
/venv /venv
/tmp /tmp

View File

@ -1 +1 @@
version = '2.0.19' version = '2.0.50'

View File

@ -9,7 +9,7 @@ import fooocus_version
from modules.launch_util import is_installed, run, python, \ from modules.launch_util import is_installed, run, python, \
run_pip, repo_dir, git_clone, requirements_met, script_path, dir_repos run_pip, repo_dir, git_clone, requirements_met, script_path, dir_repos
from modules.model_loader import load_file_from_url from modules.model_loader import load_file_from_url
from modules.path import modelfile_path, lorafile_path, vae_approx_path, fooocus_expansion_path from modules.path import modelfile_path, lorafile_path, vae_approx_path, fooocus_expansion_path, upscale_models_path
REINSTALL_ALL = False REINSTALL_ALL = False
@ -67,8 +67,14 @@ lora_filenames = [
] ]
vae_approx_filenames = [ vae_approx_filenames = [
('taesdxl_decoder.pth', ('xlvaeapp.pth',
'https://huggingface.co/lllyasviel/misc/resolve/main/taesdxl_decoder.pth') 'https://huggingface.co/lllyasviel/misc/resolve/main/xlvaeapp.pth')
]
upscaler_filenames = [
('fooocus_upscaler_s409985e5.bin',
'https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_upscaler_s409985e5.bin')
] ]
@ -79,6 +85,8 @@ def download_models():
load_file_from_url(url=url, model_dir=lorafile_path, file_name=file_name) load_file_from_url(url=url, model_dir=lorafile_path, file_name=file_name)
for file_name, url in vae_approx_filenames: for file_name, url in vae_approx_filenames:
load_file_from_url(url=url, model_dir=vae_approx_path, file_name=file_name) load_file_from_url(url=url, model_dir=vae_approx_path, file_name=file_name)
for file_name, url in upscaler_filenames:
load_file_from_url(url=url, model_dir=upscale_models_path, file_name=file_name)
load_file_from_url( load_file_from_url(
url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_expansion.bin', url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_expansion.bin',

View File

@ -1,7 +1,6 @@
import threading import threading
import torch import torch
buffer = [] buffer = []
outputs = [] outputs = []
@ -14,14 +13,18 @@ def worker():
import random import random
import copy import copy
import modules.default_pipeline as pipeline import modules.default_pipeline as pipeline
import modules.core as core
import modules.flags as flags
import modules.path import modules.path
import modules.patch import modules.patch
import modules.virtual_memory as virtual_memory import modules.virtual_memory as virtual_memory
import comfy.model_management
from modules.sdxl_styles import apply_style, aspect_ratios, fooocus_expansion from modules.sdxl_styles import apply_style, aspect_ratios, fooocus_expansion
from modules.private_logger import log from modules.private_logger import log
from modules.expansion import safe_str from modules.expansion import safe_str
from modules.util import join_prompts, remove_empty_str from modules.util import join_prompts, remove_empty_str, HWC3, resize_image
from modules.upscaler import perform_upscale
try: try:
async_gradio_app = shared.gradio_root async_gradio_app = shared.gradio_root
@ -37,16 +40,21 @@ def worker():
outputs.append(['preview', (number, text, None)]) outputs.append(['preview', (number, text, None)])
@torch.no_grad() @torch.no_grad()
@torch.inference_mode()
def handler(task): def handler(task):
prompt, negative_prompt, style_selections, performance_selction, \ prompt, negative_prompt, style_selections, performance_selction, \
aspect_ratios_selction, image_number, image_seed, sharpness, \ aspect_ratios_selction, image_number, image_seed, sharpness, \
base_model_name, refiner_model_name, \ base_model_name, refiner_model_name, \
l1, w1, l2, w2, l3, w3, l4, w4, l5, w5 = task l1, w1, l2, w2, l3, w3, l4, w4, l5, w5, \
input_image_checkbox, \
uov_method, uov_input_image = task
loras = [(l1, w1), (l2, w2), (l3, w3), (l4, w4), (l5, w5)] loras = [(l1, w1), (l2, w2), (l3, w3), (l4, w4), (l5, w5)]
raw_style_selections = copy.deepcopy(style_selections) raw_style_selections = copy.deepcopy(style_selections)
uov_method = uov_method.lower()
if fooocus_expansion in style_selections: if fooocus_expansion in style_selections:
use_expansion = True use_expansion = True
style_selections.remove(fooocus_expansion) style_selections.remove(fooocus_expansion)
@ -54,8 +62,80 @@ def worker():
use_expansion = False use_expansion = False
use_style = len(style_selections) > 0 use_style = len(style_selections) > 0
modules.patch.sharpness = sharpness modules.patch.sharpness = sharpness
initial_latent = None
denoising_strength = 1.0
tiled = False
if performance_selction == 'Speed':
steps = 30
switch = 20
else:
steps = 60
switch = 40
pipeline.clear_all_caches() # save memory
width, height = aspect_ratios[aspect_ratios_selction]
if input_image_checkbox:
progressbar(0, 'Image processing ...')
if uov_method != flags.disabled and uov_input_image is not None:
uov_input_image = HWC3(uov_input_image)
H, W, C = uov_input_image.shape
if 'vary' in uov_method:
if H * W + 8 < width * height or float(abs(H * width - W * height)) > 1.5 * float(max(H, W, width, height)):
uov_input_image = resize_image(uov_input_image, width=width, height=height)
print(f'Aspect ratio corrected - users are uploading their own images.')
if 'subtle' in uov_method:
denoising_strength = 0.5
if 'strong' in uov_method:
denoising_strength = 0.85
initial_pixels = core.numpy_to_pytorch(uov_input_image)
progressbar(0, 'VAE encoding ...')
initial_latent = core.encode_vae(vae=pipeline.xl_base_patched.vae, pixels=initial_pixels)
B, C, H, W = initial_latent['samples'].shape
width = W * 8
height = H * 8
print(f'Final resolution is {str((height, width))}.')
elif 'upscale' in uov_method:
if '1.5x' in uov_method:
f = 1.5
elif '2x' in uov_method:
f = 2.0
else:
f = 1.0
width = int(W * f)
height = int(H * f)
image_is_super_large = width * height > 2800 * 2800
progressbar(0, f'Upscaling image from {str((H, W))} to {str((height, width))}...')
uov_input_image = core.numpy_to_pytorch(uov_input_image)
uov_input_image = perform_upscale(uov_input_image)
uov_input_image = core.pytorch_to_numpy(uov_input_image)[0]
uov_input_image = resize_image(uov_input_image, width=width, height=height)
print(f'Image upscaled.')
if 'fast' in uov_method or image_is_super_large:
if 'fast' not in uov_method:
print('Image is too large. Directly returned the SR image. '
'Usually directly return SR image at 4K resolution '
'yields better results than SDXL diffusion.')
outputs.append(['results', [uov_input_image]])
return
tiled = True
denoising_strength = 1.0 - 0.618
steps = int(steps * 0.618)
switch = int(steps * 0.67)
initial_pixels = core.numpy_to_pytorch(uov_input_image)
progressbar(0, 'VAE encoding ...')
initial_latent = core.encode_vae(vae=pipeline.xl_base_patched.vae, pixels=initial_pixels, tiled=True)
B, C, H, W = initial_latent['samples'].shape
width = W * 8
height = H * 8
print(f'Final resolution is {str((height, width))}.')
progressbar(1, 'Initializing ...') progressbar(1, 'Initializing ...')
@ -152,16 +232,6 @@ def worker():
virtual_memory.try_move_to_virtual_memory(pipeline.xl_refiner.clip.cond_stage_model) virtual_memory.try_move_to_virtual_memory(pipeline.xl_refiner.clip.cond_stage_model)
if performance_selction == 'Speed':
steps = 30
switch = 20
else:
steps = 60
switch = 40
pipeline.clear_all_caches() # save memory
width, height = aspect_ratios[aspect_ratios_selction]
results = [] results = []
all_steps = steps * image_number all_steps = steps * image_number
@ -174,35 +244,43 @@ def worker():
outputs.append(['preview', (13, 'Starting tasks ...', None)]) outputs.append(['preview', (13, 'Starting tasks ...', None)])
for current_task_id, task in enumerate(tasks): for current_task_id, task in enumerate(tasks):
imgs = pipeline.process_diffusion( try:
positive_cond=task['c'], imgs = pipeline.process_diffusion(
negative_cond=task['uc'], positive_cond=task['c'],
steps=steps, negative_cond=task['uc'],
switch=switch, steps=steps,
width=width, switch=switch,
height=height, width=width,
image_seed=task['task_seed'], height=height,
callback=callback) image_seed=task['task_seed'],
callback=callback,
latent=initial_latent,
denoise=denoising_strength,
tiled=tiled
)
for x in imgs: for x in imgs:
d = [ d = [
('Prompt', raw_prompt), ('Prompt', raw_prompt),
('Negative Prompt', raw_negative_prompt), ('Negative Prompt', raw_negative_prompt),
('Fooocus V2 Expansion', task['expansion']), ('Fooocus V2 Expansion', task['expansion']),
('Styles', str(raw_style_selections)), ('Styles', str(raw_style_selections)),
('Performance', performance_selction), ('Performance', performance_selction),
('Resolution', str((width, height))), ('Resolution', str((width, height))),
('Sharpness', sharpness), ('Sharpness', sharpness),
('Base Model', base_model_name), ('Base Model', base_model_name),
('Refiner Model', refiner_model_name), ('Refiner Model', refiner_model_name),
('Seed', task['task_seed']) ('Seed', task['task_seed'])
] ]
for n, w in loras: for n, w in loras:
if n != 'None': if n != 'None':
d.append((f'LoRA [{n}] weight', w)) d.append((f'LoRA [{n}] weight', w))
log(x, d, single_line_number=3) log(x, d, single_line_number=3)
results += imgs results += imgs
except comfy.model_management.InterruptProcessingException as e:
print('User stopped')
break
outputs.append(['results', results]) outputs.append(['results', results])
return return

View File

@ -8,7 +8,7 @@ import comfy.model_management
import comfy.utils import comfy.utils
from comfy.sd import load_checkpoint_guess_config from comfy.sd import load_checkpoint_guess_config
from nodes import VAEDecode, EmptyLatentImage from nodes import VAEDecode, EmptyLatentImage, VAEEncode, VAEEncodeTiled, VAEDecodeTiled
from comfy.sample import prepare_mask, broadcast_cond, load_additional_models, cleanup_additional_models from comfy.sample import prepare_mask, broadcast_cond, load_additional_models, cleanup_additional_models
from comfy.model_base import SDXLRefiner from comfy.model_base import SDXLRefiner
from modules.samplers_advanced import KSampler, KSamplerWithRefiner from modules.samplers_advanced import KSampler, KSamplerWithRefiner
@ -18,6 +18,9 @@ from modules.patch import patch_all
patch_all() patch_all()
opEmptyLatentImage = EmptyLatentImage() opEmptyLatentImage = EmptyLatentImage()
opVAEDecode = VAEDecode() opVAEDecode = VAEDecode()
opVAEEncode = VAEEncode()
opVAEDecodeTiled = VAEDecodeTiled()
opVAEEncodeTiled = VAEEncodeTiled()
class StableDiffusionModel: class StableDiffusionModel:
@ -45,12 +48,14 @@ class StableDiffusionModel:
@torch.no_grad() @torch.no_grad()
@torch.inference_mode()
def load_model(ckpt_filename): def load_model(ckpt_filename):
unet, clip, vae, clip_vision = load_checkpoint_guess_config(ckpt_filename) unet, clip, vae, clip_vision = load_checkpoint_guess_config(ckpt_filename)
return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, model_filename=ckpt_filename) return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, model_filename=ckpt_filename)
@torch.no_grad() @torch.no_grad()
@torch.inference_mode()
def load_lora(model, lora_filename, strength_model=1.0, strength_clip=1.0): def load_lora(model, lora_filename, strength_model=1.0, strength_clip=1.0):
if strength_model == 0 and strength_clip == 0: if strength_model == 0 and strength_clip == 0:
return model return model
@ -61,40 +66,87 @@ def load_lora(model, lora_filename, strength_model=1.0, strength_clip=1.0):
@torch.no_grad() @torch.no_grad()
@torch.inference_mode()
def generate_empty_latent(width=1024, height=1024, batch_size=1): def generate_empty_latent(width=1024, height=1024, batch_size=1):
return opEmptyLatentImage.generate(width=width, height=height, batch_size=batch_size)[0] return opEmptyLatentImage.generate(width=width, height=height, batch_size=batch_size)[0]
@torch.no_grad() @torch.no_grad()
def decode_vae(vae, latent_image): @torch.inference_mode()
return opVAEDecode.decode(samples=latent_image, vae=vae)[0] def decode_vae(vae, latent_image, tiled=False):
return (opVAEDecodeTiled if tiled else opVAEDecode).decode(samples=latent_image, vae=vae)[0]
def get_previewer(device, latent_format):
from latent_preview import TAESD, TAESDPreviewerImpl
taesd_decoder_path = os.path.abspath(os.path.realpath(os.path.join("models", "vae_approx",
latent_format.taesd_decoder_name)))
if not os.path.exists(taesd_decoder_path):
print(f"Warning: TAESD previews enabled, but could not find {taesd_decoder_path}")
return None
taesd = TAESD(None, taesd_decoder_path).to(device)
def preview_function(x0, step, total_steps):
global cv2_is_top
with torch.no_grad():
x_sample = taesd.decoder(torch.nn.functional.avg_pool2d(x0, kernel_size=(2, 2))).detach() * 255.0
x_sample = einops.rearrange(x_sample, 'b c h w -> b h w c')
x_sample = x_sample.cpu().numpy().clip(0, 255).astype(np.uint8)
return x_sample[0]
taesd.preview = preview_function
return taesd
@torch.no_grad() @torch.no_grad()
@torch.inference_mode()
def encode_vae(vae, pixels, tiled=False):
return (opVAEEncodeTiled if tiled else opVAEEncode).encode(pixels=pixels, vae=vae)[0]
class VAEApprox(torch.nn.Module):
def __init__(self):
super(VAEApprox, self).__init__()
self.conv1 = torch.nn.Conv2d(4, 8, (7, 7))
self.conv2 = torch.nn.Conv2d(8, 16, (5, 5))
self.conv3 = torch.nn.Conv2d(16, 32, (3, 3))
self.conv4 = torch.nn.Conv2d(32, 64, (3, 3))
self.conv5 = torch.nn.Conv2d(64, 32, (3, 3))
self.conv6 = torch.nn.Conv2d(32, 16, (3, 3))
self.conv7 = torch.nn.Conv2d(16, 8, (3, 3))
self.conv8 = torch.nn.Conv2d(8, 3, (3, 3))
self.current_type = None
def forward(self, x):
extra = 11
x = torch.nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2))
x = torch.nn.functional.pad(x, (extra, extra, extra, extra))
for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8]:
x = layer(x)
x = torch.nn.functional.leaky_relu(x, 0.1)
return x
VAE_approx_model = None
@torch.no_grad()
@torch.inference_mode()
def get_previewer(device, latent_format):
global VAE_approx_model
if VAE_approx_model is None:
from modules.path import vae_approx_path
vae_approx_filename = os.path.join(vae_approx_path, 'xlvaeapp.pth')
sd = torch.load(vae_approx_filename, map_location='cpu')
VAE_approx_model = VAEApprox()
VAE_approx_model.load_state_dict(sd)
del sd
VAE_approx_model.eval()
if comfy.model_management.should_use_fp16():
VAE_approx_model.half()
VAE_approx_model.current_type = torch.float16
else:
VAE_approx_model.float()
VAE_approx_model.current_type = torch.float32
VAE_approx_model.to(comfy.model_management.get_torch_device())
@torch.no_grad()
@torch.inference_mode()
def preview_function(x0, step, total_steps):
with torch.no_grad():
x_sample = x0.to(VAE_approx_model.current_type)
x_sample = VAE_approx_model(x_sample) * 127.5 + 127.5
x_sample = einops.rearrange(x_sample, 'b c h w -> b h w c')[0]
x_sample = x_sample.cpu().numpy().clip(0, 255).astype(np.uint8)
return x_sample
return preview_function
@torch.no_grad()
@torch.inference_mode()
def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sampler_name='dpmpp_2m_sde_gpu', def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sampler_name='dpmpp_2m_sde_gpu',
scheduler='karras', denoise=1.0, disable_noise=False, start_step=None, last_step=None, scheduler='karras', denoise=1.0, disable_noise=False, start_step=None, last_step=None,
force_full_denoise=False, callback_function=None): force_full_denoise=False, callback_function=None):
@ -124,8 +176,8 @@ def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sa
def callback(step, x0, x, total_steps): def callback(step, x0, x, total_steps):
y = None y = None
if previewer and step % 3 == 0: if previewer is not None:
y = previewer.preview(x0, step, total_steps) y = previewer(x0, step, total_steps)
if callback_function is not None: if callback_function is not None:
callback_function(step, x0, x, total_steps, y) callback_function(step, x0, x, total_steps, y)
pbar.update_absolute(step + 1, total_steps, None) pbar.update_absolute(step + 1, total_steps, None)
@ -166,6 +218,7 @@ def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sa
@torch.no_grad() @torch.no_grad()
@torch.inference_mode()
def ksampler_with_refiner(model, positive, negative, refiner, refiner_positive, refiner_negative, latent, def ksampler_with_refiner(model, positive, negative, refiner, refiner_positive, refiner_negative, latent,
seed=None, steps=30, refiner_switch_step=20, cfg=7.0, sampler_name='dpmpp_2m_sde_gpu', seed=None, steps=30, refiner_switch_step=20, cfg=7.0, sampler_name='dpmpp_2m_sde_gpu',
scheduler='karras', denoise=1.0, disable_noise=False, start_step=None, last_step=None, scheduler='karras', denoise=1.0, disable_noise=False, start_step=None, last_step=None,
@ -196,8 +249,8 @@ def ksampler_with_refiner(model, positive, negative, refiner, refiner_positive,
def callback(step, x0, x, total_steps): def callback(step, x0, x, total_steps):
y = None y = None
if previewer and step % 3 == 0: if previewer is not None:
y = previewer.preview(x0, step, total_steps) y = previewer(x0, step, total_steps)
if callback_function is not None: if callback_function is not None:
callback_function(step, x0, x, total_steps, y) callback_function(step, x0, x, total_steps, y)
pbar.update_absolute(step + 1, total_steps, None) pbar.update_absolute(step + 1, total_steps, None)
@ -243,5 +296,16 @@ def ksampler_with_refiner(model, positive, negative, refiner, refiner_positive,
@torch.no_grad() @torch.no_grad()
def image_to_numpy(x): @torch.inference_mode()
def pytorch_to_numpy(x):
return [np.clip(255. * y.cpu().numpy(), 0, 255).astype(np.uint8) for y in x] return [np.clip(255. * y.cpu().numpy(), 0, 255).astype(np.uint8) for y in x]
@torch.no_grad()
@torch.inference_mode()
def numpy_to_pytorch(x):
y = x.astype(np.float32) / 255.0
y = y[None]
y = np.ascontiguousarray(y.copy())
y = torch.from_numpy(y).float()
return y

View File

@ -3,7 +3,6 @@ import os
import torch import torch
import modules.path import modules.path
import modules.virtual_memory as virtual_memory import modules.virtual_memory as virtual_memory
import comfy.model_management as model_management
from comfy.model_base import SDXL, SDXLRefiner from comfy.model_base import SDXL, SDXLRefiner
from modules.patch import cfg_patched from modules.patch import cfg_patched
@ -20,6 +19,8 @@ xl_base_patched: core.StableDiffusionModel = None
xl_base_patched_hash = '' xl_base_patched_hash = ''
@torch.no_grad()
@torch.inference_mode()
def refresh_base_model(name): def refresh_base_model(name):
global xl_base, xl_base_hash, xl_base_patched, xl_base_patched_hash global xl_base, xl_base_hash, xl_base_patched, xl_base_patched_hash
@ -51,6 +52,8 @@ def refresh_base_model(name):
return return
@torch.no_grad()
@torch.inference_mode()
def refresh_refiner_model(name): def refresh_refiner_model(name):
global xl_refiner, xl_refiner_hash global xl_refiner, xl_refiner_hash
@ -86,6 +89,8 @@ def refresh_refiner_model(name):
return return
@torch.no_grad()
@torch.inference_mode()
def refresh_loras(loras): def refresh_loras(loras):
global xl_base, xl_base_patched, xl_base_patched_hash global xl_base, xl_base_patched, xl_base_patched_hash
if xl_base_patched_hash == str(loras): if xl_base_patched_hash == str(loras):
@ -106,6 +111,7 @@ def refresh_loras(loras):
@torch.no_grad() @torch.no_grad()
@torch.inference_mode()
def clip_encode_single(clip, text, verbose=False): def clip_encode_single(clip, text, verbose=False):
cached = clip.fcs_cond_cache.get(text, None) cached = clip.fcs_cond_cache.get(text, None)
if cached is not None: if cached is not None:
@ -121,6 +127,7 @@ def clip_encode_single(clip, text, verbose=False):
@torch.no_grad() @torch.no_grad()
@torch.inference_mode()
def clip_encode(sd, texts, pool_top_k=1): def clip_encode(sd, texts, pool_top_k=1):
if sd is None: if sd is None:
return None return None
@ -145,6 +152,7 @@ def clip_encode(sd, texts, pool_top_k=1):
@torch.no_grad() @torch.no_grad()
@torch.inference_mode()
def clear_sd_cond_cache(sd): def clear_sd_cond_cache(sd):
if sd is None: if sd is None:
return None return None
@ -155,11 +163,14 @@ def clear_sd_cond_cache(sd):
@torch.no_grad() @torch.no_grad()
@torch.inference_mode()
def clear_all_caches(): def clear_all_caches():
clear_sd_cond_cache(xl_base_patched) clear_sd_cond_cache(xl_base_patched)
clear_sd_cond_cache(xl_refiner) clear_sd_cond_cache(xl_refiner)
@torch.no_grad()
@torch.inference_mode()
def refresh_everything(refiner_model_name, base_model_name, loras): def refresh_everything(refiner_model_name, base_model_name, loras):
refresh_refiner_model(refiner_model_name) refresh_refiner_model(refiner_model_name)
if xl_refiner is not None: if xl_refiner is not None:
@ -184,6 +195,7 @@ expansion = FooocusExpansion()
@torch.no_grad() @torch.no_grad()
@torch.inference_mode()
def patch_all_models(): def patch_all_models():
assert xl_base is not None assert xl_base is not None
assert xl_base_patched is not None assert xl_base_patched is not None
@ -198,14 +210,18 @@ def patch_all_models():
@torch.no_grad() @torch.no_grad()
def process_diffusion(positive_cond, negative_cond, steps, switch, width, height, image_seed, callback): @torch.inference_mode()
def process_diffusion(positive_cond, negative_cond, steps, switch, width, height, image_seed, callback, latent=None, denoise=1.0, tiled=False):
patch_all_models() patch_all_models()
if xl_refiner is not None: if xl_refiner is not None:
virtual_memory.try_move_to_virtual_memory(xl_refiner.unet.model) virtual_memory.try_move_to_virtual_memory(xl_refiner.unet.model)
virtual_memory.load_from_virtual_memory(xl_base.unet.model) virtual_memory.load_from_virtual_memory(xl_base.unet.model)
empty_latent = core.generate_empty_latent(width=width, height=height, batch_size=1) if latent is None:
empty_latent = core.generate_empty_latent(width=width, height=height, batch_size=1)
else:
empty_latent = latent
if xl_refiner is not None: if xl_refiner is not None:
sampled_latent = core.ksampler_with_refiner( sampled_latent = core.ksampler_with_refiner(
@ -219,6 +235,7 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height
latent=empty_latent, latent=empty_latent,
steps=steps, start_step=0, last_step=steps, disable_noise=False, force_full_denoise=True, steps=steps, start_step=0, last_step=steps, disable_noise=False, force_full_denoise=True,
seed=image_seed, seed=image_seed,
denoise=denoise,
callback_function=callback callback_function=callback
) )
else: else:
@ -229,9 +246,10 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height
latent=empty_latent, latent=empty_latent,
steps=steps, start_step=0, last_step=steps, disable_noise=False, force_full_denoise=True, steps=steps, start_step=0, last_step=steps, disable_noise=False, force_full_denoise=True,
seed=image_seed, seed=image_seed,
denoise=denoise,
callback_function=callback callback_function=callback
) )
decoded_latent = core.decode_vae(vae=xl_base_patched.vae, latent_image=sampled_latent) decoded_latent = core.decode_vae(vae=xl_base_patched.vae, latent_image=sampled_latent, tiled=tiled)
images = core.image_to_numpy(decoded_latent) images = core.pytorch_to_numpy(decoded_latent)
return images return images

10
modules/flags.py Normal file
View File

@ -0,0 +1,10 @@
disabled = 'Disabled'
subtle_variation = 'Vary (Subtle)'
strong_variation = 'Vary (Strong)'
upscale_15 = 'Upscale (1.5x)'
upscale_2 = 'Upscale (2x)'
upscale_fast = 'Upscale (Fast 2x)'
uov_list = [
disabled, subtle_variation, strong_variation, upscale_15, upscale_2, upscale_fast
]

View File

@ -83,6 +83,14 @@ progress::after {
box-shadow: none !important; box-shadow: none !important;
} }
.advanced_check_row{
width: 250px !important;
}
.min_check{
min-width: min(1px, 100%) !important;
}
''' '''
progress_html = ''' progress_html = '''
<div class="loader-container"> <div class="loader-container">

View File

@ -70,6 +70,35 @@ def sdxl_encode_adm_patched(self, **kwargs):
return torch.cat((clip_pooled.to(flat.device), flat), dim=1) return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
def sdxl_refiner_encode_adm_patched(self, **kwargs):
clip_pooled = kwargs["pooled_output"]
width = kwargs.get("width", 768)
height = kwargs.get("height", 768)
crop_w = kwargs.get("crop_w", 0)
crop_h = kwargs.get("crop_h", 0)
if kwargs.get("prompt_type", "") == "negative":
aesthetic_score = kwargs.get("aesthetic_score", 2.5)
else:
aesthetic_score = kwargs.get("aesthetic_score", 7.0)
if kwargs.get("prompt_type", "") == "negative":
width *= 0.8
height *= 0.8
elif kwargs.get("prompt_type", "") == "positive":
width *= 1.5
height *= 1.5
out = []
out.append(self.embedder(torch.Tensor([height])))
out.append(self.embedder(torch.Tensor([width])))
out.append(self.embedder(torch.Tensor([crop_h])))
out.append(self.embedder(torch.Tensor([crop_w])))
out.append(self.embedder(torch.Tensor([aesthetic_score])))
flat = torch.flatten(torch.cat(out))[None,]
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
def text_encoder_device_patched(): def text_encoder_device_patched():
# Fooocus's style system uses text encoder much more times than comfy so this makes things much faster. # Fooocus's style system uses text encoder much more times than comfy so this makes things much faster.
return comfy.model_management.get_torch_device() return comfy.model_management.get_torch_device()
@ -83,3 +112,4 @@ def patch_all():
comfy.k_diffusion.external.DiscreteEpsDDPMDenoiser.forward = patched_discrete_eps_ddpm_denoiser_forward comfy.k_diffusion.external.DiscreteEpsDDPMDenoiser.forward = patched_discrete_eps_ddpm_denoiser_forward
comfy.model_base.SDXL.encode_adm = sdxl_encode_adm_patched comfy.model_base.SDXL.encode_adm = sdxl_encode_adm_patched
# comfy.model_base.SDXLRefiner.encode_adm = sdxl_refiner_encode_adm_patched

View File

@ -3,6 +3,7 @@ import os
modelfile_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../models/checkpoints/')) modelfile_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../models/checkpoints/'))
lorafile_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../models/loras/')) lorafile_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../models/loras/'))
vae_approx_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../models/vae_approx/')) vae_approx_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../models/vae_approx/'))
upscale_models_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../models/upscale_models/'))
temp_outputs_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../outputs/')) temp_outputs_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../outputs/'))
fooocus_expansion_path = os.path.abspath(os.path.join(os.path.dirname(__file__), fooocus_expansion_path = os.path.abspath(os.path.join(os.path.dirname(__file__),

25
modules/upscaler.py Normal file
View File

@ -0,0 +1,25 @@
import os
import torch
from comfy_extras.chainner_models.architecture.RRDB import RRDBNet as ESRGAN
from comfy_extras.nodes_upscale_model import ImageUpscaleWithModel
from collections import OrderedDict
from modules.path import upscale_models_path
model_filename = os.path.join(upscale_models_path, 'fooocus_upscaler_s409985e5.bin')
opImageUpscaleWithModel = ImageUpscaleWithModel()
model = None
def perform_upscale(img):
global model
if model is None:
sd = torch.load(model_filename)
sdo = OrderedDict()
for k, v in sd.items():
sdo[k.replace('residual_block_', 'RDB')] = v
del sd
model = ESRGAN(sdo)
model.cpu()
model.eval()
return opImageUpscaleWithModel.upscale(model, img)[0]

View File

@ -1,7 +1,90 @@
import numpy as np
import datetime import datetime
import random import random
import os import os
from PIL import Image
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
def resize_image(im, width, height, resize_mode=1):
"""
Resizes an image with the specified resize_mode, width, and height.
Args:
resize_mode: The mode to use when resizing the image.
0: Resize the image to the specified width and height.
1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
im: The image to resize.
width: The width to resize the image to.
height: The height to resize the image to.
"""
im = Image.fromarray(im)
def resize(im, w, h):
return im.resize((w, h), resample=LANCZOS)
if resize_mode == 0:
res = resize(im, width, height)
elif resize_mode == 1:
ratio = width / height
src_ratio = im.width / im.height
src_w = width if ratio > src_ratio else im.width * height // im.height
src_h = height if ratio <= src_ratio else im.height * width // im.width
resized = resize(im, src_w, src_h)
res = Image.new("RGB", (width, height))
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
else:
ratio = width / height
src_ratio = im.width / im.height
src_w = width if ratio < src_ratio else im.width * height // im.height
src_h = height if ratio >= src_ratio else im.height * width // im.width
resized = resize(im, src_w, src_h)
res = Image.new("RGB", (width, height))
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
if ratio < src_ratio:
fill_height = height // 2 - src_h // 2
if fill_height > 0:
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
elif ratio > src_ratio:
fill_width = width // 2 - src_w // 2
if fill_width > 0:
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
return np.array(res)
def HWC3(x):
assert x.dtype == np.uint8
if x.ndim == 2:
x = x[:, :, None]
assert x.ndim == 3
H, W, C = x.shape
assert C == 1 or C == 3 or C == 4
if C == 3:
return x
if C == 1:
return np.concatenate([x, x, x], axis=2)
if C == 4:
color = x[:, :, 0:3].astype(np.float32)
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
y = color * alpha + 255.0 * (1.0 - alpha)
y = y.clip(0, 255).astype(np.uint8)
return y
def remove_empty_str(items, default=None): def remove_empty_str(items, default=None):
items = [x for x in items if x != ""] items = [x for x in items if x != ""]

View File

@ -1,2 +1 @@
gradio_root = None gradio_root = None

View File

@ -7,13 +7,14 @@ import modules.path
import fooocus_version import fooocus_version
import modules.html import modules.html
import modules.async_worker as worker import modules.async_worker as worker
import modules.flags as flags
import comfy.model_management as model_management
from modules.sdxl_styles import style_keys, aspect_ratios, fooocus_expansion, default_styles from modules.sdxl_styles import style_keys, aspect_ratios, fooocus_expansion, default_styles
def generate_clicked(*args): def generate_clicked(*args):
yield gr.update(interactive=False), \ yield gr.update(visible=True, value=modules.html.make_progress_html(1, 'Initializing ...')), \
gr.update(visible=True, value=modules.html.make_progress_html(1, 'Initializing ...')), \
gr.update(visible=True, value=None), \ gr.update(visible=True, value=None), \
gr.update(visible=False) gr.update(visible=False)
@ -26,13 +27,11 @@ def generate_clicked(*args):
flag, product = worker.outputs.pop(0) flag, product = worker.outputs.pop(0)
if flag == 'preview': if flag == 'preview':
percentage, title, image = product percentage, title, image = product
yield gr.update(interactive=False), \ yield gr.update(visible=True, value=modules.html.make_progress_html(percentage, title)), \
gr.update(visible=True, value=modules.html.make_progress_html(percentage, title)), \
gr.update(visible=True, value=image) if image is not None else gr.update(), \ gr.update(visible=True, value=image) if image is not None else gr.update(), \
gr.update(visible=False) gr.update(visible=False)
if flag == 'results': if flag == 'results':
yield gr.update(interactive=True), \ yield gr.update(visible=False), \
gr.update(visible=False), \
gr.update(visible=False), \ gr.update(visible=False), \
gr.update(visible=True, value=product) gr.update(visible=True, value=product)
finished = True finished = True
@ -50,9 +49,28 @@ with shared.gradio_root:
with gr.Column(scale=0.85): with gr.Column(scale=0.85):
prompt = gr.Textbox(show_label=False, placeholder="Type prompt here.", container=False, autofocus=True, elem_classes='type_row', lines=1024) prompt = gr.Textbox(show_label=False, placeholder="Type prompt here.", container=False, autofocus=True, elem_classes='type_row', lines=1024)
with gr.Column(scale=0.15, min_width=0): with gr.Column(scale=0.15, min_width=0):
run_button = gr.Button(label="Generate", value="Generate", elem_classes='type_row') run_button = gr.Button(label="Generate", value="Generate", elem_classes='type_row', visible=True)
with gr.Row(): stop_button = gr.Button(label="Stop", value="Stop", elem_classes='type_row', visible=False)
advanced_checkbox = gr.Checkbox(label='Advanced', value=False, container=False)
def stop_clicked():
model_management.interrupt_current_processing()
return gr.update(interactive=False)
stop_button.click(stop_clicked, outputs=stop_button, queue=False)
with gr.Row(elem_classes='advanced_check_row'):
input_image_checkbox = gr.Checkbox(label='Input Image', value=False, container=False, elem_classes='min_check')
advanced_checkbox = gr.Checkbox(label='Advanced', value=False, container=False, elem_classes='min_check')
with gr.Row(visible=False) as image_input_panel:
with gr.Column(scale=0.5):
with gr.Accordion(label='Upscale or Variation', open=True):
uov_input_image = gr.Image(label='Drag above image to here', source='upload', type='numpy')
uov_method = gr.Radio(label='Method', choices=flags.uov_list, value=flags.disabled, show_label=False, container=False)
gr.HTML('<a href="https://github.com/lllyasviel/Fooocus/discussions/390">\U0001F4D4 Document</a>')
input_image_checkbox.change(lambda x: gr.update(visible=x), inputs=input_image_checkbox, outputs=image_input_panel, queue=False)
# def get_select_index(g, evt: gr.SelectData):
# return g[evt.index]['name']
# gallery.select(get_select_index, gallery, uov_input_image)
with gr.Column(scale=0.5, visible=False) as right_col: with gr.Column(scale=0.5, visible=False) as right_col:
with gr.Tab(label='Setting'): with gr.Tab(label='Setting'):
performance_selction = gr.Radio(label='Performance', choices=['Speed', 'Quality'], value='Speed') performance_selction = gr.Radio(label='Performance', choices=['Speed', 'Quality'], value='Speed')
@ -73,7 +91,7 @@ with shared.gradio_root:
else: else:
return s return s
seed_random.change(random_checked, inputs=[seed_random], outputs=[image_seed]) seed_random.change(random_checked, inputs=[seed_random], outputs=[image_seed], queue=False)
with gr.Tab(label='Style'): with gr.Tab(label='Style'):
style_selections = gr.CheckboxGroup(show_label=False, container=False, style_selections = gr.CheckboxGroup(show_label=False, container=False,
@ -105,16 +123,21 @@ with shared.gradio_root:
results += [gr.update(choices=['None'] + modules.path.lora_filenames), gr.update()] results += [gr.update(choices=['None'] + modules.path.lora_filenames), gr.update()]
return results return results
model_refresh.click(model_refresh_clicked, [], [base_model, refiner_model] + lora_ctrls) model_refresh.click(model_refresh_clicked, [], [base_model, refiner_model] + lora_ctrls, queue=False)
advanced_checkbox.change(lambda x: gr.update(visible=x), advanced_checkbox, right_col) advanced_checkbox.change(lambda x: gr.update(visible=x), advanced_checkbox, right_col, queue=False)
ctrls = [ ctrls = [
prompt, negative_prompt, style_selections, prompt, negative_prompt, style_selections,
performance_selction, aspect_ratios_selction, image_number, image_seed, sharpness performance_selction, aspect_ratios_selction, image_number, image_seed, sharpness
] ]
ctrls += [base_model, refiner_model] + lora_ctrls ctrls += [base_model, refiner_model] + lora_ctrls
run_button.click(fn=refresh_seed, inputs=[seed_random, image_seed], outputs=image_seed)\ ctrls += [input_image_checkbox]
.then(fn=generate_clicked, inputs=ctrls, outputs=[run_button, progress_html, progress_window, gallery]) ctrls += [uov_method, uov_input_image]
run_button.click(lambda: (gr.update(visible=True, interactive=True), gr.update(visible=False), []), outputs=[stop_button, run_button, gallery])\
.then(fn=refresh_seed, inputs=[seed_random, image_seed], outputs=image_seed)\
.then(fn=generate_clicked, inputs=ctrls, outputs=[progress_html, progress_window, gallery])\
.then(lambda: (gr.update(visible=True), gr.update(visible=False)), outputs=[run_button, stop_button])
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()