diff --git a/extras/ip_adapter.py b/extras/ip_adapter.py index a145b68..b18f0df 100644 --- a/extras/ip_adapter.py +++ b/extras/ip_adapter.py @@ -167,14 +167,7 @@ def preprocess(img, ip_adapter_path): ldm_patched.modules.model_management.load_model_gpu(clip_vision.patcher) pixel_values = clip_preprocess(numpy_to_pytorch(img).to(clip_vision.load_device)) - - if clip_vision.dtype != torch.float32: - precision_scope = torch.autocast - else: - precision_scope = lambda a, b: contextlib.nullcontext(a) - - with precision_scope(ldm_patched.modules.model_management.get_autocast_device(clip_vision.load_device), torch.float32): - outputs = clip_vision.model(pixel_values=pixel_values, intermediate_output=-2) + outputs = clip_vision.model(pixel_values=pixel_values, intermediate_output=-2) ip_adapter = entry['ip_adapter'] ip_layers = entry['ip_layers'] diff --git a/fooocus_version.py b/fooocus_version.py index 1a708c5..a7dac99 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '2.1.843' +version = '2.1.844' diff --git a/ldm_patched/contrib/external.py b/ldm_patched/contrib/external.py index e20b08c..7f95f08 100644 --- a/ldm_patched/contrib/external.py +++ b/ldm_patched/contrib/external.py @@ -1870,6 +1870,7 @@ def init_custom_nodes(): "nodes_images.py", "nodes_video_model.py", "nodes_sag.py", + "nodes_perpneg.py", ] for node_file in extras_files: diff --git a/ldm_patched/contrib/external_latent.py b/ldm_patched/contrib/external_latent.py index e2364b8..c6f874e 100644 --- a/ldm_patched/contrib/external_latent.py +++ b/ldm_patched/contrib/external_latent.py @@ -5,9 +5,7 @@ import torch def reshape_latent_to(target_shape, latent): if latent.shape[1:] != target_shape[1:]: - latent.movedim(1, -1) latent = ldm_patched.modules.utils.common_upscale(latent, target_shape[3], target_shape[2], "bilinear", "center") - latent.movedim(-1, 1) return ldm_patched.modules.utils.repeat_to_batch_size(latent, target_shape[0]) @@ -104,9 +102,32 @@ class LatentInterpolate: samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio)) return (samples_out,) +class LatentBatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "batch" + + CATEGORY = "latent/batch" + + def batch(self, samples1, samples2): + samples_out = samples1.copy() + s1 = samples1["samples"] + s2 = samples2["samples"] + + if s1.shape[1:] != s2.shape[1:]: + s2 = ldm_patched.modules.utils.common_upscale(s2, s1.shape[3], s1.shape[2], "bilinear", "center") + s = torch.cat((s1, s2), dim=0) + samples_out["samples"] = s + samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])]) + return (samples_out,) + NODE_CLASS_MAPPINGS = { "LatentAdd": LatentAdd, "LatentSubtract": LatentSubtract, "LatentMultiply": LatentMultiply, "LatentInterpolate": LatentInterpolate, + "LatentBatch": LatentBatch, } diff --git a/ldm_patched/contrib/external_model_advanced.py b/ldm_patched/contrib/external_model_advanced.py index 4ebd9db..03a2f04 100644 --- a/ldm_patched/contrib/external_model_advanced.py +++ b/ldm_patched/contrib/external_model_advanced.py @@ -19,41 +19,19 @@ class LCM(ldm_patched.modules.model_sampling.EPS): return c_out * x0 + c_skip * model_input -class ModelSamplingDiscreteDistilled(torch.nn.Module): +class ModelSamplingDiscreteDistilled(ldm_patched.modules.model_sampling.ModelSamplingDiscrete): original_timesteps = 50 - def __init__(self): - super().__init__() - self.sigma_data = 1.0 - timesteps = 1000 - beta_start = 0.00085 - beta_end = 0.012 + def __init__(self, model_config=None): + super().__init__(model_config) - betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float32) ** 2 - alphas = 1.0 - betas - alphas_cumprod = torch.cumprod(alphas, dim=0) + self.skip_steps = self.num_timesteps // self.original_timesteps - self.skip_steps = timesteps // self.original_timesteps - - - alphas_cumprod_valid = torch.zeros((self.original_timesteps), dtype=torch.float32) + sigmas_valid = torch.zeros((self.original_timesteps), dtype=torch.float32) for x in range(self.original_timesteps): - alphas_cumprod_valid[self.original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps] + sigmas_valid[self.original_timesteps - 1 - x] = self.sigmas[self.num_timesteps - 1 - x * self.skip_steps] - sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5 - self.set_sigmas(sigmas) - - def set_sigmas(self, sigmas): - self.register_buffer('sigmas', sigmas) - self.register_buffer('log_sigmas', sigmas.log()) - - @property - def sigma_min(self): - return self.sigmas[0] - - @property - def sigma_max(self): - return self.sigmas[-1] + self.set_sigmas(sigmas_valid) def timestep(self, sigma): log_sigma = sigma.log() @@ -68,14 +46,6 @@ class ModelSamplingDiscreteDistilled(torch.nn.Module): log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] return log_sigma.exp().to(timestep.device) - def percent_to_sigma(self, percent): - if percent <= 0.0: - return 999999999.9 - if percent >= 1.0: - return 0.0 - percent = 1.0 - percent - return self.sigma(torch.tensor(percent * 999.0)).item() - def rescale_zero_terminal_snr_sigmas(sigmas): alphas_cumprod = 1 / ((sigmas * sigmas) + 1) @@ -124,7 +94,7 @@ class ModelSamplingDiscrete: class ModelSamplingAdvanced(sampling_base, sampling_type): pass - model_sampling = ModelSamplingAdvanced() + model_sampling = ModelSamplingAdvanced(model.model.model_config) if zsnr: model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas)) @@ -156,7 +126,7 @@ class ModelSamplingContinuousEDM: class ModelSamplingAdvanced(ldm_patched.modules.model_sampling.ModelSamplingContinuousEDM, sampling_type): pass - model_sampling = ModelSamplingAdvanced() + model_sampling = ModelSamplingAdvanced(model.model.model_config) model_sampling.set_sigma_range(sigma_min, sigma_max) m.add_object_patch("model_sampling", model_sampling) return (m, ) diff --git a/ldm_patched/contrib/external_perpneg.py b/ldm_patched/contrib/external_perpneg.py new file mode 100644 index 0000000..ec91681 --- /dev/null +++ b/ldm_patched/contrib/external_perpneg.py @@ -0,0 +1,57 @@ +# https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py + +import torch +import ldm_patched.modules.model_management +import ldm_patched.modules.sample +import ldm_patched.modules.samplers +import ldm_patched.modules.utils + + +class PerpNeg: + @classmethod + def INPUT_TYPES(s): + return {"required": {"model": ("MODEL", ), + "empty_conditioning": ("CONDITIONING", ), + "neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing" + + def patch(self, model, empty_conditioning, neg_scale): + m = model.clone() + nocond = ldm_patched.modules.sample.convert_cond(empty_conditioning) + + def cfg_function(args): + model = args["model"] + noise_pred_pos = args["cond_denoised"] + noise_pred_neg = args["uncond_denoised"] + cond_scale = args["cond_scale"] + x = args["input"] + sigma = args["sigma"] + model_options = args["model_options"] + nocond_processed = ldm_patched.modules.samplers.encode_model_conds(model.extra_conds, nocond, x, x.device, "negative") + + (noise_pred_nocond, _) = ldm_patched.modules.samplers.calc_cond_uncond_batch(model, nocond_processed, None, x, sigma, model_options) + + pos = noise_pred_pos - noise_pred_nocond + neg = noise_pred_neg - noise_pred_nocond + perp = ((torch.mul(pos, neg).sum())/(torch.norm(neg)**2)) * neg + perp_neg = perp * neg_scale + cfg_result = noise_pred_nocond + cond_scale*(pos - perp_neg) + cfg_result = x - cfg_result + return cfg_result + + m.set_model_sampler_cfg_function(cfg_function) + + return (m, ) + + +NODE_CLASS_MAPPINGS = { + "PerpNeg": PerpNeg, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "PerpNeg": "Perp-Neg", +} diff --git a/ldm_patched/contrib/external_sag.py b/ldm_patched/contrib/external_sag.py index 3505b44..06ca67f 100644 --- a/ldm_patched/contrib/external_sag.py +++ b/ldm_patched/contrib/external_sag.py @@ -60,7 +60,7 @@ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0): attn = attn.reshape(b, -1, hw1, hw2) # Global Average Pool mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold - ratio = round(math.sqrt(lh * lw / hw1)) + ratio = math.ceil(math.sqrt(lh * lw / hw1)) mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)] # Reshape diff --git a/ldm_patched/modules/clip_vision.py b/ldm_patched/modules/clip_vision.py index eda441a..9699210 100644 --- a/ldm_patched/modules/clip_vision.py +++ b/ldm_patched/modules/clip_vision.py @@ -19,11 +19,13 @@ class Output: def clip_preprocess(image, size=224): mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype) std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype) - scale = (size / min(image.shape[1], image.shape[2])) - image = torch.nn.functional.interpolate(image.movedim(-1, 1), size=(round(scale * image.shape[1]), round(scale * image.shape[2])), mode="bicubic", antialias=True) - h = (image.shape[2] - size)//2 - w = (image.shape[3] - size)//2 - image = image[:,:,h:h+size,w:w+size] + image = image.movedim(-1, 1) + if not (image.shape[2] == size and image.shape[3] == size): + scale = (size / min(image.shape[2], image.shape[3])) + image = torch.nn.functional.interpolate(image, size=(round(scale * image.shape[2]), round(scale * image.shape[3])), mode="bicubic", antialias=True) + h = (image.shape[2] - size)//2 + w = (image.shape[3] - size)//2 + image = image[:,:,h:h+size,w:w+size] image = torch.clip((255. * image), 0, 255).round() / 255.0 return (image - mean.view([3,1,1])) / std.view([3,1,1]) @@ -34,11 +36,9 @@ class ClipVisionModel(): self.load_device = ldm_patched.modules.model_management.text_encoder_device() offload_device = ldm_patched.modules.model_management.text_encoder_offload_device() - self.dtype = torch.float32 - if ldm_patched.modules.model_management.should_use_fp16(self.load_device, prioritize_performance=False): - self.dtype = torch.float16 - - self.model = ldm_patched.modules.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, ldm_patched.modules.ops.disable_weight_init) + self.dtype = ldm_patched.modules.model_management.text_encoder_dtype(self.load_device) + self.model = ldm_patched.modules.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, ldm_patched.modules.ops.manual_cast) + self.model.eval() self.patcher = ldm_patched.modules.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) def load_sd(self, sd): @@ -46,15 +46,8 @@ class ClipVisionModel(): def encode_image(self, image): ldm_patched.modules.model_management.load_model_gpu(self.patcher) - pixel_values = clip_preprocess(image.to(self.load_device)) - - if self.dtype != torch.float32: - precision_scope = torch.autocast - else: - precision_scope = lambda a, b: contextlib.nullcontext(a) - - with precision_scope(ldm_patched.modules.model_management.get_autocast_device(self.load_device), torch.float32): - out = self.model(pixel_values=pixel_values, intermediate_output=-2) + pixel_values = clip_preprocess(image.to(self.load_device)).float() + out = self.model(pixel_values=pixel_values, intermediate_output=-2) outputs = Output() outputs["last_hidden_state"] = out[0].to(ldm_patched.modules.model_management.intermediate_device()) diff --git a/ldm_patched/modules/samplers.py b/ldm_patched/modules/samplers.py index 4e13d72..bfcb3f5 100644 --- a/ldm_patched/modules/samplers.py +++ b/ldm_patched/modules/samplers.py @@ -251,7 +251,8 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options) if "sampler_cfg_function" in model_options: - args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep} + args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, + "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options} cfg_result = x - model_options["sampler_cfg_function"](args) else: cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale diff --git a/troubleshoot.md b/troubleshoot.md index 7be743d..7e07974 100644 --- a/troubleshoot.md +++ b/troubleshoot.md @@ -143,19 +143,19 @@ Besides, the current support for MAC is very experimental, and we encourage user ### I am using Nvidia with 8GB VRAM, I get CUDA Out Of Memory -It is a BUG. Please let us know as soon as possible. Please make an issue. See also [minimal requirements](readme.md#minimal-requirement). +It is a BUG. Please let us know as soon as possible. Please make an issue. See also [minimal requirements](https://github.com/lllyasviel/Fooocus/tree/main?tab=readme-ov-file#minimal-requirement). ### I am using Nvidia with 6GB VRAM, I get CUDA Out Of Memory -It is very likely a BUG. Please let us know as soon as possible. Please make an issue. See also [minimal requirements](readme.md#minimal-requirement). +It is very likely a BUG. Please let us know as soon as possible. Please make an issue. See also [minimal requirements](https://github.com/lllyasviel/Fooocus/tree/main?tab=readme-ov-file#minimal-requirement). ### I am using Nvidia with 4GB VRAM with Float16 support, like RTX 3050, I get CUDA Out Of Memory -It is a BUG. Please let us know as soon as possible. Please make an issue. See also [minimal requirements](readme.md#minimal-requirement). +It is a BUG. Please let us know as soon as possible. Please make an issue. See also [minimal requirements](https://github.com/lllyasviel/Fooocus/tree/main?tab=readme-ov-file#minimal-requirement). ### I am using Nvidia with 4GB VRAM without Float16 support, like GTX 960, I get CUDA Out Of Memory -Supporting GPU with 4GB VRAM without fp16 is extremely difficult, and you may not be able to use SDXL. However, you may still make an issue and let us know. You may try SD1.5 in Automatic1111 or other software for your device. See also [minimal requirements](readme.md#minimal-requirement). +Supporting GPU with 4GB VRAM without fp16 is extremely difficult, and you may not be able to use SDXL. However, you may still make an issue and let us know. You may try SD1.5 in Automatic1111 or other software for your device. See also [minimal requirements](https://github.com/lllyasviel/Fooocus/tree/main?tab=readme-ov-file#minimal-requirement). ### I am using AMD GPU on Windows, I get CUDA Out Of Memory @@ -163,11 +163,11 @@ Current AMD support is very experimental for Windows. If you see this, then perh However, if you re able to run SDXL on this same device on any other software, please let us know immediately, and we will support it as soon as possible. If no other software can enable your device to run SDXL on Windows, then we also do not have much to help. -Besides, the AMD support on Linux is slightly better because it will use ROCM. You may also try it if you are willing to change OS to linux. See also [minimal requirements](readme.md#minimal-requirement). +Besides, the AMD support on Linux is slightly better because it will use ROCM. You may also try it if you are willing to change OS to linux. See also [minimal requirements](https://github.com/lllyasviel/Fooocus/tree/main?tab=readme-ov-file#minimal-requirement). ### I am using AMD GPU on Linux, I get CUDA Out Of Memory -Current AMD support for Linux is better than that for Windows, but still, very experimental. However, if you re able to run SDXL on this same device on any other software, please let us know immediately, and we will support it as soon as possible. If no other software can enable your device to run SDXL on Windows, then we also do not have much to help. See also [minimal requirements](readme.md#minimal-requirement). +Current AMD support for Linux is better than that for Windows, but still, very experimental. However, if you re able to run SDXL on this same device on any other software, please let us know immediately, and we will support it as soon as possible. If no other software can enable your device to run SDXL on Windows, then we also do not have much to help. See also [minimal requirements](https://github.com/lllyasviel/Fooocus/tree/main?tab=readme-ov-file#minimal-requirement). ### I tried flags like --lowvram or --gpu-only or --bf16 or so on, and things are not getting any better?