mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-05 22:06:12 +02:00
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Close stale issues / stale (push) Has been cancelled
This should fix the import errors but will break if the custom nodes actually try to use the class.
64 lines
2.2 KiB
Python
64 lines
2.2 KiB
Python
import torch
|
|
from torch import Tensor, nn
|
|
|
|
from comfy.ldm.flux.layers import (
|
|
MLPEmbedder,
|
|
RMSNorm,
|
|
ModulationOut,
|
|
)
|
|
|
|
# TODO: remove this in a few months
|
|
SingleStreamBlock = None
|
|
DoubleStreamBlock = None
|
|
|
|
|
|
class ChromaModulationOut(ModulationOut):
|
|
@classmethod
|
|
def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut:
|
|
return cls(
|
|
shift=tensor[:, offset : offset + 1, :],
|
|
scale=tensor[:, offset + 1 : offset + 2, :],
|
|
gate=tensor[:, offset + 2 : offset + 3, :],
|
|
)
|
|
|
|
|
|
|
|
|
|
class Approximator(nn.Module):
|
|
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers = 5, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
|
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
|
self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
|
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
|
|
|
|
@property
|
|
def device(self):
|
|
# Get the device of the module (assumes all parameters are on the same device)
|
|
return next(self.parameters()).device
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
x = self.in_proj(x)
|
|
|
|
for layer, norms in zip(self.layers, self.norms):
|
|
x = x + layer(norms(x))
|
|
|
|
x = self.out_proj(x)
|
|
|
|
return x
|
|
|
|
|
|
class LastLayer(nn.Module):
|
|
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
|
self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device)
|
|
|
|
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
|
shift, scale = vec
|
|
shift = shift.squeeze(1)
|
|
scale = scale.squeeze(1)
|
|
x = torch.addcmul(shift[:, None, :], 1 + scale[:, None, :], self.norm_final(x))
|
|
x = self.linear(x)
|
|
return x
|