From 09e0d1cb3ae5a1d74443009a41da9f96c1b54683 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Sat, 2 Sep 2023 06:00:20 -0700 Subject: [PATCH] 1.0.45 (#313) * Reworked SAG, removed unnecessary patch * Reworked anisotropic filters for faster compute. * Replaced with guided anisotropic filter for less distribution. --- fooocus_version.py | 2 +- modules/anisotropic.py | 15 ++ modules/default_pipeline.py | 10 + modules/patch.py | 369 +++--------------------------------- update_log.md | 8 + webui.py | 2 +- 6 files changed, 61 insertions(+), 345 deletions(-) diff --git a/fooocus_version.py b/fooocus_version.py index 0ad2642..3261afa 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '1.0.43' +version = '1.0.45' diff --git a/modules/anisotropic.py b/modules/anisotropic.py index ba16a9d..5768222 100644 --- a/modules/anisotropic.py +++ b/modules/anisotropic.py @@ -126,6 +126,21 @@ def bilateral_blur( return _bilateral_blur(input, None, kernel_size, sigma_color, sigma_space, border_type, color_distance_type) +def adaptive_anisotropic_filter(x, g=None): + if g is None: + g = x + s, m = torch.std_mean(g, dim=(1, 2, 3), keepdim=True) + s = s + 1e-5 + guidance = (g - m) / s + y = _bilateral_blur(x, guidance, + kernel_size=(13, 13), + sigma_color=3.0, + sigma_space=3.0, + border_type='reflect', + color_distance_type='l1') + return y + + def joint_bilateral_blur( input: Tensor, guidance: Tensor, diff --git a/modules/default_pipeline.py b/modules/default_pipeline.py index 7db75ca..69684ec 100644 --- a/modules/default_pipeline.py +++ b/modules/default_pipeline.py @@ -4,6 +4,7 @@ import torch import modules.path from comfy.model_base import SDXL, SDXLRefiner +from modules.patch import cfg_patched xl_base: core.StableDiffusionModel = None @@ -123,6 +124,15 @@ def process(positive_prompt, negative_prompt, steps, switch, width, height, imag global positive_conditions_cache, negative_conditions_cache, \ positive_conditions_refiner_cache, negative_conditions_refiner_cache + if xl_base is not None: + xl_base.unet.model_options['sampler_cfg_function'] = cfg_patched + + if xl_base_patched is not None: + xl_base_patched.unet.model_options['sampler_cfg_function'] = cfg_patched + + if xl_refiner is not None: + xl_refiner.unet.model_options['sampler_cfg_function'] = cfg_patched + positive_conditions = core.encode_prompt_condition(clip=xl_base_patched.clip, prompt=positive_prompt) if positive_conditions_cache is None else positive_conditions_cache negative_conditions = core.encode_prompt_condition(clip=xl_base_patched.clip, prompt=negative_prompt) if negative_conditions_cache is None else negative_conditions_cache diff --git a/modules/patch.py b/modules/patch.py index 7195b2c..d573958 100644 --- a/modules/patch.py +++ b/modules/patch.py @@ -2,359 +2,43 @@ import torch import comfy.model_base import comfy.ldm.modules.diffusionmodules.openaimodel import comfy.samplers +import comfy.k_diffusion.external import modules.anisotropic as anisotropic -from comfy.samplers import model_management, lcm, math -from comfy.ldm.modules.diffusionmodules.openaimodel import timestep_embedding, forward_timestep_embed +from comfy.k_diffusion import utils sharpness = 2.0 - -def sampling_function_patched(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}, - seed=None): - def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in): - area = (x_in.shape[2], x_in.shape[3], 0, 0) - strength = 1.0 - if 'timestep_start' in cond[1]: - timestep_start = cond[1]['timestep_start'] - if timestep_in[0] > timestep_start: - return None - if 'timestep_end' in cond[1]: - timestep_end = cond[1]['timestep_end'] - if timestep_in[0] < timestep_end: - return None - if 'area' in cond[1]: - area = cond[1]['area'] - if 'strength' in cond[1]: - strength = cond[1]['strength'] - - adm_cond = None - if 'adm_encoded' in cond[1]: - adm_cond = cond[1]['adm_encoded'] - - input_x = x_in[:, :, area[2]:area[0] + area[2], area[3]:area[1] + area[3]] - if 'mask' in cond[1]: - # Scale the mask to the size of the input - # The mask should have been resized as we began the sampling process - mask_strength = 1.0 - if "mask_strength" in cond[1]: - mask_strength = cond[1]["mask_strength"] - mask = cond[1]['mask'] - assert (mask.shape[1] == x_in.shape[2]) - assert (mask.shape[2] == x_in.shape[3]) - mask = mask[:, area[2]:area[0] + area[2], area[3]:area[1] + area[3]] * mask_strength - mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1) - else: - mask = torch.ones_like(input_x) - mult = mask * strength - - if 'mask' not in cond[1]: - rr = 8 - if area[2] != 0: - for t in range(rr): - mult[:, :, t:1 + t, :] *= ((1.0 / rr) * (t + 1)) - if (area[0] + area[2]) < x_in.shape[2]: - for t in range(rr): - mult[:, :, area[0] - 1 - t:area[0] - t, :] *= ((1.0 / rr) * (t + 1)) - if area[3] != 0: - for t in range(rr): - mult[:, :, :, t:1 + t] *= ((1.0 / rr) * (t + 1)) - if (area[1] + area[3]) < x_in.shape[3]: - for t in range(rr): - mult[:, :, :, area[1] - 1 - t:area[1] - t] *= ((1.0 / rr) * (t + 1)) - - conditionning = {} - conditionning['c_crossattn'] = cond[0] - if cond_concat_in is not None and len(cond_concat_in) > 0: - cropped = [] - for x in cond_concat_in: - cr = x[:, :, area[2]:area[0] + area[2], area[3]:area[1] + area[3]] - cropped.append(cr) - conditionning['c_concat'] = torch.cat(cropped, dim=1) - - if adm_cond is not None: - conditionning['c_adm'] = adm_cond - - control = None - if 'control' in cond[1]: - control = cond[1]['control'] - - patches = None - if 'gligen' in cond[1]: - gligen = cond[1]['gligen'] - patches = {} - gligen_type = gligen[0] - gligen_model = gligen[1] - if gligen_type == "position": - gligen_patch = gligen_model.set_position(input_x.shape, gligen[2], input_x.device) - else: - gligen_patch = gligen_model.set_empty(input_x.shape, input_x.device) - - patches['middle_patch'] = [gligen_patch] - - return (input_x, mult, conditionning, area, control, patches) - - def cond_equal_size(c1, c2): - if c1 is c2: - return True - if c1.keys() != c2.keys(): - return False - if 'c_crossattn' in c1: - s1 = c1['c_crossattn'].shape - s2 = c2['c_crossattn'].shape - if s1 != s2: - if s1[0] != s2[0] or s1[2] != s2[2]: # these 2 cases should not happen - return False - - mult_min = lcm(s1[1], s2[1]) - diff = mult_min // min(s1[1], s2[1]) - if diff > 4: # arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much - return False - if 'c_concat' in c1: - if c1['c_concat'].shape != c2['c_concat'].shape: - return False - if 'c_adm' in c1: - if c1['c_adm'].shape != c2['c_adm'].shape: - return False - return True - - def can_concat_cond(c1, c2): - if c1[0].shape != c2[0].shape: - return False - - # control - if (c1[4] is None) != (c2[4] is None): - return False - if c1[4] is not None: - if c1[4] is not c2[4]: - return False - - # patches - if (c1[5] is None) != (c2[5] is None): - return False - if (c1[5] is not None): - if c1[5] is not c2[5]: - return False - - return cond_equal_size(c1[2], c2[2]) - - def cond_cat(c_list): - c_crossattn = [] - c_concat = [] - c_adm = [] - crossattn_max_len = 0 - for x in c_list: - if 'c_crossattn' in x: - c = x['c_crossattn'] - if crossattn_max_len == 0: - crossattn_max_len = c.shape[1] - else: - crossattn_max_len = lcm(crossattn_max_len, c.shape[1]) - c_crossattn.append(c) - if 'c_concat' in x: - c_concat.append(x['c_concat']) - if 'c_adm' in x: - c_adm.append(x['c_adm']) - out = {} - c_crossattn_out = [] - for c in c_crossattn: - if c.shape[1] < crossattn_max_len: - c = c.repeat(1, crossattn_max_len // c.shape[1], 1) # padding with repeat doesn't change result - c_crossattn_out.append(c) - - if len(c_crossattn_out) > 0: - out['c_crossattn'] = [torch.cat(c_crossattn_out)] - if len(c_concat) > 0: - out['c_concat'] = [torch.cat(c_concat)] - if len(c_adm) > 0: - out['c_adm'] = torch.cat(c_adm) - return out - - def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in, - model_options): - out_cond = torch.zeros_like(x_in) - out_count = torch.ones_like(x_in) / 100000.0 - - out_uncond = torch.zeros_like(x_in) - out_uncond_count = torch.ones_like(x_in) / 100000.0 - - COND = 0 - UNCOND = 1 - - to_run = [] - for x in cond: - p = get_area_and_mult(x, x_in, cond_concat_in, timestep) - if p is None: - continue - - to_run += [(p, COND)] - if uncond is not None: - for x in uncond: - p = get_area_and_mult(x, x_in, cond_concat_in, timestep) - if p is None: - continue - - to_run += [(p, UNCOND)] - - while len(to_run) > 0: - first = to_run[0] - first_shape = first[0][0].shape - to_batch_temp = [] - for x in range(len(to_run)): - if can_concat_cond(to_run[x][0], first[0]): - to_batch_temp += [x] - - to_batch_temp.reverse() - to_batch = to_batch_temp[:1] - - for i in range(1, len(to_batch_temp) + 1): - batch_amount = to_batch_temp[:len(to_batch_temp) // i] - if (len(batch_amount) * first_shape[0] * first_shape[2] * first_shape[3] < max_total_area): - to_batch = batch_amount - break - - input_x = [] - mult = [] - c = [] - cond_or_uncond = [] - area = [] - control = None - patches = None - for x in to_batch: - o = to_run.pop(x) - p = o[0] - input_x += [p[0]] - mult += [p[1]] - c += [p[2]] - area += [p[3]] - cond_or_uncond += [o[1]] - control = p[4] - patches = p[5] - - batch_chunks = len(cond_or_uncond) - input_x = torch.cat(input_x) - c = cond_cat(c) - timestep_ = torch.cat([timestep] * batch_chunks) - - if control is not None: - c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond)) - - transformer_options = {} - if 'transformer_options' in model_options: - transformer_options = model_options['transformer_options'].copy() - - if patches is not None: - if "patches" in transformer_options: - cur_patches = transformer_options["patches"].copy() - for p in patches: - if p in cur_patches: - cur_patches[p] = cur_patches[p] + patches[p] - else: - cur_patches[p] = patches[p] - else: - transformer_options["patches"] = patches - - c['transformer_options'] = transformer_options - - transformer_options['uc_mask'] = torch.Tensor(cond_or_uncond).to(input_x).float()[:, None, None, None] - - if 'model_function_wrapper' in model_options: - output = model_options['model_function_wrapper'](model_function, - {"input": input_x, "timestep": timestep_, "c": c, - "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) - else: - output = model_function(input_x, timestep_, **c).chunk(batch_chunks) - del input_x - - model_management.throw_exception_if_processing_interrupted() - - for o in range(batch_chunks): - if cond_or_uncond[o] == COND: - out_cond[:, :, area[o][2]:area[o][0] + area[o][2], area[o][3]:area[o][1] + area[o][3]] += output[ - o] * \ - mult[o] - out_count[:, :, area[o][2]:area[o][0] + area[o][2], area[o][3]:area[o][1] + area[o][3]] += mult[o] - else: - out_uncond[:, :, area[o][2]:area[o][0] + area[o][2], area[o][3]:area[o][1] + area[o][3]] += output[ - o] * \ - mult[o] - out_uncond_count[:, :, area[o][2]:area[o][0] + area[o][2], area[o][3]:area[o][1] + area[o][3]] += \ - mult[o] - del mult - - out_cond /= out_count - del out_count - out_uncond /= out_uncond_count - del out_uncond_count - - return out_cond, out_uncond - - max_total_area = model_management.maximum_batch_area() - if math.isclose(cond_scale, 1.0): - uncond = None - - cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, - model_options) - if "sampler_cfg_function" in model_options: - args = {"cond": cond, "uncond": uncond, "cond_scale": cond_scale, "timestep": timestep} - return model_options["sampler_cfg_function"](args) - else: - return uncond + (cond - uncond) * cond_scale +cfg_x0 = 0.0 +cfg_s = 1.0 -def unet_forward_patched(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): - uc_mask = transformer_options['uc_mask'] - transformer_options["original_shape"] = list(x.shape) - transformer_options["current_index"] = 0 +def cfg_patched(args): + global cfg_x0, cfg_s + positive_eps = args['cond'].clone() + positive_x0 = args['cond'] * cfg_s + cfg_x0 + uncond = args['uncond'] * cfg_s + cfg_x0 + cond_scale = args['cond_scale'] + t = args['timestep'] - hs = [] - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype) - emb = self.time_embed(t_emb) - - if self.num_classes is not None: - assert y.shape[0] == x.shape[0] - emb = emb + self.label_emb(y) - - h = x.type(self.dtype) - for id, module in enumerate(self.input_blocks): - transformer_options["block"] = ("input", id) - h = forward_timestep_embed(module, h, emb, context, transformer_options) - if control is not None and 'input' in control and len(control['input']) > 0: - ctrl = control['input'].pop() - if ctrl is not None: - h += ctrl - hs.append(h) - transformer_options["block"] = ("middle", 0) - h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options) - if control is not None and 'middle' in control and len(control['middle']) > 0: - h += control['middle'].pop() - - for id, module in enumerate(self.output_blocks): - transformer_options["block"] = ("output", id) - hsp = hs.pop() - if control is not None and 'output' in control and len(control['output']) > 0: - ctrl = control['output'].pop() - if ctrl is not None: - hsp += ctrl - - h = torch.cat([h, hsp], dim=1) - del hsp - if len(hs) > 0: - output_shape = hs[-1].shape - else: - output_shape = None - h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape) - h = h.type(x.dtype) - x0 = self.out(h) - - alpha = 1.0 - (timesteps / 999.0)[:, None, None, None].clone() + alpha = 1.0 - (t / 999.0)[:, None, None, None].clone() alpha *= 0.001 * sharpness - degraded_x0 = anisotropic.bilateral_blur(x0) * alpha + x0 * (1.0 - alpha) - x0 = x0 * uc_mask + degraded_x0 * (1.0 - uc_mask) + eps_degraded = anisotropic.adaptive_anisotropic_filter(x=positive_eps, g=positive_x0) + eps_degraded_weighted = eps_degraded * alpha + positive_eps * (1.0 - alpha) - return x0 + cond = eps_degraded_weighted * cfg_s + cfg_x0 + + return uncond + (cond - uncond) * cond_scale + + +def patched_discrete_eps_ddpm_denoiser_forward(self, input, sigma, **kwargs): + global cfg_x0, cfg_s + c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + cfg_x0 = input + cfg_s = c_out + return self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs) def sdxl_encode_adm_patched(self, **kwargs): @@ -385,6 +69,5 @@ def sdxl_encode_adm_patched(self, **kwargs): def patch_all(): - comfy.samplers.sampling_function = sampling_function_patched + comfy.k_diffusion.external.DiscreteEpsDDPMDenoiser.forward = patched_discrete_eps_ddpm_denoiser_forward comfy.model_base.SDXL.encode_adm = sdxl_encode_adm_patched - comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = unet_forward_patched diff --git a/update_log.md b/update_log.md index 50930fb..ec90ce3 100644 --- a/update_log.md +++ b/update_log.md @@ -1,3 +1,11 @@ +### 1.0.45 + +* Reworked SAG, removed unnecessary patch +* Reworked anisotropic filters for faster compute. +* Replaced with guided anisotropic filter for less distortion. + +### 1.0.41 + (The update of Fooocus will be paused for a period of time for AUTOMATIC1111 sd-webui 1.6.X, and some features will also be implemented as webui extensions) ### 1.0.40 diff --git a/webui.py b/webui.py index a2a5c09..550b5ab 100644 --- a/webui.py +++ b/webui.py @@ -91,7 +91,7 @@ with shared.gradio_root: with gr.Row(): model_refresh = gr.Button(label='Refresh', value='\U0001f504 Refresh All Files', variant='secondary', elem_classes='refresh_button') with gr.Accordion(label='Advanced', open=False): - sharpness = gr.Slider(label='Sampling Sharpness', minimum=0.0, maximum=40.0, step=0.01, value=2.0) + sharpness = gr.Slider(label='Sampling Sharpness', minimum=0.0, maximum=30.0, step=0.01, value=2.0) gr.HTML('\U0001F4D4 Document') def model_refresh_clicked():