mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-05 05:46:12 +02:00
Merge branch 'master' into feat/api-nodes/openai-gpt5.5
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
This commit is contained in:
commit
db1243e54c
13
README.md
13
README.md
@ -1,7 +1,7 @@
|
||||
<div align="center">
|
||||
|
||||
# ComfyUI
|
||||
**The most powerful and modular visual AI engine and application.**
|
||||
**The most powerful and modular AI engine for content creation.**
|
||||
|
||||
|
||||
[![Website][website-shield]][website-url]
|
||||
@ -31,10 +31,16 @@
|
||||
[github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest
|
||||
[github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases
|
||||
|
||||

|
||||
<img width="1590" height="795" alt="ComfyUI Screenshot" src="https://github.com/user-attachments/assets/36e065e0-bfae-4456-8c7f-8369d5ea48a2" />
|
||||
<br>
|
||||
</div>
|
||||
|
||||
ComfyUI lets you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. Available on Windows, Linux, and macOS.
|
||||
ComfyUI is the AI creation engine for visual professionals who demand control over every model, every parameter, and every output. Its powerful and modular node graph interface empowers creatives to generate images, videos, 3D models, audio, and more...
|
||||
- ComfyUI natively supports the latest open-source state of the art models.
|
||||
- API nodes provide access to the best closed source models such as Nano Banana, Seedance, Hunyuan3D, etc.
|
||||
- It is available on Windows, Linux, and macOS, locally with our desktop application or on our cloud.
|
||||
- The most sophisticated workflows can be exposed through a simple UI thanks to App Mode.
|
||||
- It integrates seamlessly into production pipelines with our API endpoints.
|
||||
|
||||
## Get Started
|
||||
|
||||
@ -77,6 +83,7 @@ See what ComfyUI can do with the [newer template workflows](https://comfy.org/wo
|
||||
- [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
|
||||
- [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/)
|
||||
- [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/)
|
||||
- Ernie Image
|
||||
- Image Editing Models
|
||||
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
|
||||
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
|
||||
|
||||
@ -91,6 +91,7 @@ parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE"
|
||||
|
||||
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
|
||||
parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.")
|
||||
parser.add_argument("--enable-triton-backend", action="store_true", help="ComfyUI will enable the use of Triton backend in comfy-kitchen. Is disabled at launch by default.")
|
||||
|
||||
class LatentPreviewMethod(enum.Enum):
|
||||
NoPreviews = "none"
|
||||
|
||||
@ -16,6 +16,7 @@ from comfy.ldm.lightricks.model import (
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.model_prefetch
|
||||
|
||||
class CompressedTimestep:
|
||||
"""Store video timestep embeddings in compressed form using per-frame indexing."""
|
||||
@ -907,9 +908,11 @@ class LTXAVModel(LTXVModel):
|
||||
"""Process transformer blocks for LTXAV."""
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
prefetch_queue = comfy.model_prefetch.make_prefetch_queue(list(self.transformer_blocks), vx.device, transformer_options)
|
||||
|
||||
# Process transformer blocks
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, block)
|
||||
if ("double_block", i) in blocks_replace:
|
||||
|
||||
def block_wrap(args):
|
||||
@ -982,6 +985,8 @@ class LTXAVModel(LTXVModel):
|
||||
a_prompt_timestep=a_prompt_timestep,
|
||||
)
|
||||
|
||||
comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, None)
|
||||
|
||||
return [vx, ax]
|
||||
|
||||
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
||||
|
||||
@ -14,6 +14,8 @@ from .sub_quadratic_attention import efficient_dot_product_attention
|
||||
|
||||
from comfy import model_management
|
||||
|
||||
TORCH_HAS_GQA = model_management.torch_version_numeric >= (2, 5)
|
||||
|
||||
if model_management.xformers_enabled():
|
||||
import xformers
|
||||
import xformers.ops
|
||||
@ -150,7 +152,12 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
|
||||
scale = dim_head ** -0.5
|
||||
if kwargs.get("enable_gqa", False) and q.shape[-3] != k.shape[-3]:
|
||||
n_rep = q.shape[-3] // k.shape[-3]
|
||||
k = k.repeat_interleave(n_rep, dim=-3)
|
||||
v = v.repeat_interleave(n_rep, dim=-3)
|
||||
|
||||
scale = kwargs.get("scale", dim_head ** -0.5)
|
||||
|
||||
h = heads
|
||||
if skip_reshape:
|
||||
@ -219,6 +226,10 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
|
||||
b, _, dim_head = query.shape
|
||||
dim_head //= heads
|
||||
|
||||
if "scale" in kwargs:
|
||||
# Pre-scale query to match requested scale (cancels internal 1/sqrt(dim_head))
|
||||
query = query * (kwargs["scale"] * dim_head ** 0.5)
|
||||
|
||||
if skip_reshape:
|
||||
query = query.reshape(b * heads, -1, dim_head)
|
||||
value = value.reshape(b * heads, -1, dim_head)
|
||||
@ -290,7 +301,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
|
||||
scale = dim_head ** -0.5
|
||||
scale = kwargs.get("scale", dim_head ** -0.5)
|
||||
|
||||
if skip_reshape:
|
||||
q, k, v = map(
|
||||
@ -500,8 +511,13 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
# Pass through extra SDPA kwargs (scale, enable_gqa) if provided
|
||||
# enable_gqa requires PyTorch 2.5+; older versions use manual KV expansion above
|
||||
sdpa_keys = ("scale", "enable_gqa") if TORCH_HAS_GQA else ("scale",)
|
||||
sdpa_extra = {k: v for k, v in kwargs.items() if k in sdpa_keys}
|
||||
|
||||
if SDP_BATCH_LIMIT >= b:
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, **sdpa_extra)
|
||||
if not skip_output_reshape:
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
@ -519,7 +535,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
k[i : i + SDP_BATCH_LIMIT],
|
||||
v[i : i + SDP_BATCH_LIMIT],
|
||||
attn_mask=m,
|
||||
dropout_p=0.0, is_causal=False
|
||||
dropout_p=0.0, is_causal=False, **sdpa_extra
|
||||
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
|
||||
return out
|
||||
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import comfy.memory_management
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
import comfy.model_base
|
||||
@ -473,3 +474,17 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
|
||||
weight = old_weight
|
||||
|
||||
return weight
|
||||
|
||||
def prefetch_prepared_value(value, allocate_buffer, stream):
|
||||
if isinstance(value, torch.Tensor):
|
||||
dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value))
|
||||
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
|
||||
return comfy.memory_management.interpret_gathered_like([value], dest)[0]
|
||||
elif isinstance(value, weight_adapter.WeightAdapterBase):
|
||||
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream))
|
||||
elif isinstance(value, tuple):
|
||||
return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value)
|
||||
elif isinstance(value, list):
|
||||
return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value]
|
||||
|
||||
return value
|
||||
|
||||
@ -214,6 +214,11 @@ class BaseModel(torch.nn.Module):
|
||||
if "latent_shapes" in extra_conds:
|
||||
xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
|
||||
|
||||
transformer_options = transformer_options.copy()
|
||||
transformer_options["prefetch_dynamic_vbars"] = (
|
||||
self.current_patcher is not None and self.current_patcher.is_dynamic()
|
||||
)
|
||||
|
||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
|
||||
if len(model_output) > 1 and not torch.is_tensor(model_output):
|
||||
model_output, _ = utils.pack_latents(model_output)
|
||||
|
||||
@ -31,6 +31,7 @@ from contextlib import nullcontext
|
||||
import comfy.memory_management
|
||||
import comfy.utils
|
||||
import comfy.quant_ops
|
||||
import comfy_aimdo.vram_buffer
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
@ -1175,6 +1176,10 @@ stream_counters = {}
|
||||
|
||||
STREAM_CAST_BUFFERS = {}
|
||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||
STREAM_AIMDO_CAST_BUFFERS = {}
|
||||
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
|
||||
|
||||
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
|
||||
|
||||
def get_cast_buffer(offload_stream, device, size, ref):
|
||||
global LARGEST_CASTED_WEIGHT
|
||||
@ -1208,13 +1213,26 @@ def get_cast_buffer(offload_stream, device, size, ref):
|
||||
|
||||
return cast_buffer
|
||||
|
||||
def get_aimdo_cast_buffer(offload_stream, device):
|
||||
cast_buffer = STREAM_AIMDO_CAST_BUFFERS.get(offload_stream, None)
|
||||
if cast_buffer is None:
|
||||
cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index)
|
||||
STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
|
||||
|
||||
return cast_buffer
|
||||
def reset_cast_buffers():
|
||||
global LARGEST_CASTED_WEIGHT
|
||||
global LARGEST_AIMDO_CASTED_WEIGHT
|
||||
|
||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||
for offload_stream in STREAM_CAST_BUFFERS:
|
||||
offload_stream.synchronize()
|
||||
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
|
||||
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS):
|
||||
if offload_stream is not None:
|
||||
offload_stream.synchronize()
|
||||
synchronize()
|
||||
|
||||
STREAM_CAST_BUFFERS.clear()
|
||||
STREAM_AIMDO_CAST_BUFFERS.clear()
|
||||
soft_empty_cache()
|
||||
|
||||
def get_offload_stream(device):
|
||||
|
||||
@ -121,9 +121,20 @@ class LowVramPatch:
|
||||
self.patches = patches
|
||||
self.convert_func = convert_func # TODO: remove
|
||||
self.set_func = set_func
|
||||
self.prepared_patches = None
|
||||
|
||||
def prepare(self, allocate_buffer, stream):
|
||||
self.prepared_patches = [
|
||||
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4])
|
||||
for patch in self.patches[self.key]
|
||||
]
|
||||
|
||||
def clear_prepared(self):
|
||||
self.prepared_patches = None
|
||||
|
||||
def __call__(self, weight):
|
||||
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
|
||||
patches = self.prepared_patches if self.prepared_patches is not None else self.patches[self.key]
|
||||
return comfy.lora.calculate_weight(patches, weight, self.key, intermediate_dtype=weight.dtype)
|
||||
|
||||
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2
|
||||
|
||||
|
||||
65
comfy/model_prefetch.py
Normal file
65
comfy/model_prefetch.py
Normal file
@ -0,0 +1,65 @@
|
||||
import comfy_aimdo.model_vbar
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
|
||||
PREFETCH_QUEUES = []
|
||||
|
||||
def cleanup_prefetched_modules(comfy_modules):
|
||||
for s in comfy_modules:
|
||||
prefetch = getattr(s, "_prefetch", None)
|
||||
if prefetch is None:
|
||||
continue
|
||||
for param_key in ("weight", "bias"):
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
if lowvram_fn is not None:
|
||||
lowvram_fn.clear_prepared()
|
||||
if prefetch["signature"] is not None:
|
||||
comfy_aimdo.model_vbar.vbar_unpin(s._v)
|
||||
delattr(s, "_prefetch")
|
||||
|
||||
def cleanup_prefetch_queues():
|
||||
global PREFETCH_QUEUES
|
||||
|
||||
for queue in PREFETCH_QUEUES:
|
||||
for entry in queue:
|
||||
if entry is None or not isinstance(entry, tuple):
|
||||
continue
|
||||
_, prefetch_state = entry
|
||||
comfy_modules = prefetch_state[1]
|
||||
if comfy_modules is not None:
|
||||
cleanup_prefetched_modules(comfy_modules)
|
||||
PREFETCH_QUEUES = []
|
||||
|
||||
def prefetch_queue_pop(queue, device, module):
|
||||
if queue is None:
|
||||
return
|
||||
|
||||
consumed = queue.pop(0)
|
||||
if consumed is not None:
|
||||
offload_stream, prefetch_state = consumed
|
||||
offload_stream.wait_stream(comfy.model_management.current_stream(device))
|
||||
_, comfy_modules = prefetch_state
|
||||
if comfy_modules is not None:
|
||||
cleanup_prefetched_modules(comfy_modules)
|
||||
|
||||
prefetch = queue[0]
|
||||
if prefetch is not None:
|
||||
comfy_modules = []
|
||||
for s in prefetch.modules():
|
||||
if hasattr(s, "_v"):
|
||||
comfy_modules.append(s)
|
||||
|
||||
offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True)
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
queue[0] = (offload_stream, (prefetch, comfy_modules))
|
||||
|
||||
def make_prefetch_queue(queue, device, transformer_options):
|
||||
if (not transformer_options.get("prefetch_dynamic_vbars", False)
|
||||
or comfy.model_management.NUM_STREAMS == 0
|
||||
or comfy.model_management.is_device_cpu(device)
|
||||
or not comfy.model_management.device_supports_non_blocking(device)):
|
||||
return None
|
||||
|
||||
queue = [None] + queue + [None]
|
||||
PREFETCH_QUEUES.append(queue)
|
||||
return queue
|
||||
268
comfy/ops.py
268
comfy/ops.py
@ -86,38 +86,61 @@ def materialize_meta_param(s, param_keys):
|
||||
setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad))
|
||||
|
||||
|
||||
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
|
||||
#vbar doesn't support CPU weights, but some custom nodes have weird paths
|
||||
#that might switch the layer to the CPU and expect it to work. We have to take
|
||||
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
|
||||
#If you are a custom node author reading this, please move your layer to the GPU
|
||||
#or declare your ModelPatcher as CPU in the first place.
|
||||
if comfy.model_management.is_device_cpu(device):
|
||||
materialize_meta_param(s, ["weight", "bias"])
|
||||
weight = s.weight.to(dtype=dtype, copy=True)
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight = weight.dequantize()
|
||||
bias = None
|
||||
if s.bias is not None:
|
||||
bias = s.bias.to(dtype=bias_dtype, copy=True)
|
||||
return weight, bias, (None, None, None)
|
||||
|
||||
# FIXME: add n=1 cache hit fast path
|
||||
def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blocking):
|
||||
offload_stream = None
|
||||
xfer_dest = None
|
||||
cast_buffer = None
|
||||
cast_buffer_offset = 0
|
||||
|
||||
def ensure_offload_stream(module, required_size, check_largest):
|
||||
nonlocal offload_stream
|
||||
nonlocal cast_buffer
|
||||
|
||||
if offload_stream is None:
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
if offload_stream is None or not check_largest or len(comfy_modules) != 1:
|
||||
return
|
||||
|
||||
current_size = 0 if cast_buffer is None else cast_buffer.size()
|
||||
if current_size < required_size and module is comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[0]:
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
cast_buffer = None
|
||||
if required_size > comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[1]:
|
||||
comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT = (module, required_size)
|
||||
|
||||
def get_cast_buffer(buffer_size):
|
||||
nonlocal offload_stream
|
||||
nonlocal cast_buffer
|
||||
nonlocal cast_buffer_offset
|
||||
|
||||
if buffer_size == 0:
|
||||
return None
|
||||
|
||||
if offload_stream is None:
|
||||
return torch.empty((buffer_size,), dtype=torch.uint8, device=device)
|
||||
|
||||
cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device)
|
||||
buffer = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(buffer_size, cast_buffer_offset), device)
|
||||
cast_buffer_offset += buffer_size
|
||||
return buffer
|
||||
|
||||
for s in comfy_modules:
|
||||
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
||||
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
||||
prefetch = {
|
||||
"signature": signature,
|
||||
"resident": resident,
|
||||
}
|
||||
|
||||
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
||||
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
||||
if signature is not None:
|
||||
if resident:
|
||||
weight = s._v_weight
|
||||
bias = s._v_bias
|
||||
else:
|
||||
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
|
||||
s._prefetch = prefetch
|
||||
continue
|
||||
|
||||
if not resident:
|
||||
materialize_meta_param(s, ["weight", "bias"])
|
||||
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) if signature is not None else None
|
||||
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
|
||||
cast_dest = None
|
||||
needs_cast = False
|
||||
|
||||
xfer_source = [ s.weight, s.bias ]
|
||||
|
||||
@ -129,22 +152,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
if data is None:
|
||||
continue
|
||||
if data.dtype != geometry.dtype:
|
||||
needs_cast = True
|
||||
cast_dest = xfer_dest
|
||||
if cast_dest is None:
|
||||
cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device)
|
||||
xfer_dest = None
|
||||
break
|
||||
|
||||
dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
if xfer_dest is None and offload_stream is not None:
|
||||
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
|
||||
if xfer_dest is None:
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
|
||||
ensure_offload_stream(s, dest_size if xfer_dest is None else 0, True)
|
||||
if xfer_dest is None:
|
||||
xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device)
|
||||
offload_stream = None
|
||||
xfer_dest = get_cast_buffer(dest_size)
|
||||
|
||||
if signature is None and pin is None:
|
||||
comfy.pinned_memory.pin_memory(s)
|
||||
@ -157,27 +173,54 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
xfer_source = [ pin ]
|
||||
#send it over
|
||||
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
|
||||
if cast_dest is not None:
|
||||
for param_key in ("weight", "bias"):
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
if lowvram_fn is not None:
|
||||
ensure_offload_stream(s, cast_buffer_offset, False)
|
||||
lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream)
|
||||
|
||||
prefetch["xfer_dest"] = xfer_dest
|
||||
prefetch["cast_dest"] = cast_dest
|
||||
prefetch["cast_geometry"] = cast_geometry
|
||||
prefetch["needs_cast"] = needs_cast
|
||||
s._prefetch = prefetch
|
||||
|
||||
return offload_stream
|
||||
|
||||
|
||||
def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant):
|
||||
|
||||
prefetch = getattr(s, "_prefetch", None)
|
||||
|
||||
if prefetch["resident"]:
|
||||
weight = s._v_weight
|
||||
bias = s._v_bias
|
||||
else:
|
||||
xfer_dest = prefetch["xfer_dest"]
|
||||
if prefetch["needs_cast"]:
|
||||
cast_dest = prefetch["cast_dest"] if prefetch["cast_dest"] is not None else torch.empty((comfy.memory_management.vram_aligned_size(prefetch["cast_geometry"]),), dtype=torch.uint8, device=device)
|
||||
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest),
|
||||
comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)):
|
||||
comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], cast_dest)):
|
||||
if post_cast is not None:
|
||||
post_cast.copy_(pre_cast)
|
||||
xfer_dest = cast_dest
|
||||
|
||||
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
|
||||
params = comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], xfer_dest)
|
||||
weight = params[0]
|
||||
bias = params[1]
|
||||
if signature is not None:
|
||||
if prefetch["signature"] is not None:
|
||||
s._v_weight = weight
|
||||
s._v_bias = bias
|
||||
s._v_signature=signature
|
||||
s._v_signature = prefetch["signature"]
|
||||
|
||||
def post_cast(s, param_key, x, dtype, resident, update_weight):
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
fns = getattr(s, param_key + "_function", [])
|
||||
|
||||
if x is None:
|
||||
return None
|
||||
|
||||
orig = x
|
||||
|
||||
def to_dequant(tensor, dtype):
|
||||
@ -205,14 +248,12 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
x = f(x)
|
||||
return x
|
||||
|
||||
update_weight = signature is not None
|
||||
update_weight = prefetch["signature"] is not None
|
||||
weight = post_cast(s, "weight", weight, dtype, prefetch["resident"], update_weight)
|
||||
if bias is not None:
|
||||
bias = post_cast(s, "bias", bias, bias_dtype, prefetch["resident"], update_weight)
|
||||
|
||||
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
|
||||
if s.bias is not None:
|
||||
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
|
||||
|
||||
#FIXME: weird offload return protocol
|
||||
return weight, bias, (offload_stream, device if signature is not None else None, None)
|
||||
return weight, bias
|
||||
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
|
||||
@ -230,10 +271,46 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
if device is None:
|
||||
device = input.device
|
||||
|
||||
def format_return(result, offloadable):
|
||||
weight, bias, offload_stream = result
|
||||
return (weight, bias, offload_stream) if offloadable else (weight, bias)
|
||||
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
|
||||
if hasattr(s, "_v"):
|
||||
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
|
||||
|
||||
#vbar doesn't support CPU weights, but some custom nodes have weird paths
|
||||
#that might switch the layer to the CPU and expect it to work. We have to take
|
||||
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
|
||||
#If you are a custom node author reading this, please move your layer to the GPU
|
||||
#or declare your ModelPatcher as CPU in the first place.
|
||||
if comfy.model_management.is_device_cpu(device):
|
||||
materialize_meta_param(s, ["weight", "bias"])
|
||||
weight = s.weight.to(dtype=dtype, copy=True)
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight = weight.dequantize()
|
||||
bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None
|
||||
return format_return((weight, bias, (None, None, None)), offloadable)
|
||||
|
||||
prefetched = hasattr(s, "_prefetch")
|
||||
offload_stream = None
|
||||
offload_device = None
|
||||
if not prefetched:
|
||||
offload_stream = cast_modules_with_vbar([s], dtype, device, bias_dtype, non_blocking)
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
|
||||
weight, bias = resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant)
|
||||
|
||||
if not prefetched:
|
||||
if getattr(s, "_prefetch")["signature"] is not None:
|
||||
offload_device = device
|
||||
for param_key in ("weight", "bias"):
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
if lowvram_fn is not None:
|
||||
lowvram_fn.clear_prepared()
|
||||
delattr(s, "_prefetch")
|
||||
return format_return((weight, bias, (offload_stream, offload_device, None)), offloadable)
|
||||
|
||||
|
||||
if offloadable and (device != s.weight.device or
|
||||
(s.bias is not None and device != s.bias.device)):
|
||||
@ -280,11 +357,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
for f in s.weight_function:
|
||||
weight = f(weight)
|
||||
|
||||
if offloadable:
|
||||
return weight, bias, (offload_stream, weight_a, bias_a)
|
||||
else:
|
||||
#Legacy function signature
|
||||
return weight, bias
|
||||
return format_return((weight, bias, (offload_stream, weight_a, bias_a)), offloadable)
|
||||
|
||||
|
||||
def uncast_bias_weight(s, weight, bias, offload_stream):
|
||||
@ -1173,6 +1246,93 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
self._buffers[key] = fn(buf)
|
||||
return self
|
||||
|
||||
class Embedding(manual_cast.Embedding):
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
weight_key = f"{prefix}weight"
|
||||
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
||||
if layer_conf is not None:
|
||||
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
||||
|
||||
# Only fp8 makes sense for embeddings (per-row dequant via index select).
|
||||
# Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently.
|
||||
quant_format = layer_conf.get("format", None) if layer_conf is not None else None
|
||||
if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict:
|
||||
self.quant_format = quant_format
|
||||
qconfig = QUANT_ALGOS[quant_format]
|
||||
layout_cls = get_layout_class(qconfig["comfy_tensor_layout"])
|
||||
weight = state_dict.pop(weight_key)
|
||||
manually_loaded_keys = [weight_key]
|
||||
|
||||
scale_key = f"{prefix}weight_scale"
|
||||
scale = state_dict.pop(scale_key, None)
|
||||
if scale is not None:
|
||||
scale = scale.float()
|
||||
manually_loaded_keys.append(scale_key)
|
||||
|
||||
params = layout_cls.Params(
|
||||
scale=scale if scale is not None else torch.ones((), dtype=torch.float32),
|
||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||
orig_shape=(self.num_embeddings, self.embedding_dim),
|
||||
)
|
||||
self.weight = torch.nn.Parameter(
|
||||
QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params),
|
||||
requires_grad=False)
|
||||
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
for k in manually_loaded_keys:
|
||||
if k in missing_keys:
|
||||
missing_keys.remove(k)
|
||||
else:
|
||||
if layer_conf is not None:
|
||||
state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8)
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
||||
if destination is not None:
|
||||
sd = destination
|
||||
else:
|
||||
sd = {}
|
||||
|
||||
if not hasattr(self, 'weight') or self.weight is None:
|
||||
return sd
|
||||
|
||||
if isinstance(self.weight, QuantizedTensor):
|
||||
sd_out = self.weight.state_dict("{}weight".format(prefix))
|
||||
for k in sd_out:
|
||||
sd[k] = sd_out[k]
|
||||
|
||||
quant_conf = {"format": self.quant_format}
|
||||
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
||||
else:
|
||||
sd["{}weight".format(prefix)] = self.weight
|
||||
return sd
|
||||
|
||||
def forward_comfy_cast_weights(self, input, out_dtype=None):
|
||||
weight = self.weight
|
||||
|
||||
# Optimized path: lookup in fp8, dequantize only the selected rows.
|
||||
if isinstance(weight, QuantizedTensor) and len(self.weight_function) == 0:
|
||||
qdata, _, offload_stream = cast_bias_weight(self, device=input.device, dtype=weight.dtype, offloadable=True)
|
||||
if isinstance(qdata, QuantizedTensor):
|
||||
scale = qdata._params.scale
|
||||
qdata = qdata._qdata
|
||||
else:
|
||||
scale = None
|
||||
|
||||
x = torch.nn.functional.embedding(
|
||||
input, qdata, self.padding_idx, self.max_norm,
|
||||
self.norm_type, self.scale_grad_by_freq, self.sparse)
|
||||
uncast_bias_weight(self, qdata, None, offload_stream)
|
||||
target_dtype = out_dtype if out_dtype is not None else weight._params.orig_dtype
|
||||
x = x.to(dtype=target_dtype)
|
||||
if scale is not None and scale != 1.0:
|
||||
x = x * scale.to(dtype=target_dtype)
|
||||
return x
|
||||
|
||||
# Fallback for non-quantized or weight_function (LoRA) case
|
||||
return super().forward_comfy_cast_weights(input, out_dtype=out_dtype)
|
||||
|
||||
return MixedPrecisionOps
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import torch
|
||||
import logging
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
try:
|
||||
import comfy_kitchen as ck
|
||||
from comfy_kitchen.tensor import (
|
||||
@ -21,7 +23,15 @@ try:
|
||||
ck.registry.disable("cuda")
|
||||
logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.")
|
||||
|
||||
ck.registry.disable("triton")
|
||||
if args.enable_triton_backend:
|
||||
try:
|
||||
import triton
|
||||
logging.info("Found triton %s. Enabling comfy-kitchen triton backend.", triton.__version__)
|
||||
except ImportError as e:
|
||||
logging.error(f"Failed to import triton, Error: {e}, the comfy-kitchen triton backend will not be available.")
|
||||
ck.registry.disable("triton")
|
||||
else:
|
||||
ck.registry.disable("triton")
|
||||
for k, v in ck.list_backends().items():
|
||||
logging.info(f"Found comfy_kitchen backend {k}: {v}")
|
||||
except ImportError as e:
|
||||
|
||||
@ -3,6 +3,7 @@ import comfy.model_management
|
||||
|
||||
RMSNorm = torch.nn.RMSNorm
|
||||
|
||||
# Note: torch's fused F.rms_norm is faster but produces slightly different output than manual implementations (rsqrt/reduction rounding).
|
||||
def rms_norm(x, weight=None, eps=1e-6):
|
||||
if weight is None:
|
||||
return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)
|
||||
|
||||
17
comfy/sd.py
17
comfy/sd.py
@ -65,6 +65,7 @@ import comfy.text_encoders.ace15
|
||||
import comfy.text_encoders.longcat_image
|
||||
import comfy.text_encoders.qwen35
|
||||
import comfy.text_encoders.ernie
|
||||
import comfy.text_encoders.gemma4
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@ -1271,6 +1272,9 @@ class TEModel(Enum):
|
||||
QWEN35_9B = 26
|
||||
QWEN35_27B = 27
|
||||
MINISTRAL_3_3B = 28
|
||||
GEMMA_4_E4B = 29
|
||||
GEMMA_4_E2B = 30
|
||||
GEMMA_4_31B = 31
|
||||
|
||||
|
||||
def detect_te_model(sd):
|
||||
@ -1296,6 +1300,12 @@ def detect_te_model(sd):
|
||||
return TEModel.BYT5_SMALL_GLYPH
|
||||
return TEModel.T5_BASE
|
||||
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
||||
if 'model.layers.59.self_attn.q_norm.weight' in sd:
|
||||
return TEModel.GEMMA_4_31B
|
||||
if 'model.layers.41.self_attn.q_norm.weight' in sd and 'model.layers.47.self_attn.q_norm.weight' not in sd:
|
||||
return TEModel.GEMMA_4_E4B
|
||||
if 'model.layers.34.self_attn.q_norm.weight' in sd and 'model.layers.41.self_attn.q_norm.weight' not in sd:
|
||||
return TEModel.GEMMA_4_E2B
|
||||
if 'model.layers.47.self_attn.q_norm.weight' in sd:
|
||||
return TEModel.GEMMA_3_12B
|
||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||
@ -1435,6 +1445,13 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
else:
|
||||
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
||||
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
||||
elif te_model in (TEModel.GEMMA_4_E4B, TEModel.GEMMA_4_E2B, TEModel.GEMMA_4_31B):
|
||||
variant = {TEModel.GEMMA_4_E4B: comfy.text_encoders.gemma4.Gemma4_E4B,
|
||||
TEModel.GEMMA_4_E2B: comfy.text_encoders.gemma4.Gemma4_E2B,
|
||||
TEModel.GEMMA_4_31B: comfy.text_encoders.gemma4.Gemma4_31B}[te_model]
|
||||
clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=variant)
|
||||
clip_target.tokenizer = variant.tokenizer
|
||||
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
|
||||
elif te_model == TEModel.GEMMA_2_2B:
|
||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
||||
|
||||
1298
comfy/text_encoders/gemma4.py
Normal file
1298
comfy/text_encoders/gemma4.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -521,7 +521,7 @@ class Attention(nn.Module):
|
||||
else:
|
||||
present_key_value = (xk, xv, index + num_tokens)
|
||||
|
||||
if sliding_window is not None and xk.shape[2] > sliding_window:
|
||||
if sliding_window is not None and xk.shape[2] > sliding_window and seq_length == 1:
|
||||
xk = xk[:, :, -sliding_window:]
|
||||
xv = xv[:, :, -sliding_window:]
|
||||
attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None
|
||||
@ -533,12 +533,12 @@ class Attention(nn.Module):
|
||||
return self.o_proj(output), present_key_value
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
|
||||
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None, intermediate_size=None):
|
||||
super().__init__()
|
||||
ops = ops or nn
|
||||
self.gate_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
|
||||
self.up_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
|
||||
self.down_proj = ops.Linear(config.intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
intermediate_size = intermediate_size or config.intermediate_size
|
||||
self.gate_proj = ops.Linear(config.hidden_size, intermediate_size, bias=False, device=device, dtype=dtype)
|
||||
self.up_proj = ops.Linear(config.hidden_size, intermediate_size, bias=False, device=device, dtype=dtype)
|
||||
self.down_proj = ops.Linear(intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
if config.mlp_activation == "silu":
|
||||
self.activation = torch.nn.functional.silu
|
||||
elif config.mlp_activation == "gelu_pytorch_tanh":
|
||||
@ -647,24 +647,25 @@ class TransformerBlockGemma2(nn.Module):
|
||||
|
||||
return x, present_key_value
|
||||
|
||||
def _make_scaled_embedding(ops, vocab_size, hidden_size, scale, device, dtype):
|
||||
class ScaledEmbedding(ops.Embedding):
|
||||
def forward(self, input_ids, out_dtype=None):
|
||||
return super().forward(input_ids, out_dtype=out_dtype) * scale
|
||||
return ScaledEmbedding(vocab_size, hidden_size, device=device, dtype=dtype)
|
||||
|
||||
|
||||
class Llama2_(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = ops.Embedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
device=device,
|
||||
dtype=dtype
|
||||
)
|
||||
if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3":
|
||||
transformer = TransformerBlockGemma2
|
||||
self.normalize_in = True
|
||||
self.embed_tokens = _make_scaled_embedding(ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype)
|
||||
else:
|
||||
transformer = TransformerBlock
|
||||
self.normalize_in = False
|
||||
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
transformer(config, index=i, device=device, dtype=dtype, ops=ops)
|
||||
@ -690,15 +691,12 @@ class Llama2_(nn.Module):
|
||||
self.config.rope_dims,
|
||||
device=device)
|
||||
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None):
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None):
|
||||
if embeds is not None:
|
||||
x = embeds
|
||||
else:
|
||||
x = self.embed_tokens(x, out_dtype=dtype)
|
||||
|
||||
if self.normalize_in:
|
||||
x *= self.config.hidden_size ** 0.5
|
||||
|
||||
seq_len = x.shape[1]
|
||||
past_len = 0
|
||||
if past_key_values is not None and len(past_key_values) > 0:
|
||||
@ -850,7 +848,7 @@ class BaseGenerate:
|
||||
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
||||
return past_key_values
|
||||
|
||||
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0):
|
||||
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None):
|
||||
device = embeds.device
|
||||
|
||||
if stop_tokens is None:
|
||||
@ -875,14 +873,16 @@ class BaseGenerate:
|
||||
pbar = comfy.utils.ProgressBar(max_length)
|
||||
|
||||
# Generation loop
|
||||
current_input_ids = initial_input_ids
|
||||
for step in tqdm(range(max_length), desc="Generating tokens"):
|
||||
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
|
||||
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids)
|
||||
logits = self.logits(x)[:, -1]
|
||||
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
|
||||
token_id = next_token[0].item()
|
||||
generated_token_ids.append(token_id)
|
||||
|
||||
embeds = self.model.embed_tokens(next_token).to(execution_dtype)
|
||||
current_input_ids = next_token if initial_input_ids is not None else None
|
||||
pbar.update(1)
|
||||
|
||||
if token_id in stop_tokens:
|
||||
|
||||
@ -93,8 +93,7 @@ class Gemma3_12BModel(sd1_clip.SDClipModel):
|
||||
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
|
||||
tokens_only = [[t[0] for t in b] for b in tokens]
|
||||
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
|
||||
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
|
||||
embeds, _, _, _ = self.process_tokens(tokens_only, self.execution_device)
|
||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106], presence_penalty=presence_penalty) # 106 is <end_of_turn>
|
||||
|
||||
class DualLinearProjection(torch.nn.Module):
|
||||
|
||||
@ -50,8 +50,7 @@ class Gemma3_4B_Vision_Model(sd1_clip.SDClipModel):
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B_Vision, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
def process_tokens(self, tokens, device):
|
||||
embeds, _, _, embeds_info = super().process_tokens(tokens, device)
|
||||
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
|
||||
embeds, _, _, _ = super().process_tokens(tokens, device)
|
||||
return embeds
|
||||
|
||||
class LuminaModel(sd1_clip.SD1ClipModel):
|
||||
|
||||
@ -408,8 +408,6 @@ class Qwen35Transformer(Llama2_):
|
||||
nn.Module.__init__(self)
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
self.normalize_in = False
|
||||
|
||||
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
|
||||
self.layers = nn.ModuleList([
|
||||
Qwen35TransformerBlock(config, index=i, device=device, dtype=dtype, ops=ops)
|
||||
|
||||
@ -1446,10 +1446,3 @@ def deepcopy_list_dict(obj, memo=None):
|
||||
memo[obj_id] = res
|
||||
return res
|
||||
|
||||
def normalize_image_embeddings(embeds, embeds_info, scale_factor):
|
||||
"""Normalize image embeddings to match text embedding scale"""
|
||||
for info in embeds_info:
|
||||
if info.get("type") == "image":
|
||||
start_idx = info["index"]
|
||||
end_idx = start_idx + info["size"]
|
||||
embeds[:, start_idx:end_idx, :] /= scale_factor
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -72,8 +72,11 @@ class VideoEnhancementFilter(BaseModel):
|
||||
grain: Optional[float] = Field(None, description="Grain after AI model processing")
|
||||
grainSize: Optional[float] = Field(None, description="Size of generated grain")
|
||||
recoverOriginalDetailValue: Optional[float] = Field(None, description="Source details into the output video")
|
||||
creativity: Optional[str] = Field(None, description="Creativity level(high, low) for slc-1 only")
|
||||
creativity: float | str | None = Field(None, description="slc-1/slp-2.5: enum (low/middle/high). ast-2: decimal 0.0-1.0.")
|
||||
isOptimizedMode: Optional[bool] = Field(None, description="Set to true for Starlight Creative (slc-1) only")
|
||||
prompt: str | None = Field(None, description="Descriptive scene prompt (ast-2 only)")
|
||||
sharp: float | None = Field(None, description="ast-2 pre-enhance sharpness")
|
||||
realism: float | None = Field(None, description="ast-2 realism control")
|
||||
|
||||
|
||||
class OutputInformationVideo(BaseModel):
|
||||
@ -90,7 +93,7 @@ class Overrides(BaseModel):
|
||||
|
||||
class CreateVideoRequest(BaseModel):
|
||||
source: CreateVideoRequestSource = Field(...)
|
||||
filters: list[Union[VideoFrameInterpolationFilter, VideoEnhancementFilter]] = Field(...)
|
||||
filters: list[VideoFrameInterpolationFilter | VideoEnhancementFilter] = Field(...)
|
||||
output: OutputInformationVideo = Field(...)
|
||||
overrides: Overrides = Field(Overrides(isPaidDiffusion=True))
|
||||
|
||||
|
||||
@ -36,11 +36,15 @@ from comfy_api_nodes.util import (
|
||||
)
|
||||
|
||||
UPSCALER_MODELS_MAP = {
|
||||
"Astra 2": "ast-2",
|
||||
"Starlight (Astra) Fast": "slf-1",
|
||||
"Starlight (Astra) Creative": "slc-1",
|
||||
"Starlight Precise 2.5": "slp-2.5",
|
||||
}
|
||||
|
||||
AST2_MAX_FRAMES = 9000
|
||||
AST2_MAX_FRAMES_WITH_PROMPT = 450
|
||||
|
||||
|
||||
class TopazImageEnhance(IO.ComfyNode):
|
||||
@classmethod
|
||||
@ -230,13 +234,20 @@ class TopazVideoEnhance(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TopazVideoEnhance",
|
||||
display_name="Topaz Video Enhance",
|
||||
display_name="Topaz Video Enhance (Legacy)",
|
||||
category="api node/video/Topaz",
|
||||
description="Breathe new life into video with powerful upscaling and recovery technology.",
|
||||
inputs=[
|
||||
IO.Video.Input("video"),
|
||||
IO.Boolean.Input("upscaler_enabled", default=True),
|
||||
IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())),
|
||||
IO.Combo.Input(
|
||||
"upscaler_model",
|
||||
options=[
|
||||
"Starlight (Astra) Fast",
|
||||
"Starlight (Astra) Creative",
|
||||
"Starlight Precise 2.5",
|
||||
],
|
||||
),
|
||||
IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
|
||||
IO.Combo.Input(
|
||||
"upscaler_creativity",
|
||||
@ -304,6 +315,7 @@ class TopazVideoEnhance(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -457,12 +469,357 @@ class TopazVideoEnhance(IO.ComfyNode):
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.download.url))
|
||||
|
||||
|
||||
class TopazVideoEnhanceV2(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TopazVideoEnhanceV2",
|
||||
display_name="Topaz Video Enhance",
|
||||
category="api node/video/Topaz",
|
||||
description="Breathe new life into video with powerful upscaling and recovery technology.",
|
||||
inputs=[
|
||||
IO.Video.Input("video"),
|
||||
IO.DynamicCombo.Input(
|
||||
"upscaler_model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"Astra 2",
|
||||
[
|
||||
IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
|
||||
IO.Float.Input(
|
||||
"creativity",
|
||||
default=0.5,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Creative strength of the upscale.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Optional descriptive (not instructive) scene prompt."
|
||||
f"Capping input at {AST2_MAX_FRAMES_WITH_PROMPT} frames (~15s @ 30fps) when set.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"sharp",
|
||||
default=0.5,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Pre-enhance sharpness: "
|
||||
"0.0=Gaussian blur, 0.5=passthrough (default), 1.0=USM sharpening.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"realism",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Pulls output toward photographic realism."
|
||||
"Leave at 0 for the model default.",
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"Starlight (Astra) Fast",
|
||||
[IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"Starlight (Astra) Creative",
|
||||
[
|
||||
IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
|
||||
IO.Combo.Input(
|
||||
"creativity",
|
||||
options=["low", "middle", "high"],
|
||||
default="low",
|
||||
tooltip="Creative strength of the upscale.",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"Starlight Precise 2.5",
|
||||
[IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"])],
|
||||
),
|
||||
IO.DynamicCombo.Option("Disabled", []),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"interpolation_model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("Disabled", []),
|
||||
IO.DynamicCombo.Option(
|
||||
"apo-8",
|
||||
[
|
||||
IO.Int.Input(
|
||||
"interpolation_frame_rate",
|
||||
default=60,
|
||||
min=15,
|
||||
max=240,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Output frame rate.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"interpolation_slowmo",
|
||||
default=1,
|
||||
min=1,
|
||||
max=16,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Slow-motion factor applied to the input video. "
|
||||
"For example, 2 makes the output twice as slow and doubles the duration.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"interpolation_duplicate",
|
||||
default=False,
|
||||
tooltip="Analyze the input for duplicate frames and remove them.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"interpolation_duplicate_threshold",
|
||||
default=0.01,
|
||||
min=0.001,
|
||||
max=0.1,
|
||||
step=0.001,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Detection sensitivity for duplicate frames.",
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"dynamic_compression_level",
|
||||
options=["Low", "Mid", "High"],
|
||||
default="Low",
|
||||
tooltip="CQP level.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=[
|
||||
"upscaler_model",
|
||||
"upscaler_model.upscaler_resolution",
|
||||
"interpolation_model",
|
||||
]),
|
||||
expr="""
|
||||
(
|
||||
$model := $lookup(widgets, "upscaler_model");
|
||||
$res := $lookup(widgets, "upscaler_model.upscaler_resolution");
|
||||
$interp := $lookup(widgets, "interpolation_model");
|
||||
$is4k := $contains($res, "4k");
|
||||
$hasInterp := $interp != "disabled";
|
||||
$rates := {
|
||||
"starlight (astra) fast": {"hd": 0.43, "uhd": 0.85},
|
||||
"starlight precise 2.5": {"hd": 0.70, "uhd": 1.54},
|
||||
"astra 2": {"hd": 1.72, "uhd": 2.85},
|
||||
"starlight (astra) creative": {"hd": 2.25, "uhd": 3.99}
|
||||
};
|
||||
$surcharge := $is4k ? 0.28 : 0.14;
|
||||
$entry := $lookup($rates, $model);
|
||||
$base := $is4k ? $entry.uhd : $entry.hd;
|
||||
$hi := $base + ($hasInterp ? $surcharge : 0);
|
||||
$model = "disabled"
|
||||
? {"type":"text","text":"Interpolation only"}
|
||||
: ($hasInterp
|
||||
? {"type":"text","text":"~" & $string($base) & "–" & $string($hi) & " credits/src frame"}
|
||||
: {"type":"text","text":"~" & $string($base) & " credits/src frame"})
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
video: Input.Video,
|
||||
upscaler_model: dict,
|
||||
interpolation_model: dict,
|
||||
dynamic_compression_level: str = "Low",
|
||||
) -> IO.NodeOutput:
|
||||
upscaler_choice = upscaler_model["upscaler_model"]
|
||||
interpolation_choice = interpolation_model["interpolation_model"]
|
||||
if upscaler_choice == "Disabled" and interpolation_choice == "Disabled":
|
||||
raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.")
|
||||
validate_container_format_is_mp4(video)
|
||||
src_width, src_height = video.get_dimensions()
|
||||
src_frame_rate = int(video.get_frame_rate())
|
||||
duration_sec = video.get_duration()
|
||||
src_video_stream = video.get_stream_source()
|
||||
target_width = src_width
|
||||
target_height = src_height
|
||||
target_frame_rate = src_frame_rate
|
||||
filters = []
|
||||
if upscaler_choice != "Disabled":
|
||||
if "1080p" in upscaler_model["upscaler_resolution"]:
|
||||
target_pixel_p = 1080
|
||||
max_long_side = 1920
|
||||
else:
|
||||
target_pixel_p = 2160
|
||||
max_long_side = 3840
|
||||
ar = src_width / src_height
|
||||
if src_width >= src_height:
|
||||
# Landscape or Square; Attempt to set height to target (e.g., 2160), calculate width
|
||||
target_height = target_pixel_p
|
||||
target_width = int(target_height * ar)
|
||||
# Check if width exceeds standard bounds (for ultra-wide e.g., 21:9 ARs)
|
||||
if target_width > max_long_side:
|
||||
target_width = max_long_side
|
||||
target_height = int(target_width / ar)
|
||||
else:
|
||||
# Portrait; Attempt to set width to target (e.g., 2160), calculate height
|
||||
target_width = target_pixel_p
|
||||
target_height = int(target_width / ar)
|
||||
# Check if height exceeds standard bounds
|
||||
if target_height > max_long_side:
|
||||
target_height = max_long_side
|
||||
target_width = int(target_height * ar)
|
||||
if target_width % 2 != 0:
|
||||
target_width += 1
|
||||
if target_height % 2 != 0:
|
||||
target_height += 1
|
||||
model_id = UPSCALER_MODELS_MAP[upscaler_choice]
|
||||
if model_id == "slc-1":
|
||||
filters.append(
|
||||
VideoEnhancementFilter(
|
||||
model=model_id,
|
||||
creativity=upscaler_model["creativity"],
|
||||
isOptimizedMode=True,
|
||||
)
|
||||
)
|
||||
elif model_id == "ast-2":
|
||||
n_frames = video.get_frame_count()
|
||||
ast2_prompt = (upscaler_model["prompt"] or "").strip()
|
||||
if ast2_prompt and n_frames > AST2_MAX_FRAMES_WITH_PROMPT:
|
||||
raise ValueError(
|
||||
f"Astra 2 with a prompt is limited to {AST2_MAX_FRAMES_WITH_PROMPT} input frames "
|
||||
f"(~15s @ 30fps); video has {n_frames}. Clear the prompt or shorten the clip."
|
||||
)
|
||||
if n_frames > AST2_MAX_FRAMES:
|
||||
raise ValueError(f"Astra 2 is limited to {AST2_MAX_FRAMES} input frames; video has {n_frames}.")
|
||||
realism = upscaler_model["realism"]
|
||||
filters.append(
|
||||
VideoEnhancementFilter(
|
||||
model=model_id,
|
||||
creativity=upscaler_model["creativity"],
|
||||
prompt=(ast2_prompt or None),
|
||||
sharp=upscaler_model["sharp"],
|
||||
realism=(realism if realism > 0 else None),
|
||||
)
|
||||
)
|
||||
else:
|
||||
filters.append(VideoEnhancementFilter(model=model_id))
|
||||
if interpolation_choice != "Disabled":
|
||||
target_frame_rate = interpolation_model["interpolation_frame_rate"]
|
||||
filters.append(
|
||||
VideoFrameInterpolationFilter(
|
||||
model=interpolation_choice,
|
||||
slowmo=interpolation_model["interpolation_slowmo"],
|
||||
fps=interpolation_model["interpolation_frame_rate"],
|
||||
duplicate=interpolation_model["interpolation_duplicate"],
|
||||
duplicate_threshold=interpolation_model["interpolation_duplicate_threshold"],
|
||||
),
|
||||
)
|
||||
initial_res = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/topaz/video/", method="POST"),
|
||||
response_model=CreateVideoResponse,
|
||||
data=CreateVideoRequest(
|
||||
source=CreateVideoRequestSource(
|
||||
container="mp4",
|
||||
size=get_fs_object_size(src_video_stream),
|
||||
duration=int(duration_sec),
|
||||
frameCount=video.get_frame_count(),
|
||||
frameRate=src_frame_rate,
|
||||
resolution=Resolution(width=src_width, height=src_height),
|
||||
),
|
||||
filters=filters,
|
||||
output=OutputInformationVideo(
|
||||
resolution=Resolution(width=target_width, height=target_height),
|
||||
frameRate=target_frame_rate,
|
||||
audioCodec="AAC",
|
||||
audioTransfer="Copy",
|
||||
dynamicCompressionLevel=dynamic_compression_level,
|
||||
),
|
||||
),
|
||||
wait_label="Creating task",
|
||||
final_label_on_success="Task created",
|
||||
)
|
||||
upload_res = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(
|
||||
path=f"/proxy/topaz/video/{initial_res.requestId}/accept",
|
||||
method="PATCH",
|
||||
),
|
||||
response_model=VideoAcceptResponse,
|
||||
wait_label="Preparing upload",
|
||||
final_label_on_success="Upload started",
|
||||
)
|
||||
if len(upload_res.urls) > 1:
|
||||
raise NotImplementedError(
|
||||
"Large files are not currently supported. Please open an issue in the ComfyUI repository."
|
||||
)
|
||||
async with aiohttp.ClientSession(headers={"Content-Type": "video/mp4"}) as session:
|
||||
if isinstance(src_video_stream, BytesIO):
|
||||
src_video_stream.seek(0)
|
||||
async with session.put(upload_res.urls[0], data=src_video_stream, raise_for_status=True) as res:
|
||||
upload_etag = res.headers["Etag"]
|
||||
else:
|
||||
with builtins.open(src_video_stream, "rb") as video_file:
|
||||
async with session.put(upload_res.urls[0], data=video_file, raise_for_status=True) as res:
|
||||
upload_etag = res.headers["Etag"]
|
||||
await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(
|
||||
path=f"/proxy/topaz/video/{initial_res.requestId}/complete-upload",
|
||||
method="PATCH",
|
||||
),
|
||||
response_model=VideoCompleteUploadResponse,
|
||||
data=VideoCompleteUploadRequest(
|
||||
uploadResults=[
|
||||
VideoCompleteUploadRequestPart(
|
||||
partNum=1,
|
||||
eTag=upload_etag,
|
||||
),
|
||||
],
|
||||
),
|
||||
wait_label="Finalizing upload",
|
||||
final_label_on_success="Upload completed",
|
||||
)
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/topaz/video/{initial_res.requestId}/status"),
|
||||
response_model=VideoStatusResponse,
|
||||
status_extractor=lambda x: x.status,
|
||||
progress_extractor=lambda x: getattr(x, "progress", 0),
|
||||
price_extractor=lambda x: (x.estimates.cost[0] * 0.08 if x.estimates and x.estimates.cost[0] else None),
|
||||
poll_interval=10.0,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.download.url))
|
||||
|
||||
|
||||
class TopazExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
TopazImageEnhance,
|
||||
TopazVideoEnhance,
|
||||
TopazVideoEnhanceV2,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -199,6 +199,9 @@ class FILMNet(nn.Module):
|
||||
def get_dtype(self):
|
||||
return self.extract.extract_sublevels.convs[0][0].conv.weight.dtype
|
||||
|
||||
def memory_used_forward(self, shape, dtype):
|
||||
return 1700 * shape[1] * shape[2] * dtype.itemsize
|
||||
|
||||
def _build_warp_grids(self, H, W, device):
|
||||
"""Pre-compute warp grids for all pyramid levels."""
|
||||
if (H, W) in self._warp_grids:
|
||||
|
||||
@ -74,6 +74,9 @@ class IFNet(nn.Module):
|
||||
def get_dtype(self):
|
||||
return self.encode.cnn0.weight.dtype
|
||||
|
||||
def memory_used_forward(self, shape, dtype):
|
||||
return 300 * shape[1] * shape[2] * dtype.itemsize
|
||||
|
||||
def _build_warp_grids(self, H, W, device):
|
||||
if (H, W) in self._warp_grids:
|
||||
return
|
||||
|
||||
@ -202,14 +202,11 @@ class JoinImageWithAlpha(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
|
||||
batch_size = min(len(image), len(alpha))
|
||||
out_images = []
|
||||
|
||||
batch_size = max(len(image), len(alpha))
|
||||
alpha = 1.0 - resize_mask(alpha, image.shape[1:])
|
||||
for i in range(batch_size):
|
||||
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
|
||||
|
||||
return io.NodeOutput(torch.stack(out_images))
|
||||
alpha = comfy.utils.repeat_to_batch_size(alpha, batch_size)
|
||||
image = comfy.utils.repeat_to_batch_size(image, batch_size)
|
||||
return io.NodeOutput(torch.cat((image[..., :3], alpha.unsqueeze(-1)), dim=-1))
|
||||
|
||||
|
||||
class CompositingExtension(ComfyExtension):
|
||||
|
||||
@ -37,7 +37,7 @@ class FrameInterpolationModelLoader(io.ComfyNode):
|
||||
model = cls._detect_and_load(sd)
|
||||
dtype = torch.float16 if model_management.should_use_fp16(model_management.get_torch_device()) else torch.float32
|
||||
model.eval().to(dtype)
|
||||
patcher = comfy.model_patcher.ModelPatcher(
|
||||
patcher = comfy.model_patcher.CoreModelPatcher(
|
||||
model,
|
||||
load_device=model_management.get_torch_device(),
|
||||
offload_device=model_management.unet_offload_device(),
|
||||
@ -98,16 +98,13 @@ class FrameInterpolate(io.ComfyNode):
|
||||
if num_frames < 2 or multiplier < 2:
|
||||
return io.NodeOutput(images)
|
||||
|
||||
model_management.load_model_gpu(interp_model)
|
||||
device = interp_model.load_device
|
||||
dtype = interp_model.model_dtype()
|
||||
inference_model = interp_model.model
|
||||
|
||||
# Free VRAM for inference activations (model weights + ~20x a single frame's worth)
|
||||
H, W = images.shape[1], images.shape[2]
|
||||
activation_mem = H * W * 3 * images.element_size() * 20
|
||||
model_management.free_memory(activation_mem, device)
|
||||
activation_mem = inference_model.memory_used_forward(images.shape, dtype)
|
||||
model_management.load_models_gpu([interp_model], memory_required=activation_mem)
|
||||
align = getattr(inference_model, "pad_align", 1)
|
||||
H, W = images.shape[1], images.shape[2]
|
||||
|
||||
# Prepare a single padded frame on device for determining output dimensions
|
||||
def prepare_frame(idx):
|
||||
|
||||
@ -666,12 +666,13 @@ class ColorTransfer(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ColorTransfer",
|
||||
display_name="Color Transfer",
|
||||
category="image/postprocessing",
|
||||
description="Match the colors of one image to another using various algorithms.",
|
||||
search_aliases=["color match", "color grading", "color correction", "match colors", "color transform", "mkl", "reinhard", "histogram"],
|
||||
inputs=[
|
||||
io.Image.Input("image_target", tooltip="Image(s) to apply the color transform to."),
|
||||
io.Image.Input("image_ref", optional=True, tooltip="Reference image(s) to match colors to. If not provided, processing is skipped"),
|
||||
io.Image.Input("image_ref", tooltip="Reference image(s) to match colors to."),
|
||||
io.Combo.Input("method", options=['reinhard_lab', 'mkl_lab', 'histogram'],),
|
||||
io.DynamicCombo.Input("source_stats",
|
||||
tooltip="per_frame: each frame matched to image_ref individually. uniform: pool stats across all source frames as baseline, match to image_ref. target_frame: use one chosen frame as the baseline for the transform to image_ref, applied uniformly to all frames (preserves relative differences)",
|
||||
|
||||
@ -49,7 +49,7 @@ class Int(io.ComfyNode):
|
||||
display_name="Int",
|
||||
category="utils/primitive",
|
||||
inputs=[
|
||||
io.Int.Input("value", min=-sys.maxsize, max=sys.maxsize, control_after_generate=True),
|
||||
io.Int.Input("value", min=-sys.maxsize, max=sys.maxsize, control_after_generate=io.ControlAfterGenerate.fixed),
|
||||
],
|
||||
outputs=[io.Int.Output()],
|
||||
)
|
||||
|
||||
@ -32,6 +32,8 @@ class TextGenerate(io.ComfyNode):
|
||||
io.Clip.Input("clip"),
|
||||
io.String.Input("prompt", multiline=True, dynamic_prompts=True, default=""),
|
||||
io.Image.Input("image", optional=True),
|
||||
io.Image.Input("video", optional=True, tooltip="Video frames as image batch. Assumed to be 24 FPS; subsampled to 1 FPS internally."),
|
||||
io.Audio.Input("audio", optional=True),
|
||||
io.Int.Input("max_length", default=256, min=1, max=2048),
|
||||
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
|
||||
io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."),
|
||||
@ -43,9 +45,9 @@ class TextGenerate(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput:
|
||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True, video=None, audio=None) -> io.NodeOutput:
|
||||
|
||||
tokens = clip.tokenize(prompt, image=image, skip_template=not use_default_template, min_length=1, thinking=thinking)
|
||||
tokens = clip.tokenize(prompt, image=image, skip_template=not use_default_template, min_length=1, thinking=thinking, video=video, audio=audio)
|
||||
|
||||
# Get sampling parameters from dynamic combo
|
||||
do_sample = sampling_mode.get("sampling_mode") == "on"
|
||||
@ -70,7 +72,8 @@ class TextGenerate(io.ComfyNode):
|
||||
seed=seed
|
||||
)
|
||||
|
||||
generated_text = clip.decode(generated_ids, skip_special_tokens=True)
|
||||
generated_text = clip.decode(generated_ids)
|
||||
|
||||
return io.NodeOutput(generated_text)
|
||||
|
||||
|
||||
@ -161,12 +164,12 @@ class TextGenerateLTX2Prompt(TextGenerate):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput:
|
||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True, video=None, audio=None) -> io.NodeOutput:
|
||||
if image is None:
|
||||
formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
||||
else:
|
||||
formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
||||
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking, use_default_template)
|
||||
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image=image, thinking=thinking, use_default_template=use_default_template, video=video, audio=audio)
|
||||
|
||||
|
||||
class TextgenExtension(ComfyExtension):
|
||||
|
||||
@ -15,6 +15,7 @@ import torch
|
||||
from comfy.cli_args import args
|
||||
import comfy.memory_management
|
||||
import comfy.model_management
|
||||
import comfy.model_prefetch
|
||||
import comfy_aimdo.model_vbar
|
||||
|
||||
from latent_preview import set_preview_method
|
||||
@ -537,6 +538,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
if args.verbose == "DEBUG":
|
||||
comfy_aimdo.control.analyze()
|
||||
comfy.model_management.reset_cast_buffers()
|
||||
comfy.model_prefetch.cleanup_prefetch_queues()
|
||||
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
|
||||
|
||||
if has_pending_tasks:
|
||||
|
||||
@ -28,7 +28,7 @@
|
||||
#config for a1111 ui
|
||||
#all you have to do is uncomment this (remove the #) and change the base_path to where yours is installed
|
||||
|
||||
#a111:
|
||||
#a1111:
|
||||
# base_path: path/to/stable-diffusion-webui/
|
||||
# checkpoints: models/Stable-diffusion
|
||||
# configs: models/Stable-diffusion
|
||||
|
||||
@ -86,6 +86,6 @@ def image_alpha_fix(destination, source):
|
||||
if destination.shape[-1] < source.shape[-1]:
|
||||
source = source[...,:destination.shape[-1]]
|
||||
elif destination.shape[-1] > source.shape[-1]:
|
||||
destination = torch.nn.functional.pad(destination, (0, 1))
|
||||
destination[..., -1] = 1.0
|
||||
source = torch.nn.functional.pad(source, (0, 1))
|
||||
source[..., -1] = 1.0
|
||||
return destination, source
|
||||
|
||||
95
nodes.py
95
nodes.py
@ -1694,26 +1694,27 @@ class LoadImage:
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK")
|
||||
FUNCTION = "load_image"
|
||||
|
||||
def load_image(self, image):
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
|
||||
dtype = comfy.model_management.intermediate_dtype()
|
||||
device = comfy.model_management.intermediate_device()
|
||||
|
||||
components = InputImpl.VideoFromFile(image_path).get_components()
|
||||
if components.images.shape[0] > 0:
|
||||
return (components.images, 1.0 - components.alpha[..., -1] if components.alpha is not None else torch.zeros((components.images.shape[0], 64, 64), dtype=torch.float32, device="cpu"))
|
||||
return (components.images.to(device=device, dtype=dtype), (1.0 - components.alpha[..., -1]).to(device=device, dtype=dtype) if components.alpha is not None else torch.zeros((components.images.shape[0], 64, 64), dtype=dtype, device=device))
|
||||
|
||||
# This code is left here to handle animated webp which pyav does not support loading
|
||||
img = node_helpers.pillow(Image.open, image_path)
|
||||
|
||||
output_images = []
|
||||
output_masks = []
|
||||
w, h = None, None
|
||||
|
||||
dtype = comfy.model_management.intermediate_dtype()
|
||||
|
||||
for i in ImageSequence.Iterator(img):
|
||||
i = node_helpers.pillow(ImageOps.exif_transpose, i)
|
||||
|
||||
if i.mode == 'I':
|
||||
i = i.point(lambda i: i * (1 / 255))
|
||||
image = i.convert("RGB")
|
||||
|
||||
if len(output_images) == 0:
|
||||
@ -1728,25 +1729,15 @@ class LoadImage:
|
||||
if 'A' in i.getbands():
|
||||
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
|
||||
mask = 1. - torch.from_numpy(mask)
|
||||
elif i.mode == 'P' and 'transparency' in i.info:
|
||||
mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0
|
||||
mask = 1. - torch.from_numpy(mask)
|
||||
else:
|
||||
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
|
||||
mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu")
|
||||
output_images.append(image.to(dtype=dtype))
|
||||
output_masks.append(mask.unsqueeze(0).to(dtype=dtype))
|
||||
|
||||
if img.format == "MPO":
|
||||
break # ignore all frames except the first one for MPO format
|
||||
output_image = torch.cat(output_images, dim=0)
|
||||
output_mask = torch.cat(output_masks, dim=0)
|
||||
|
||||
if len(output_images) > 1:
|
||||
output_image = torch.cat(output_images, dim=0)
|
||||
output_mask = torch.cat(output_masks, dim=0)
|
||||
else:
|
||||
output_image = output_images[0]
|
||||
output_mask = output_masks[0]
|
||||
|
||||
return (output_image, output_mask)
|
||||
return (output_image.to(device=device, dtype=dtype), output_mask.to(device=device, dtype=dtype))
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(s, image):
|
||||
@ -1763,57 +1754,49 @@ class LoadImage:
|
||||
|
||||
return True
|
||||
|
||||
class LoadImageMask:
|
||||
|
||||
class LoadImageMask(LoadImage):
|
||||
ESSENTIALS_CATEGORY = "Image Tools"
|
||||
SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"]
|
||||
|
||||
_color_channels = ["alpha", "red", "green", "blue"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
||||
return {"required":
|
||||
{"image": (sorted(files), {"image_upload": True}),
|
||||
"channel": (s._color_channels, ), }
|
||||
}
|
||||
types = super().INPUT_TYPES()
|
||||
return {
|
||||
"required": {
|
||||
**types["required"],
|
||||
"channel": (s._color_channels, )
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
FUNCTION = "load_image"
|
||||
def load_image(self, image, channel):
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
i = node_helpers.pillow(Image.open, image_path)
|
||||
i = node_helpers.pillow(ImageOps.exif_transpose, i)
|
||||
if i.getbands() != ("R", "G", "B", "A"):
|
||||
if i.mode == 'I':
|
||||
i = i.point(lambda i: i * (1 / 255))
|
||||
i = i.convert("RGBA")
|
||||
mask = None
|
||||
FUNCTION = "load_image_mask"
|
||||
|
||||
def load_image_mask(self, image, channel):
|
||||
image_tensor, mask_tensor = super().load_image(image)
|
||||
c = channel[0].upper()
|
||||
if c in i.getbands():
|
||||
mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0
|
||||
mask = torch.from_numpy(mask)
|
||||
if c == 'A':
|
||||
mask = 1. - mask
|
||||
|
||||
if c == 'A':
|
||||
return (mask_tensor,)
|
||||
|
||||
channel_idx = {'R': 0, 'G': 1, 'B': 2}.get(c, 0)
|
||||
|
||||
if channel_idx < image_tensor.shape[-1]:
|
||||
return (image_tensor[..., channel_idx].clone(),)
|
||||
else:
|
||||
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
|
||||
return (mask.unsqueeze(0),)
|
||||
empty_mask = torch.zeros(
|
||||
image_tensor.shape[:-1],
|
||||
dtype=image_tensor.dtype,
|
||||
device=image_tensor.device
|
||||
)
|
||||
return (empty_mask,)
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(s, image, channel):
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
m = hashlib.sha256()
|
||||
with open(image_path, 'rb') as f:
|
||||
m.update(f.read())
|
||||
return m.digest().hex()
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(s, image):
|
||||
if not folder_paths.exists_annotated_filepath(image):
|
||||
return "Invalid image file: {}".format(image)
|
||||
|
||||
return True
|
||||
return super().IS_CHANGED(image)
|
||||
|
||||
|
||||
class LoadImageOutput(LoadImage):
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.42.15
|
||||
comfyui-workflow-templates==0.9.66
|
||||
comfyui-workflow-templates==0.9.68
|
||||
comfyui-embedded-docs==0.4.4
|
||||
torch
|
||||
torchsde
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import errno
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
@ -1245,7 +1246,13 @@ class PromptServer():
|
||||
address = addr[0]
|
||||
port = addr[1]
|
||||
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
|
||||
await site.start()
|
||||
try:
|
||||
await site.start()
|
||||
except OSError as e:
|
||||
if e.errno == errno.EADDRINUSE:
|
||||
logging.error(f"Port {port} is already in use on address {address}. Please close the other application or use a different port with --port.")
|
||||
raise SystemExit(1)
|
||||
raise
|
||||
|
||||
if not hasattr(self, 'address'):
|
||||
self.address = address #TODO: remove this
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user