mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-04 21:36:14 +02:00
feat(isolation): execution and extension loading integration
This commit is contained in:
parent
359f972a1b
commit
d1155318eb
540
comfy/isolation/extension_loader.py
Normal file
540
comfy/isolation/extension_loader.py
Normal file
@ -0,0 +1,540 @@
|
||||
# pylint: disable=cyclic-import,import-outside-toplevel,redefined-outer-name
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import inspect
|
||||
import sys
|
||||
import types
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Tuple
|
||||
|
||||
import pyisolate
|
||||
from pyisolate import ExtensionManager, ExtensionManagerConfig
|
||||
from packaging.requirements import InvalidRequirement, Requirement
|
||||
from packaging.utils import canonicalize_name
|
||||
|
||||
from .manifest_loader import is_cache_valid, load_from_cache, save_to_cache
|
||||
from .host_policy import load_host_policy
|
||||
|
||||
try:
|
||||
import tomllib
|
||||
except ImportError:
|
||||
import tomli as tomllib # type: ignore[no-redef]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _register_web_directory(extension_name: str, node_dir: Path) -> None:
|
||||
"""Register an isolated extension's web directory on the host side."""
|
||||
import nodes
|
||||
|
||||
# Method 1: pyproject.toml [tool.comfy] web field
|
||||
pyproject = node_dir / "pyproject.toml"
|
||||
if pyproject.exists():
|
||||
try:
|
||||
with pyproject.open("rb") as f:
|
||||
data = tomllib.load(f)
|
||||
web_dir_name = data.get("tool", {}).get("comfy", {}).get("web")
|
||||
if web_dir_name:
|
||||
web_dir_path = str(node_dir / web_dir_name)
|
||||
if os.path.isdir(web_dir_path):
|
||||
nodes.EXTENSION_WEB_DIRS[extension_name] = web_dir_path
|
||||
logger.debug(
|
||||
"][ Registered web dir for isolated %s: %s",
|
||||
extension_name,
|
||||
web_dir_path,
|
||||
)
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Method 2: __init__.py WEB_DIRECTORY constant (parse without importing)
|
||||
init_file = node_dir / "__init__.py"
|
||||
if init_file.exists():
|
||||
try:
|
||||
source = init_file.read_text()
|
||||
for line in source.splitlines():
|
||||
stripped = line.strip()
|
||||
if stripped.startswith("WEB_DIRECTORY"):
|
||||
# Parse: WEB_DIRECTORY = "./web" or WEB_DIRECTORY = "web"
|
||||
_, _, value = stripped.partition("=")
|
||||
value = value.strip().strip("\"'")
|
||||
if value:
|
||||
web_dir_path = str((node_dir / value).resolve())
|
||||
if os.path.isdir(web_dir_path):
|
||||
nodes.EXTENSION_WEB_DIRS[extension_name] = web_dir_path
|
||||
logger.debug(
|
||||
"][ Registered web dir for isolated %s: %s",
|
||||
extension_name,
|
||||
web_dir_path,
|
||||
)
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _get_extension_type(execution_model: str) -> type[Any]:
|
||||
if execution_model == "sealed_worker":
|
||||
return pyisolate.SealedNodeExtension
|
||||
|
||||
from .extension_wrapper import ComfyNodeExtension
|
||||
|
||||
return ComfyNodeExtension
|
||||
|
||||
|
||||
async def _stop_extension_safe(extension: Any, extension_name: str) -> None:
|
||||
try:
|
||||
stop_result = extension.stop()
|
||||
if inspect.isawaitable(stop_result):
|
||||
await stop_result
|
||||
except Exception:
|
||||
logger.debug("][ %s stop failed", extension_name, exc_info=True)
|
||||
|
||||
|
||||
def _normalize_dependency_spec(dep: str, base_paths: list[Path]) -> str:
|
||||
req, sep, marker = dep.partition(";")
|
||||
req = req.strip()
|
||||
marker_suffix = f";{marker}" if sep else ""
|
||||
|
||||
def _resolve_local_path(local_path: str) -> Path | None:
|
||||
for base in base_paths:
|
||||
candidate = (base / local_path).resolve()
|
||||
if candidate.exists():
|
||||
return candidate
|
||||
return None
|
||||
|
||||
if req.startswith("./") or req.startswith("../"):
|
||||
resolved = _resolve_local_path(req)
|
||||
if resolved is not None:
|
||||
return f"{resolved}{marker_suffix}"
|
||||
|
||||
if req.startswith("file://"):
|
||||
raw = req[len("file://") :]
|
||||
if raw.startswith("./") or raw.startswith("../"):
|
||||
resolved = _resolve_local_path(raw)
|
||||
if resolved is not None:
|
||||
return f"file://{resolved}{marker_suffix}"
|
||||
|
||||
return dep
|
||||
|
||||
|
||||
def _dependency_name_from_spec(dep: str) -> str | None:
|
||||
stripped = dep.strip()
|
||||
if not stripped or stripped == "-e" or stripped.startswith("-e "):
|
||||
return None
|
||||
if stripped.startswith(("/", "./", "../", "file://")):
|
||||
return None
|
||||
|
||||
try:
|
||||
return canonicalize_name(Requirement(stripped).name)
|
||||
except InvalidRequirement:
|
||||
return None
|
||||
|
||||
|
||||
def _parse_cuda_wheels_config(
|
||||
tool_config: dict[str, object], dependencies: list[str]
|
||||
) -> dict[str, object] | None:
|
||||
raw_config = tool_config.get("cuda_wheels")
|
||||
if raw_config is None:
|
||||
return None
|
||||
if not isinstance(raw_config, dict):
|
||||
raise ExtensionLoadError("[tool.comfy.isolation.cuda_wheels] must be a table")
|
||||
|
||||
index_url = raw_config.get("index_url")
|
||||
index_urls = raw_config.get("index_urls")
|
||||
if index_urls is not None:
|
||||
if not isinstance(index_urls, list) or not all(
|
||||
isinstance(u, str) and u.strip() for u in index_urls
|
||||
):
|
||||
raise ExtensionLoadError(
|
||||
"[tool.comfy.isolation.cuda_wheels.index_urls] must be a list of non-empty strings"
|
||||
)
|
||||
elif not isinstance(index_url, str) or not index_url.strip():
|
||||
raise ExtensionLoadError(
|
||||
"[tool.comfy.isolation.cuda_wheels.index_url] must be a non-empty string"
|
||||
)
|
||||
|
||||
packages = raw_config.get("packages")
|
||||
if not isinstance(packages, list) or not all(
|
||||
isinstance(package_name, str) and package_name.strip()
|
||||
for package_name in packages
|
||||
):
|
||||
raise ExtensionLoadError(
|
||||
"[tool.comfy.isolation.cuda_wheels.packages] must be a list of non-empty strings"
|
||||
)
|
||||
|
||||
declared_dependencies = {
|
||||
dependency_name
|
||||
for dep in dependencies
|
||||
if (dependency_name := _dependency_name_from_spec(dep)) is not None
|
||||
}
|
||||
normalized_packages = [canonicalize_name(package_name) for package_name in packages]
|
||||
missing = [
|
||||
package_name
|
||||
for package_name in normalized_packages
|
||||
if package_name not in declared_dependencies
|
||||
]
|
||||
if missing:
|
||||
missing_joined = ", ".join(sorted(missing))
|
||||
raise ExtensionLoadError(
|
||||
"[tool.comfy.isolation.cuda_wheels.packages] references undeclared dependencies: "
|
||||
f"{missing_joined}"
|
||||
)
|
||||
|
||||
package_map = raw_config.get("package_map", {})
|
||||
if not isinstance(package_map, dict):
|
||||
raise ExtensionLoadError(
|
||||
"[tool.comfy.isolation.cuda_wheels.package_map] must be a table"
|
||||
)
|
||||
|
||||
normalized_package_map: dict[str, str] = {}
|
||||
for dependency_name, index_package_name in package_map.items():
|
||||
if not isinstance(dependency_name, str) or not dependency_name.strip():
|
||||
raise ExtensionLoadError(
|
||||
"[tool.comfy.isolation.cuda_wheels.package_map] keys must be non-empty strings"
|
||||
)
|
||||
if not isinstance(index_package_name, str) or not index_package_name.strip():
|
||||
raise ExtensionLoadError(
|
||||
"[tool.comfy.isolation.cuda_wheels.package_map] values must be non-empty strings"
|
||||
)
|
||||
canonical_dependency_name = canonicalize_name(dependency_name)
|
||||
if canonical_dependency_name not in normalized_packages:
|
||||
raise ExtensionLoadError(
|
||||
"[tool.comfy.isolation.cuda_wheels.package_map] can only override packages listed in "
|
||||
"[tool.comfy.isolation.cuda_wheels.packages]"
|
||||
)
|
||||
normalized_package_map[canonical_dependency_name] = index_package_name.strip()
|
||||
|
||||
result: dict = {
|
||||
"packages": normalized_packages,
|
||||
"package_map": normalized_package_map,
|
||||
}
|
||||
if index_urls is not None:
|
||||
result["index_urls"] = [u.rstrip("/") + "/" for u in index_urls]
|
||||
else:
|
||||
result["index_url"] = index_url.rstrip("/") + "/"
|
||||
return result
|
||||
|
||||
|
||||
def get_enforcement_policy() -> Dict[str, bool]:
|
||||
return {
|
||||
"force_isolated": os.environ.get("PYISOLATE_ENFORCE_ISOLATED") == "1",
|
||||
"force_sandbox": os.environ.get("PYISOLATE_ENFORCE_SANDBOX") == "1",
|
||||
}
|
||||
|
||||
|
||||
class ExtensionLoadError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def register_dummy_module(extension_name: str, node_dir: Path) -> None:
|
||||
normalized_name = extension_name.replace("-", "_").replace(".", "_")
|
||||
if normalized_name not in sys.modules:
|
||||
dummy_module = types.ModuleType(normalized_name)
|
||||
dummy_module.__file__ = str(node_dir / "__init__.py")
|
||||
dummy_module.__path__ = [str(node_dir)]
|
||||
dummy_module.__package__ = normalized_name
|
||||
sys.modules[normalized_name] = dummy_module
|
||||
|
||||
|
||||
def _is_stale_node_cache(cached_data: Dict[str, Dict]) -> bool:
|
||||
for details in cached_data.values():
|
||||
if not isinstance(details, dict):
|
||||
return True
|
||||
if details.get("is_v3") and "schema_v1" not in details:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def load_isolated_node(
|
||||
node_dir: Path,
|
||||
manifest_path: Path,
|
||||
logger: logging.Logger,
|
||||
build_stub_class: Callable[[str, Dict[str, object], Any], type],
|
||||
venv_root: Path,
|
||||
extension_managers: List[ExtensionManager],
|
||||
) -> List[Tuple[str, str, type]]:
|
||||
try:
|
||||
with manifest_path.open("rb") as handle:
|
||||
manifest_data = tomllib.load(handle)
|
||||
except Exception as e:
|
||||
logger.warning(f"][ Failed to parse {manifest_path}: {e}")
|
||||
return []
|
||||
|
||||
# Parse [tool.comfy.isolation]
|
||||
tool_config = manifest_data.get("tool", {}).get("comfy", {}).get("isolation", {})
|
||||
can_isolate = tool_config.get("can_isolate", False)
|
||||
share_torch = tool_config.get("share_torch", False)
|
||||
package_manager = tool_config.get("package_manager", "uv")
|
||||
is_conda = package_manager == "conda"
|
||||
execution_model = tool_config.get("execution_model")
|
||||
if execution_model is None:
|
||||
execution_model = "sealed_worker" if is_conda else "host-coupled"
|
||||
|
||||
if "sealed_host_ro_paths" in tool_config:
|
||||
raise ValueError(
|
||||
"Manifest field 'sealed_host_ro_paths' is not allowed. "
|
||||
"Configure [tool.comfy.host].sealed_worker_ro_import_paths in host policy."
|
||||
)
|
||||
|
||||
# Conda-specific manifest fields
|
||||
conda_channels: list[str] = (
|
||||
tool_config.get("conda_channels", []) if is_conda else []
|
||||
)
|
||||
conda_dependencies: list[str] = (
|
||||
tool_config.get("conda_dependencies", []) if is_conda else []
|
||||
)
|
||||
conda_platforms: list[str] = (
|
||||
tool_config.get("conda_platforms", []) if is_conda else []
|
||||
)
|
||||
conda_python: str = (
|
||||
tool_config.get("conda_python", "*") if is_conda else "*"
|
||||
)
|
||||
|
||||
# Parse [project] dependencies
|
||||
project_config = manifest_data.get("project", {})
|
||||
dependencies = project_config.get("dependencies", [])
|
||||
if not isinstance(dependencies, list):
|
||||
dependencies = []
|
||||
|
||||
# Get extension name (default to folder name if not in project.name)
|
||||
extension_name = project_config.get("name", node_dir.name)
|
||||
|
||||
# LOGIC: Isolation Decision
|
||||
policy = get_enforcement_policy()
|
||||
isolated = can_isolate or policy["force_isolated"]
|
||||
|
||||
if not isolated:
|
||||
return []
|
||||
|
||||
import folder_paths
|
||||
|
||||
base_paths = [Path(folder_paths.base_path), node_dir]
|
||||
dependencies = [
|
||||
_normalize_dependency_spec(dep, base_paths) if isinstance(dep, str) else dep
|
||||
for dep in dependencies
|
||||
]
|
||||
cuda_wheels = _parse_cuda_wheels_config(tool_config, dependencies)
|
||||
|
||||
manager_config = ExtensionManagerConfig(venv_root_path=str(venv_root))
|
||||
extension_type = _get_extension_type(execution_model)
|
||||
manager: ExtensionManager = pyisolate.ExtensionManager(
|
||||
extension_type, manager_config
|
||||
)
|
||||
extension_managers.append(manager)
|
||||
|
||||
host_policy = load_host_policy(Path(folder_paths.base_path))
|
||||
|
||||
sandbox_config = {}
|
||||
is_linux = platform.system() == "Linux"
|
||||
|
||||
if is_conda:
|
||||
share_torch = False
|
||||
share_cuda_ipc = False
|
||||
else:
|
||||
share_cuda_ipc = share_torch and is_linux
|
||||
|
||||
if is_linux and isolated:
|
||||
sandbox_config = {
|
||||
"network": host_policy["allow_network"],
|
||||
"writable_paths": host_policy["writable_paths"],
|
||||
"readonly_paths": host_policy["readonly_paths"],
|
||||
}
|
||||
|
||||
extension_config: dict = {
|
||||
"name": extension_name,
|
||||
"module_path": str(node_dir),
|
||||
"isolated": True,
|
||||
"dependencies": dependencies,
|
||||
"share_torch": share_torch,
|
||||
"share_cuda_ipc": share_cuda_ipc,
|
||||
"sandbox_mode": host_policy["sandbox_mode"],
|
||||
"sandbox": sandbox_config,
|
||||
}
|
||||
|
||||
share_torch_no_deps = tool_config.get("share_torch_no_deps", [])
|
||||
if share_torch_no_deps:
|
||||
if not isinstance(share_torch_no_deps, list) or not all(
|
||||
isinstance(dep, str) and dep.strip() for dep in share_torch_no_deps
|
||||
):
|
||||
raise ExtensionLoadError(
|
||||
"[tool.comfy.isolation.share_torch_no_deps] must be a list of non-empty strings"
|
||||
)
|
||||
extension_config["share_torch_no_deps"] = share_torch_no_deps
|
||||
|
||||
_is_sealed = execution_model == "sealed_worker"
|
||||
_is_sandboxed = host_policy["sandbox_mode"] != "disabled" and is_linux
|
||||
logger.info(
|
||||
"][ Loading isolated node: %s (torch_share [%s], sealed [%s], sandboxed [%s])",
|
||||
extension_name,
|
||||
"x" if share_torch else " ",
|
||||
"x" if _is_sealed else " ",
|
||||
"x" if _is_sandboxed else " ",
|
||||
)
|
||||
|
||||
if cuda_wheels is not None:
|
||||
extension_config["cuda_wheels"] = cuda_wheels
|
||||
|
||||
extra_index_urls = tool_config.get("extra_index_urls", [])
|
||||
if extra_index_urls:
|
||||
if not isinstance(extra_index_urls, list) or not all(
|
||||
isinstance(u, str) and u.strip() for u in extra_index_urls
|
||||
):
|
||||
raise ExtensionLoadError(
|
||||
"[tool.comfy.isolation.extra_index_urls] must be a list of non-empty strings"
|
||||
)
|
||||
extension_config["extra_index_urls"] = extra_index_urls
|
||||
|
||||
# Conda-specific keys
|
||||
if is_conda:
|
||||
extension_config["package_manager"] = "conda"
|
||||
extension_config["conda_channels"] = conda_channels
|
||||
extension_config["conda_dependencies"] = conda_dependencies
|
||||
extension_config["conda_python"] = conda_python
|
||||
find_links = tool_config.get("find_links", [])
|
||||
if find_links:
|
||||
extension_config["find_links"] = find_links
|
||||
if conda_platforms:
|
||||
extension_config["conda_platforms"] = conda_platforms
|
||||
|
||||
if execution_model != "host-coupled":
|
||||
extension_config["execution_model"] = execution_model
|
||||
if execution_model == "sealed_worker":
|
||||
policy_ro_paths = host_policy.get("sealed_worker_ro_import_paths", [])
|
||||
if isinstance(policy_ro_paths, list) and policy_ro_paths:
|
||||
extension_config["sealed_host_ro_paths"] = list(policy_ro_paths)
|
||||
# Sealed workers keep the host RPC service inventory even when the
|
||||
# child resolves no API classes locally.
|
||||
|
||||
extension = manager.load_extension(extension_config)
|
||||
register_dummy_module(extension_name, node_dir)
|
||||
|
||||
# Register host-side event handlers via adapter
|
||||
from .adapter import ComfyUIAdapter
|
||||
ComfyUIAdapter.register_host_event_handlers(extension)
|
||||
|
||||
# Register web directory on the host — only when sandbox is disabled.
|
||||
# In sandbox mode, serving untrusted JS to the browser is not safe.
|
||||
if host_policy["sandbox_mode"] == "disabled":
|
||||
_register_web_directory(extension_name, node_dir)
|
||||
|
||||
# Register for proxied web serving — the child's web dir may have
|
||||
# content that doesn't exist on the host (e.g., pip-installed viewer
|
||||
# bundles). The WebDirectoryCache will lazily fetch via RPC.
|
||||
from .proxies.web_directory_proxy import WebDirectoryProxy, get_web_directory_cache
|
||||
|
||||
class ChildWebDirectoryProxy:
|
||||
def __init__(self, host_extension):
|
||||
self._host_extension = host_extension
|
||||
self._caller = None
|
||||
|
||||
def _get_caller(self):
|
||||
self._host_extension.proxy
|
||||
rpc = self._host_extension._extension.rpc
|
||||
caller = rpc.create_caller(WebDirectoryProxy, WebDirectoryProxy.get_remote_id())
|
||||
if self._caller is not caller:
|
||||
self._caller = caller
|
||||
return self._caller
|
||||
|
||||
def list_web_files(self, extension_name: str):
|
||||
from .proxies.base import run_sync_rpc_coro
|
||||
return run_sync_rpc_coro(self._get_caller().list_web_files(extension_name))
|
||||
|
||||
def get_web_file(self, extension_name: str, relative_path: str):
|
||||
from .proxies.base import run_sync_rpc_coro
|
||||
return run_sync_rpc_coro(
|
||||
self._get_caller().get_web_file(extension_name, relative_path)
|
||||
)
|
||||
|
||||
cache = get_web_directory_cache()
|
||||
cache.register_proxy(extension_name, ChildWebDirectoryProxy(extension))
|
||||
|
||||
# Try cache first (lazy spawn)
|
||||
if is_cache_valid(node_dir, manifest_path, venv_root):
|
||||
cached_data = load_from_cache(node_dir, venv_root)
|
||||
if cached_data:
|
||||
if _is_stale_node_cache(cached_data):
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
flushed = await extension.flush_pending_routes()
|
||||
logger.info("][ %s flushed %d routes", extension_name, flushed)
|
||||
except Exception as exc:
|
||||
logger.warning("][ %s route flush failed: %s", extension_name, exc)
|
||||
specs: List[Tuple[str, str, type]] = []
|
||||
for node_name, details in cached_data.items():
|
||||
stub_cls = build_stub_class(node_name, details, extension)
|
||||
specs.append(
|
||||
(node_name, details.get("display_name", node_name), stub_cls)
|
||||
)
|
||||
return specs
|
||||
# Cache miss - spawn process and get metadata
|
||||
|
||||
try:
|
||||
remote_nodes: Dict[str, str] = await extension.list_nodes()
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"][ %s metadata discovery failed, skipping isolated load: %s",
|
||||
extension_name,
|
||||
exc,
|
||||
)
|
||||
await _stop_extension_safe(extension, extension_name)
|
||||
return []
|
||||
|
||||
if not remote_nodes:
|
||||
logger.debug("][ %s exposed no isolated nodes; skipping", extension_name)
|
||||
await _stop_extension_safe(extension, extension_name)
|
||||
return []
|
||||
|
||||
specs: List[Tuple[str, str, type]] = []
|
||||
cache_data: Dict[str, Dict] = {}
|
||||
|
||||
for node_name, display_name in remote_nodes.items():
|
||||
try:
|
||||
details = await extension.get_node_details(node_name)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"][ %s failed to load metadata for %s, skipping node: %s",
|
||||
extension_name,
|
||||
node_name,
|
||||
exc,
|
||||
)
|
||||
continue
|
||||
details["display_name"] = display_name
|
||||
cache_data[node_name] = details
|
||||
stub_cls = build_stub_class(node_name, details, extension)
|
||||
specs.append((node_name, display_name, stub_cls))
|
||||
|
||||
if not specs:
|
||||
logger.warning(
|
||||
"][ %s produced no usable nodes after metadata scan; skipping",
|
||||
extension_name,
|
||||
)
|
||||
await _stop_extension_safe(extension, extension_name)
|
||||
return []
|
||||
|
||||
# Save metadata to cache for future runs
|
||||
save_to_cache(node_dir, venv_root, cache_data, manifest_path)
|
||||
logger.debug(f"][ {extension_name} metadata cached")
|
||||
|
||||
# Re-check web directory AFTER child has populated it
|
||||
if host_policy["sandbox_mode"] == "disabled":
|
||||
_register_web_directory(extension_name, node_dir)
|
||||
|
||||
# Flush any routes the child buffered during module import — must happen
|
||||
# before router freeze and before we kill the child process.
|
||||
try:
|
||||
flushed = await extension.flush_pending_routes()
|
||||
logger.info("][ %s flushed %d routes", extension_name, flushed)
|
||||
except Exception as exc:
|
||||
logger.warning("][ %s route flush failed: %s", extension_name, exc)
|
||||
|
||||
# EJECT: Kill process after getting metadata (will respawn on first execution)
|
||||
await _stop_extension_safe(extension, extension_name)
|
||||
|
||||
return specs
|
||||
|
||||
|
||||
__all__ = ["ExtensionLoadError", "register_dummy_module", "load_isolated_node"]
|
||||
942
comfy/isolation/extension_wrapper.py
Normal file
942
comfy/isolation/extension_wrapper.py
Normal file
@ -0,0 +1,942 @@
|
||||
# pylint: disable=consider-using-from-import,cyclic-import,import-outside-toplevel,logging-fstring-interpolation,protected-access,wrong-import-position
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import torch
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
def __getattr__(self, item):
|
||||
try:
|
||||
return self[item]
|
||||
except KeyError as e:
|
||||
raise AttributeError(item) from e
|
||||
|
||||
def copy(self):
|
||||
return AttrDict(super().copy())
|
||||
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from dataclasses import asdict
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from pyisolate import ExtensionBase
|
||||
|
||||
from comfy_api.internal import _ComfyNodeInternal
|
||||
|
||||
LOG_PREFIX = "]["
|
||||
V3_DISCOVERY_TIMEOUT = 30
|
||||
_PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _run_prestartup_web_copy(module: Any, module_dir: str, web_dir_path: str) -> None:
|
||||
"""Run the web asset copy step that prestartup_script.py used to do.
|
||||
|
||||
If the module's web/ directory is empty and the module had a
|
||||
prestartup_script.py that copied assets from pip packages, this
|
||||
function replicates that work inside the child process.
|
||||
|
||||
Generic pattern: reads _PRESTARTUP_WEB_COPY from the module if
|
||||
defined, otherwise falls back to detecting common asset packages.
|
||||
"""
|
||||
import shutil
|
||||
|
||||
# Already populated — nothing to do
|
||||
if os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path)):
|
||||
return
|
||||
|
||||
os.makedirs(web_dir_path, exist_ok=True)
|
||||
|
||||
# Try module-defined copy spec first (generic hook for any node pack)
|
||||
copy_spec = getattr(module, "_PRESTARTUP_WEB_COPY", None)
|
||||
if copy_spec is not None and callable(copy_spec):
|
||||
try:
|
||||
copy_spec(web_dir_path)
|
||||
logger.info(
|
||||
"%s Ran _PRESTARTUP_WEB_COPY for %s", LOG_PREFIX, module_dir
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"%s _PRESTARTUP_WEB_COPY failed for %s: %s",
|
||||
LOG_PREFIX, module_dir, e,
|
||||
)
|
||||
|
||||
# Fallback: detect comfy_3d_viewers and run copy_viewer()
|
||||
try:
|
||||
from comfy_3d_viewers import copy_viewer, VIEWER_FILES
|
||||
viewers = list(VIEWER_FILES.keys())
|
||||
for viewer in viewers:
|
||||
try:
|
||||
copy_viewer(viewer, web_dir_path)
|
||||
except Exception:
|
||||
pass
|
||||
if any(os.scandir(web_dir_path)):
|
||||
logger.info(
|
||||
"%s Copied %d viewer types from comfy_3d_viewers to %s",
|
||||
LOG_PREFIX, len(viewers), web_dir_path,
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Fallback: detect comfy_dynamic_widgets
|
||||
try:
|
||||
from comfy_dynamic_widgets import get_js_path
|
||||
src = os.path.realpath(get_js_path())
|
||||
if os.path.exists(src):
|
||||
dst_dir = os.path.join(web_dir_path, "js")
|
||||
os.makedirs(dst_dir, exist_ok=True)
|
||||
dst = os.path.join(dst_dir, "dynamic_widgets.js")
|
||||
shutil.copy2(src, dst)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def _read_extension_name(module_dir: str) -> str:
|
||||
"""Read extension name from pyproject.toml, falling back to directory name."""
|
||||
pyproject = os.path.join(module_dir, "pyproject.toml")
|
||||
if os.path.exists(pyproject):
|
||||
try:
|
||||
import tomllib
|
||||
except ImportError:
|
||||
import tomli as tomllib # type: ignore[no-redef]
|
||||
try:
|
||||
with open(pyproject, "rb") as f:
|
||||
data = tomllib.load(f)
|
||||
name = data.get("project", {}).get("name")
|
||||
if name:
|
||||
return name
|
||||
except Exception:
|
||||
pass
|
||||
return os.path.basename(module_dir)
|
||||
|
||||
|
||||
def _flush_tensor_transport_state(marker: str) -> int:
|
||||
try:
|
||||
from pyisolate import flush_tensor_keeper # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
return 0
|
||||
if not callable(flush_tensor_keeper):
|
||||
return 0
|
||||
flushed = flush_tensor_keeper()
|
||||
if flushed > 0:
|
||||
logger.debug(
|
||||
"%s %s flush_tensor_keeper released=%d", LOG_PREFIX, marker, flushed
|
||||
)
|
||||
return flushed
|
||||
|
||||
|
||||
def _relieve_child_vram_pressure(marker: str) -> None:
|
||||
import comfy.model_management as model_management
|
||||
|
||||
model_management.cleanup_models_gc()
|
||||
model_management.cleanup_models()
|
||||
|
||||
device = model_management.get_torch_device()
|
||||
if not hasattr(device, "type") or device.type == "cpu":
|
||||
return
|
||||
|
||||
required = max(
|
||||
model_management.minimum_inference_memory(),
|
||||
_PRE_EXEC_MIN_FREE_VRAM_BYTES,
|
||||
)
|
||||
if model_management.get_free_memory(device) < required:
|
||||
model_management.free_memory(required, device, for_dynamic=True)
|
||||
if model_management.get_free_memory(device) < required:
|
||||
model_management.free_memory(required, device, for_dynamic=False)
|
||||
model_management.cleanup_models()
|
||||
model_management.soft_empty_cache()
|
||||
logger.debug("%s %s free_memory target=%d", LOG_PREFIX, marker, required)
|
||||
|
||||
|
||||
def _sanitize_for_transport(value):
|
||||
primitives = (str, int, float, bool, type(None))
|
||||
if isinstance(value, primitives):
|
||||
return value
|
||||
|
||||
cls_name = value.__class__.__name__
|
||||
if cls_name == "FlexibleOptionalInputType":
|
||||
return {
|
||||
"__pyisolate_flexible_optional__": True,
|
||||
"type": _sanitize_for_transport(getattr(value, "type", "*")),
|
||||
}
|
||||
if cls_name == "AnyType":
|
||||
return {"__pyisolate_any_type__": True, "value": str(value)}
|
||||
if cls_name == "ByPassTypeTuple":
|
||||
return {
|
||||
"__pyisolate_bypass_tuple__": [
|
||||
_sanitize_for_transport(v) for v in tuple(value)
|
||||
]
|
||||
}
|
||||
|
||||
if isinstance(value, dict):
|
||||
return {k: _sanitize_for_transport(v) for k, v in value.items()}
|
||||
if isinstance(value, tuple):
|
||||
return {"__pyisolate_tuple__": [_sanitize_for_transport(v) for v in value]}
|
||||
if isinstance(value, list):
|
||||
return [_sanitize_for_transport(v) for v in value]
|
||||
|
||||
return str(value)
|
||||
|
||||
|
||||
# Re-export RemoteObjectHandle from pyisolate for backward compatibility
|
||||
# The canonical definition is now in pyisolate._internal.remote_handle
|
||||
from pyisolate._internal.remote_handle import RemoteObjectHandle # noqa: E402,F401
|
||||
|
||||
|
||||
class ComfyNodeExtension(ExtensionBase):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.node_classes: Dict[str, type] = {}
|
||||
self.display_names: Dict[str, str] = {}
|
||||
self.node_instances: Dict[str, Any] = {}
|
||||
self.remote_objects: Dict[str, Any] = {}
|
||||
self._route_handlers: Dict[str, Any] = {}
|
||||
self._module: Any = None
|
||||
self._metadata_ready = asyncio.Event()
|
||||
|
||||
async def on_module_loaded(self, module: Any) -> None:
|
||||
try:
|
||||
self._module = module
|
||||
|
||||
# Registries are initialized in host_hooks.py initialize_host_process()
|
||||
# They auto-register via ProxiedSingleton when instantiated
|
||||
# NO additional setup required here - if a registry is missing from host_hooks, it WILL fail
|
||||
|
||||
self.node_classes = getattr(module, "NODE_CLASS_MAPPINGS", {}) or {}
|
||||
self.display_names = getattr(module, "NODE_DISPLAY_NAME_MAPPINGS", {}) or {}
|
||||
self._register_module_routes(module)
|
||||
|
||||
# Register web directory with WebDirectoryProxy (child-side)
|
||||
web_dir_attr = getattr(module, "WEB_DIRECTORY", None)
|
||||
if web_dir_attr is not None:
|
||||
module_dir = os.path.dirname(os.path.abspath(module.__file__))
|
||||
web_dir_path = os.path.abspath(os.path.join(module_dir, web_dir_attr))
|
||||
ext_name = _read_extension_name(module_dir)
|
||||
|
||||
# If web dir is empty, run the copy step that prestartup_script.py did
|
||||
_run_prestartup_web_copy(module, module_dir, web_dir_path)
|
||||
|
||||
if os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path)):
|
||||
from comfy.isolation.proxies.web_directory_proxy import WebDirectoryProxy
|
||||
WebDirectoryProxy.register_web_dir(ext_name, web_dir_path)
|
||||
|
||||
try:
|
||||
from comfy_api.latest import ComfyExtension
|
||||
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if not (
|
||||
inspect.isclass(obj)
|
||||
and issubclass(obj, ComfyExtension)
|
||||
and obj is not ComfyExtension
|
||||
):
|
||||
continue
|
||||
if not obj.__module__.startswith(module.__name__):
|
||||
continue
|
||||
try:
|
||||
ext_instance = obj()
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
ext_instance.on_load(), timeout=V3_DISCOVERY_TIMEOUT
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
"%s V3 Extension %s timed out in on_load()",
|
||||
LOG_PREFIX,
|
||||
name,
|
||||
)
|
||||
continue
|
||||
try:
|
||||
v3_nodes = await asyncio.wait_for(
|
||||
ext_instance.get_node_list(), timeout=V3_DISCOVERY_TIMEOUT
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
"%s V3 Extension %s timed out in get_node_list()",
|
||||
LOG_PREFIX,
|
||||
name,
|
||||
)
|
||||
continue
|
||||
for node_cls in v3_nodes:
|
||||
if hasattr(node_cls, "GET_SCHEMA"):
|
||||
schema = node_cls.GET_SCHEMA()
|
||||
self.node_classes[schema.node_id] = node_cls
|
||||
if schema.display_name:
|
||||
self.display_names[schema.node_id] = schema.display_name
|
||||
except Exception as e:
|
||||
logger.error("%s V3 Extension %s failed: %s", LOG_PREFIX, name, e)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
module_name = getattr(module, "__name__", "isolated_nodes")
|
||||
for node_cls in self.node_classes.values():
|
||||
if hasattr(node_cls, "__module__") and "/" in str(node_cls.__module__):
|
||||
node_cls.__module__ = module_name
|
||||
|
||||
self.node_instances = {}
|
||||
finally:
|
||||
self._metadata_ready.set()
|
||||
|
||||
def _register_module_routes(self, module: Any) -> None:
|
||||
"""Bridge legacy module-level ROUTES declarations into isolated routing."""
|
||||
routes = getattr(module, "ROUTES", None) or []
|
||||
if not routes:
|
||||
return
|
||||
|
||||
from comfy.isolation.proxies.prompt_server_impl import PromptServerStub
|
||||
|
||||
prompt_server = PromptServerStub()
|
||||
route_table = getattr(prompt_server, "routes", None)
|
||||
if route_table is None:
|
||||
logger.warning("%s Route registration unavailable for %s", LOG_PREFIX, module)
|
||||
return
|
||||
|
||||
for route_spec in routes:
|
||||
if not isinstance(route_spec, dict):
|
||||
logger.warning("%s Ignoring non-dict ROUTES entry: %r", LOG_PREFIX, route_spec)
|
||||
continue
|
||||
|
||||
method = str(route_spec.get("method", "")).strip().upper()
|
||||
path = str(route_spec.get("path", "")).strip()
|
||||
handler_ref = route_spec.get("handler")
|
||||
if not method or not path:
|
||||
logger.warning("%s Ignoring incomplete route spec: %r", LOG_PREFIX, route_spec)
|
||||
continue
|
||||
|
||||
if isinstance(handler_ref, str):
|
||||
handler = getattr(module, handler_ref, None)
|
||||
else:
|
||||
handler = handler_ref
|
||||
if not callable(handler):
|
||||
logger.warning(
|
||||
"%s Ignoring route with missing handler %r for %s %s",
|
||||
LOG_PREFIX,
|
||||
handler_ref,
|
||||
method,
|
||||
path,
|
||||
)
|
||||
continue
|
||||
|
||||
decorator = getattr(route_table, method.lower(), None)
|
||||
if not callable(decorator):
|
||||
logger.warning("%s Unsupported route method %s for %s", LOG_PREFIX, method, path)
|
||||
continue
|
||||
|
||||
decorator(path)(handler)
|
||||
self._route_handlers[f"{method} {path}"] = handler
|
||||
logger.info("%s buffered legacy route %s %s", LOG_PREFIX, method, path)
|
||||
|
||||
async def list_nodes(self) -> Dict[str, str]:
|
||||
await asyncio.wait_for(
|
||||
self._metadata_ready.wait(), timeout=V3_DISCOVERY_TIMEOUT
|
||||
)
|
||||
return {name: self.display_names.get(name, name) for name in self.node_classes}
|
||||
|
||||
async def get_node_info(self, node_name: str) -> Dict[str, Any]:
|
||||
return await self.get_node_details(node_name)
|
||||
|
||||
async def get_node_details(self, node_name: str) -> Dict[str, Any]:
|
||||
await asyncio.wait_for(
|
||||
self._metadata_ready.wait(), timeout=V3_DISCOVERY_TIMEOUT
|
||||
)
|
||||
node_cls = self._get_node_class(node_name)
|
||||
is_v3 = issubclass(node_cls, _ComfyNodeInternal)
|
||||
|
||||
input_types_raw = (
|
||||
node_cls.INPUT_TYPES() if hasattr(node_cls, "INPUT_TYPES") else {}
|
||||
)
|
||||
output_is_list = getattr(node_cls, "OUTPUT_IS_LIST", None)
|
||||
if output_is_list is not None:
|
||||
output_is_list = tuple(bool(x) for x in output_is_list)
|
||||
|
||||
details: Dict[str, Any] = {
|
||||
"input_types": _sanitize_for_transport(input_types_raw),
|
||||
"return_types": tuple(
|
||||
str(t) for t in getattr(node_cls, "RETURN_TYPES", ())
|
||||
),
|
||||
"return_names": getattr(node_cls, "RETURN_NAMES", None),
|
||||
"function": str(getattr(node_cls, "FUNCTION", "execute")),
|
||||
"category": str(getattr(node_cls, "CATEGORY", "")),
|
||||
"output_node": bool(getattr(node_cls, "OUTPUT_NODE", False)),
|
||||
"output_is_list": output_is_list,
|
||||
"is_v3": is_v3,
|
||||
}
|
||||
|
||||
if is_v3:
|
||||
try:
|
||||
schema = node_cls.GET_SCHEMA()
|
||||
schema_v1 = asdict(schema.get_v1_info(node_cls))
|
||||
try:
|
||||
schema_v3 = asdict(schema.get_v3_info(node_cls))
|
||||
except (AttributeError, TypeError):
|
||||
schema_v3 = self._build_schema_v3_fallback(schema)
|
||||
details.update(
|
||||
{
|
||||
"schema_v1": schema_v1,
|
||||
"schema_v3": schema_v3,
|
||||
"hidden": [h.value for h in (schema.hidden or [])],
|
||||
"description": getattr(schema, "description", ""),
|
||||
"deprecated": bool(getattr(node_cls, "DEPRECATED", False)),
|
||||
"experimental": bool(getattr(node_cls, "EXPERIMENTAL", False)),
|
||||
"api_node": bool(getattr(node_cls, "API_NODE", False)),
|
||||
"input_is_list": bool(
|
||||
getattr(node_cls, "INPUT_IS_LIST", False)
|
||||
),
|
||||
"not_idempotent": bool(
|
||||
getattr(node_cls, "NOT_IDEMPOTENT", False)
|
||||
),
|
||||
"accept_all_inputs": bool(
|
||||
getattr(node_cls, "ACCEPT_ALL_INPUTS", False)
|
||||
),
|
||||
}
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"%s V3 schema serialization failed for %s: %s",
|
||||
LOG_PREFIX,
|
||||
node_name,
|
||||
exc,
|
||||
)
|
||||
return details
|
||||
|
||||
def _build_schema_v3_fallback(self, schema) -> Dict[str, Any]:
|
||||
input_dict: Dict[str, Any] = {}
|
||||
output_dict: Dict[str, Any] = {}
|
||||
hidden_list: List[str] = []
|
||||
|
||||
if getattr(schema, "inputs", None):
|
||||
for inp in schema.inputs:
|
||||
self._add_schema_io_v3(inp, input_dict)
|
||||
if getattr(schema, "outputs", None):
|
||||
for out in schema.outputs:
|
||||
self._add_schema_io_v3(out, output_dict)
|
||||
if getattr(schema, "hidden", None):
|
||||
for h in schema.hidden:
|
||||
hidden_list.append(getattr(h, "value", str(h)))
|
||||
|
||||
return {
|
||||
"input": input_dict,
|
||||
"output": output_dict,
|
||||
"hidden": hidden_list,
|
||||
"name": getattr(schema, "node_id", None),
|
||||
"display_name": getattr(schema, "display_name", None),
|
||||
"description": getattr(schema, "description", None),
|
||||
"category": getattr(schema, "category", None),
|
||||
"output_node": getattr(schema, "is_output_node", False),
|
||||
"deprecated": getattr(schema, "is_deprecated", False),
|
||||
"experimental": getattr(schema, "is_experimental", False),
|
||||
"api_node": getattr(schema, "is_api_node", False),
|
||||
}
|
||||
|
||||
def _add_schema_io_v3(self, io_obj: Any, target: Dict[str, Any]) -> None:
|
||||
io_id = getattr(io_obj, "id", None)
|
||||
if io_id is None:
|
||||
return
|
||||
|
||||
io_type_fn = getattr(io_obj, "get_io_type", None)
|
||||
io_type = (
|
||||
io_type_fn() if callable(io_type_fn) else getattr(io_obj, "io_type", None)
|
||||
)
|
||||
|
||||
as_dict_fn = getattr(io_obj, "as_dict", None)
|
||||
payload = as_dict_fn() if callable(as_dict_fn) else {}
|
||||
|
||||
target[str(io_id)] = (io_type, payload)
|
||||
|
||||
async def get_input_types(self, node_name: str) -> Dict[str, Any]:
|
||||
node_cls = self._get_node_class(node_name)
|
||||
if hasattr(node_cls, "INPUT_TYPES"):
|
||||
return node_cls.INPUT_TYPES()
|
||||
return {}
|
||||
|
||||
async def execute_node(self, node_name: str, **inputs: Any) -> Tuple[Any, ...]:
|
||||
logger.debug(
|
||||
"%s ISO:child_execute_start ext=%s node=%s input_keys=%d",
|
||||
LOG_PREFIX,
|
||||
getattr(self, "name", "?"),
|
||||
node_name,
|
||||
len(inputs),
|
||||
)
|
||||
if os.environ.get("PYISOLATE_CHILD") == "1":
|
||||
_relieve_child_vram_pressure("EXT:pre_execute")
|
||||
|
||||
resolved_inputs = self._resolve_remote_objects(inputs)
|
||||
|
||||
instance = self._get_node_instance(node_name)
|
||||
node_cls = self._get_node_class(node_name)
|
||||
|
||||
# V3 API nodes expect hidden parameters in cls.hidden, not as kwargs
|
||||
# Hidden params come through RPC as string keys like "Hidden.prompt"
|
||||
from comfy_api.latest._io import Hidden, HiddenHolder
|
||||
|
||||
# Map string representations back to Hidden enum keys
|
||||
hidden_string_map = {
|
||||
"Hidden.unique_id": Hidden.unique_id,
|
||||
"Hidden.prompt": Hidden.prompt,
|
||||
"Hidden.extra_pnginfo": Hidden.extra_pnginfo,
|
||||
"Hidden.dynprompt": Hidden.dynprompt,
|
||||
"Hidden.auth_token_comfy_org": Hidden.auth_token_comfy_org,
|
||||
"Hidden.api_key_comfy_org": Hidden.api_key_comfy_org,
|
||||
# Uppercase enum VALUE forms — V3 execution engine passes these
|
||||
"UNIQUE_ID": Hidden.unique_id,
|
||||
"PROMPT": Hidden.prompt,
|
||||
"EXTRA_PNGINFO": Hidden.extra_pnginfo,
|
||||
"DYNPROMPT": Hidden.dynprompt,
|
||||
"AUTH_TOKEN_COMFY_ORG": Hidden.auth_token_comfy_org,
|
||||
"API_KEY_COMFY_ORG": Hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
# Find and extract hidden parameters (both enum and string form)
|
||||
hidden_found = {}
|
||||
keys_to_remove = []
|
||||
|
||||
for key in list(resolved_inputs.keys()):
|
||||
# Check string form first (from RPC serialization)
|
||||
if key in hidden_string_map:
|
||||
hidden_found[hidden_string_map[key]] = resolved_inputs[key]
|
||||
keys_to_remove.append(key)
|
||||
# Also check enum form (direct calls)
|
||||
elif isinstance(key, Hidden):
|
||||
hidden_found[key] = resolved_inputs[key]
|
||||
keys_to_remove.append(key)
|
||||
|
||||
# Remove hidden params from kwargs
|
||||
for key in keys_to_remove:
|
||||
resolved_inputs.pop(key)
|
||||
|
||||
# Set hidden on node class if any hidden params found
|
||||
if hidden_found:
|
||||
if not hasattr(node_cls, "hidden") or node_cls.hidden is None:
|
||||
node_cls.hidden = HiddenHolder.from_dict(hidden_found)
|
||||
else:
|
||||
# Update existing hidden holder
|
||||
for key, value in hidden_found.items():
|
||||
setattr(node_cls.hidden, key.value.lower(), value)
|
||||
|
||||
# INPUT_IS_LIST: ComfyUI's executor passes all inputs as lists when this
|
||||
# flag is set. The isolation RPC delivers unwrapped values, so we must
|
||||
# wrap each input in a single-element list to match the contract.
|
||||
if getattr(node_cls, "INPUT_IS_LIST", False):
|
||||
resolved_inputs = {k: [v] for k, v in resolved_inputs.items()}
|
||||
|
||||
function_name = getattr(node_cls, "FUNCTION", "execute")
|
||||
if not hasattr(instance, function_name):
|
||||
raise AttributeError(f"Node {node_name} missing callable '{function_name}'")
|
||||
|
||||
handler = getattr(instance, function_name)
|
||||
|
||||
try:
|
||||
import torch
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
with torch.inference_mode():
|
||||
result = await handler(**resolved_inputs)
|
||||
else:
|
||||
import functools
|
||||
|
||||
def _run_with_inference_mode(**kwargs):
|
||||
with torch.inference_mode():
|
||||
return handler(**kwargs)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None, functools.partial(_run_with_inference_mode, **resolved_inputs)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"%s ISO:child_execute_error ext=%s node=%s",
|
||||
LOG_PREFIX,
|
||||
getattr(self, "name", "?"),
|
||||
node_name,
|
||||
)
|
||||
raise
|
||||
|
||||
if type(result).__name__ == "NodeOutput":
|
||||
node_output_dict = {
|
||||
"__node_output__": True,
|
||||
"args": self._wrap_unpicklable_objects(result.args),
|
||||
}
|
||||
if result.ui is not None:
|
||||
node_output_dict["ui"] = self._wrap_unpicklable_objects(result.ui)
|
||||
if getattr(result, "expand", None) is not None:
|
||||
node_output_dict["expand"] = result.expand
|
||||
if getattr(result, "block_execution", None) is not None:
|
||||
node_output_dict["block_execution"] = result.block_execution
|
||||
return node_output_dict
|
||||
if self._is_comfy_protocol_return(result):
|
||||
wrapped = self._wrap_unpicklable_objects(result)
|
||||
return wrapped
|
||||
|
||||
if not isinstance(result, tuple):
|
||||
result = (result,)
|
||||
wrapped = self._wrap_unpicklable_objects(result)
|
||||
return wrapped
|
||||
|
||||
async def flush_pending_routes(self) -> int:
|
||||
"""Flush buffered route registrations to host via RPC. Called by host after node discovery."""
|
||||
from comfy.isolation.proxies.prompt_server_impl import PromptServerStub
|
||||
return await PromptServerStub.flush_child_routes()
|
||||
|
||||
async def flush_transport_state(self) -> int:
|
||||
if os.environ.get("PYISOLATE_CHILD") != "1":
|
||||
return 0
|
||||
logger.debug(
|
||||
"%s ISO:child_flush_start ext=%s", LOG_PREFIX, getattr(self, "name", "?")
|
||||
)
|
||||
flushed = _flush_tensor_transport_state("EXT:workflow_end")
|
||||
try:
|
||||
from comfy.isolation.model_patcher_proxy_registry import (
|
||||
ModelPatcherRegistry,
|
||||
)
|
||||
|
||||
registry = ModelPatcherRegistry()
|
||||
removed = registry.sweep_pending_cleanup()
|
||||
if removed > 0:
|
||||
logger.debug(
|
||||
"%s EXT:workflow_end registry sweep removed=%d", LOG_PREFIX, removed
|
||||
)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"%s EXT:workflow_end registry sweep failed", LOG_PREFIX, exc_info=True
|
||||
)
|
||||
logger.debug(
|
||||
"%s ISO:child_flush_done ext=%s flushed=%d",
|
||||
LOG_PREFIX,
|
||||
getattr(self, "name", "?"),
|
||||
flushed,
|
||||
)
|
||||
return flushed
|
||||
|
||||
async def get_remote_object(self, object_id: str) -> Any:
|
||||
"""Retrieve a remote object by ID for host-side deserialization."""
|
||||
if object_id not in self.remote_objects:
|
||||
raise KeyError(f"Remote object {object_id} not found")
|
||||
|
||||
return self.remote_objects[object_id]
|
||||
|
||||
def _store_remote_object_handle(self, obj: Any) -> RemoteObjectHandle:
|
||||
object_id = str(uuid.uuid4())
|
||||
self.remote_objects[object_id] = obj
|
||||
return RemoteObjectHandle(object_id, type(obj).__name__)
|
||||
|
||||
async def call_remote_object_method(
|
||||
self,
|
||||
object_id: str,
|
||||
method_name: str,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Invoke a method or attribute-backed accessor on a child-owned object."""
|
||||
obj = await self.get_remote_object(object_id)
|
||||
|
||||
if method_name == "get_patcher_attr":
|
||||
return getattr(obj, args[0])
|
||||
if method_name == "get_model_options":
|
||||
return getattr(obj, "model_options")
|
||||
if method_name == "set_model_options":
|
||||
setattr(obj, "model_options", args[0])
|
||||
return None
|
||||
if method_name == "get_object_patches":
|
||||
return getattr(obj, "object_patches")
|
||||
if method_name == "get_patches":
|
||||
return getattr(obj, "patches")
|
||||
if method_name == "get_wrappers":
|
||||
return getattr(obj, "wrappers")
|
||||
if method_name == "get_callbacks":
|
||||
return getattr(obj, "callbacks")
|
||||
if method_name == "get_load_device":
|
||||
return getattr(obj, "load_device")
|
||||
if method_name == "get_offload_device":
|
||||
return getattr(obj, "offload_device")
|
||||
if method_name == "get_hook_mode":
|
||||
return getattr(obj, "hook_mode")
|
||||
if method_name == "get_parent":
|
||||
parent = getattr(obj, "parent", None)
|
||||
if parent is None:
|
||||
return None
|
||||
return self._store_remote_object_handle(parent)
|
||||
if method_name == "get_inner_model_attr":
|
||||
attr_name = args[0]
|
||||
if hasattr(obj.model, attr_name):
|
||||
return getattr(obj.model, attr_name)
|
||||
if hasattr(obj, attr_name):
|
||||
return getattr(obj, attr_name)
|
||||
return None
|
||||
if method_name == "inner_model_apply_model":
|
||||
return obj.model.apply_model(*args[0], **args[1])
|
||||
if method_name == "inner_model_extra_conds_shapes":
|
||||
return obj.model.extra_conds_shapes(*args[0], **args[1])
|
||||
if method_name == "inner_model_extra_conds":
|
||||
return obj.model.extra_conds(*args[0], **args[1])
|
||||
if method_name == "inner_model_memory_required":
|
||||
return obj.model.memory_required(*args[0], **args[1])
|
||||
if method_name == "process_latent_in":
|
||||
return obj.model.process_latent_in(*args[0], **args[1])
|
||||
if method_name == "process_latent_out":
|
||||
return obj.model.process_latent_out(*args[0], **args[1])
|
||||
if method_name == "scale_latent_inpaint":
|
||||
return obj.model.scale_latent_inpaint(*args[0], **args[1])
|
||||
if method_name.startswith("get_"):
|
||||
attr_name = method_name[4:]
|
||||
if hasattr(obj, attr_name):
|
||||
return getattr(obj, attr_name)
|
||||
|
||||
target = getattr(obj, method_name)
|
||||
if callable(target):
|
||||
result = target(*args, **kwargs)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
if type(result).__name__ == "ModelPatcher":
|
||||
return self._store_remote_object_handle(result)
|
||||
return result
|
||||
if args or kwargs:
|
||||
raise TypeError(f"{method_name} is not callable on remote object {object_id}")
|
||||
return target
|
||||
|
||||
def _wrap_unpicklable_objects(self, data: Any) -> Any:
|
||||
if isinstance(data, (str, int, float, bool, type(None))):
|
||||
return data
|
||||
if isinstance(data, torch.Tensor):
|
||||
tensor = data.detach() if data.requires_grad else data
|
||||
if os.environ.get("PYISOLATE_CHILD") == "1" and tensor.device.type != "cpu":
|
||||
return tensor.cpu()
|
||||
return tensor
|
||||
|
||||
# Special-case clip vision outputs: preserve attribute access by packing fields
|
||||
if hasattr(data, "penultimate_hidden_states") or hasattr(
|
||||
data, "last_hidden_state"
|
||||
):
|
||||
fields = {}
|
||||
for attr in (
|
||||
"penultimate_hidden_states",
|
||||
"last_hidden_state",
|
||||
"image_embeds",
|
||||
"text_embeds",
|
||||
):
|
||||
if hasattr(data, attr):
|
||||
try:
|
||||
fields[attr] = self._wrap_unpicklable_objects(
|
||||
getattr(data, attr)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
if fields:
|
||||
return {"__pyisolate_attribute_container__": True, "data": fields}
|
||||
|
||||
# Avoid converting arbitrary objects with stateful methods (models, etc.)
|
||||
# They will be handled via RemoteObjectHandle below.
|
||||
|
||||
type_name = type(data).__name__
|
||||
if type_name == "ModelPatcherProxy":
|
||||
return {"__type__": "ModelPatcherRef", "model_id": data._instance_id}
|
||||
if type_name == "CLIPProxy":
|
||||
return {"__type__": "CLIPRef", "clip_id": data._instance_id}
|
||||
if type_name == "VAEProxy":
|
||||
return {"__type__": "VAERef", "vae_id": data._instance_id}
|
||||
if type_name == "ModelSamplingProxy":
|
||||
return {"__type__": "ModelSamplingRef", "ms_id": data._instance_id}
|
||||
|
||||
if isinstance(data, (list, tuple)):
|
||||
wrapped = [self._wrap_unpicklable_objects(item) for item in data]
|
||||
return tuple(wrapped) if isinstance(data, tuple) else wrapped
|
||||
if isinstance(data, dict):
|
||||
converted_dict = {
|
||||
k: self._wrap_unpicklable_objects(v) for k, v in data.items()
|
||||
}
|
||||
return {"__pyisolate_attrdict__": True, "data": converted_dict}
|
||||
|
||||
from pyisolate._internal.serialization_registry import SerializerRegistry
|
||||
|
||||
registry = SerializerRegistry.get_instance()
|
||||
if registry.is_data_type(type_name):
|
||||
serializer = registry.get_serializer(type_name)
|
||||
if serializer:
|
||||
return serializer(data)
|
||||
|
||||
return self._store_remote_object_handle(data)
|
||||
|
||||
def _resolve_remote_objects(self, data: Any) -> Any:
|
||||
if isinstance(data, RemoteObjectHandle):
|
||||
if data.object_id not in self.remote_objects:
|
||||
raise KeyError(f"Remote object {data.object_id} not found")
|
||||
return self.remote_objects[data.object_id]
|
||||
|
||||
if isinstance(data, dict):
|
||||
ref_type = data.get("__type__")
|
||||
if ref_type in ("CLIPRef", "ModelPatcherRef", "VAERef"):
|
||||
from pyisolate._internal.model_serialization import (
|
||||
deserialize_proxy_result,
|
||||
)
|
||||
|
||||
return deserialize_proxy_result(data)
|
||||
if ref_type == "ModelSamplingRef":
|
||||
from pyisolate._internal.model_serialization import (
|
||||
deserialize_proxy_result,
|
||||
)
|
||||
|
||||
return deserialize_proxy_result(data)
|
||||
return {k: self._resolve_remote_objects(v) for k, v in data.items()}
|
||||
|
||||
if isinstance(data, (list, tuple)):
|
||||
resolved = [self._resolve_remote_objects(item) for item in data]
|
||||
return tuple(resolved) if isinstance(data, tuple) else resolved
|
||||
return data
|
||||
|
||||
def _get_node_class(self, node_name: str) -> type:
|
||||
if node_name not in self.node_classes:
|
||||
raise KeyError(f"Unknown node: {node_name}")
|
||||
return self.node_classes[node_name]
|
||||
|
||||
def _get_node_instance(self, node_name: str) -> Any:
|
||||
if node_name not in self.node_instances:
|
||||
if node_name not in self.node_classes:
|
||||
raise KeyError(f"Unknown node: {node_name}")
|
||||
self.node_instances[node_name] = self.node_classes[node_name]()
|
||||
return self.node_instances[node_name]
|
||||
|
||||
async def before_module_loaded(self) -> None:
|
||||
try:
|
||||
from comfy.isolation import initialize_proxies
|
||||
|
||||
initialize_proxies()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"%s before_module_loaded initialize_proxies FAILED: %s", LOG_PREFIX, e
|
||||
)
|
||||
|
||||
await super().before_module_loaded()
|
||||
try:
|
||||
from comfy_api.latest import ComfyAPI_latest
|
||||
from .proxies.progress_proxy import ProgressProxy
|
||||
|
||||
ComfyAPI_latest.Execution = ProgressProxy
|
||||
# ComfyAPI_latest.execution = ProgressProxy() # Eliminated to avoid Singleton collision
|
||||
# fp_proxy = FolderPathsProxy() # Eliminated to avoid Singleton collision
|
||||
# latest_ui.folder_paths = fp_proxy
|
||||
# latest_resources.folder_paths = fp_proxy
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def call_route_handler(
|
||||
self,
|
||||
handler_module: str,
|
||||
handler_func: str,
|
||||
request_data: Dict[str, Any],
|
||||
) -> Any:
|
||||
cache_key = f"{handler_module}.{handler_func}"
|
||||
if cache_key not in self._route_handlers:
|
||||
if self._module is not None and hasattr(self._module, "__file__"):
|
||||
node_dir = os.path.dirname(self._module.__file__)
|
||||
if node_dir not in sys.path:
|
||||
sys.path.insert(0, node_dir)
|
||||
try:
|
||||
module = importlib.import_module(handler_module)
|
||||
self._route_handlers[cache_key] = getattr(module, handler_func)
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise ValueError(f"Route handler not found: {cache_key}") from e
|
||||
|
||||
handler = self._route_handlers[cache_key]
|
||||
mock_request = MockRequest(request_data)
|
||||
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
result = await handler(mock_request)
|
||||
else:
|
||||
result = handler(mock_request)
|
||||
return self._serialize_response(result)
|
||||
|
||||
def _is_comfy_protocol_return(self, result: Any) -> bool:
|
||||
"""
|
||||
Check if the result matches the ComfyUI 'Protocol Return' schema.
|
||||
|
||||
A Protocol Return is a dictionary containing specific reserved keys that
|
||||
ComfyUI's execution engine interprets as instructions (UI updates,
|
||||
Workflow expansion, etc.) rather than purely data outputs.
|
||||
|
||||
Schema:
|
||||
- Must be a dict
|
||||
- Must contain at least one of: 'ui', 'result', 'expand'
|
||||
"""
|
||||
if not isinstance(result, dict):
|
||||
return False
|
||||
return any(key in result for key in ("ui", "result", "expand"))
|
||||
|
||||
def _serialize_response(self, response: Any) -> Dict[str, Any]:
|
||||
if response is None:
|
||||
return {"type": "text", "body": "", "status": 204}
|
||||
if isinstance(response, dict):
|
||||
return {"type": "json", "body": response, "status": 200}
|
||||
if isinstance(response, str):
|
||||
return {"type": "text", "body": response, "status": 200}
|
||||
if hasattr(response, "text") and hasattr(response, "status"):
|
||||
return {
|
||||
"type": "text",
|
||||
"body": response.text
|
||||
if hasattr(response, "text")
|
||||
else str(response.body),
|
||||
"status": response.status,
|
||||
"headers": dict(response.headers)
|
||||
if hasattr(response, "headers")
|
||||
else {},
|
||||
}
|
||||
if hasattr(response, "body") and hasattr(response, "status"):
|
||||
body = response.body
|
||||
if isinstance(body, bytes):
|
||||
try:
|
||||
return {
|
||||
"type": "text",
|
||||
"body": body.decode("utf-8"),
|
||||
"status": response.status,
|
||||
}
|
||||
except UnicodeDecodeError:
|
||||
return {
|
||||
"type": "binary",
|
||||
"body": body.hex(),
|
||||
"status": response.status,
|
||||
}
|
||||
return {"type": "json", "body": body, "status": response.status}
|
||||
return {"type": "text", "body": str(response), "status": 200}
|
||||
|
||||
|
||||
class MockRequest:
|
||||
def __init__(self, data: Dict[str, Any]):
|
||||
self.method = data.get("method", "GET")
|
||||
self.path = data.get("path", "/")
|
||||
self.query = data.get("query", {})
|
||||
self._body = data.get("body", {})
|
||||
self._text = data.get("text", "")
|
||||
self.headers = data.get("headers", {})
|
||||
self.content_type = data.get(
|
||||
"content_type", self.headers.get("Content-Type", "application/json")
|
||||
)
|
||||
self.match_info = data.get("match_info", {})
|
||||
|
||||
async def json(self) -> Any:
|
||||
if isinstance(self._body, dict):
|
||||
return self._body
|
||||
if isinstance(self._body, str):
|
||||
return json.loads(self._body)
|
||||
return {}
|
||||
|
||||
async def post(self) -> Dict[str, Any]:
|
||||
if isinstance(self._body, dict):
|
||||
return self._body
|
||||
return {}
|
||||
|
||||
async def text(self) -> str:
|
||||
if self._text:
|
||||
return self._text
|
||||
if isinstance(self._body, str):
|
||||
return self._body
|
||||
if isinstance(self._body, dict):
|
||||
return json.dumps(self._body)
|
||||
return ""
|
||||
|
||||
async def read(self) -> bytes:
|
||||
return (await self.text()).encode("utf-8")
|
||||
161
execution.py
161
execution.py
@ -1,7 +1,9 @@
|
||||
import copy
|
||||
import gc
|
||||
import heapq
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
@ -42,6 +44,8 @@ from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_re
|
||||
from comfy_api.latest import io, _io
|
||||
from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger
|
||||
|
||||
_AIMDO_VBAR_RESET_UNAVAILABLE_LOGGED = False
|
||||
|
||||
|
||||
class ExecutionResult(Enum):
|
||||
SUCCESS = 0
|
||||
@ -262,20 +266,31 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
|
||||
pre_execute_cb(index)
|
||||
# V3
|
||||
if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)):
|
||||
# if is just a class, then assign no state, just create clone
|
||||
if is_class(obj):
|
||||
type_obj = obj
|
||||
obj.VALIDATE_CLASS()
|
||||
class_clone = obj.PREPARE_CLASS_CLONE(v3_data)
|
||||
# otherwise, use class instance to populate/reuse some fields
|
||||
# Check for isolated node - skip validation and class cloning
|
||||
if hasattr(obj, "_pyisolate_extension"):
|
||||
# Isolated Node: The stub is just a proxy; real validation happens in child process
|
||||
if v3_data is not None:
|
||||
inputs = _io.build_nested_inputs(inputs, v3_data)
|
||||
# Inject hidden inputs so they're available in the isolated child process
|
||||
inputs.update(v3_data.get("hidden_inputs", {}))
|
||||
f = getattr(obj, func)
|
||||
# Standard V3 Node (Existing Logic)
|
||||
|
||||
else:
|
||||
type_obj = type(obj)
|
||||
type_obj.VALIDATE_CLASS()
|
||||
class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data)
|
||||
f = make_locked_method_func(type_obj, func, class_clone)
|
||||
# in case of dynamic inputs, restructure inputs to expected nested dict
|
||||
if v3_data is not None:
|
||||
inputs = _io.build_nested_inputs(inputs, v3_data)
|
||||
# if is just a class, then assign no resources or state, just create clone
|
||||
if is_class(obj):
|
||||
type_obj = obj
|
||||
obj.VALIDATE_CLASS()
|
||||
class_clone = obj.PREPARE_CLASS_CLONE(v3_data)
|
||||
# otherwise, use class instance to populate/reuse some fields
|
||||
else:
|
||||
type_obj = type(obj)
|
||||
type_obj.VALIDATE_CLASS()
|
||||
class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data)
|
||||
f = make_locked_method_func(type_obj, func, class_clone)
|
||||
# in case of dynamic inputs, restructure inputs to expected nested dict
|
||||
if v3_data is not None:
|
||||
inputs = _io.build_nested_inputs(inputs, v3_data)
|
||||
# V1
|
||||
else:
|
||||
f = getattr(obj, func)
|
||||
@ -537,7 +552,17 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
if args.verbose == "DEBUG":
|
||||
comfy_aimdo.control.analyze()
|
||||
comfy.model_management.reset_cast_buffers()
|
||||
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
|
||||
vbar_lib = getattr(comfy_aimdo.model_vbar, "lib", None)
|
||||
if vbar_lib is not None:
|
||||
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
|
||||
else:
|
||||
global _AIMDO_VBAR_RESET_UNAVAILABLE_LOGGED
|
||||
if not _AIMDO_VBAR_RESET_UNAVAILABLE_LOGGED:
|
||||
logging.warning(
|
||||
"DynamicVRAM backend unavailable for watermark reset; "
|
||||
"skipping vbar reset for this process."
|
||||
)
|
||||
_AIMDO_VBAR_RESET_UNAVAILABLE_LOGGED = True
|
||||
|
||||
if has_pending_tasks:
|
||||
pending_async_nodes[unique_id] = output_data
|
||||
@ -546,8 +571,29 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
tasks = [x for x in output_data if isinstance(x, asyncio.Task)]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
unblock()
|
||||
asyncio.create_task(await_completion())
|
||||
return (ExecutionResult.PENDING, None, None)
|
||||
|
||||
# Keep isolation node execution deterministic by default, but allow
|
||||
# opt-out for diagnostics.
|
||||
isolation_sequential = os.environ.get("COMFY_ISOLATE_SEQUENTIAL", "1").lower() in ("1", "true", "yes")
|
||||
if args.use_process_isolation and isolation_sequential:
|
||||
await await_completion()
|
||||
results = []
|
||||
for r in pending_async_nodes[unique_id]:
|
||||
if isinstance(r, asyncio.Task):
|
||||
try:
|
||||
results.append(r.result())
|
||||
except Exception as ex:
|
||||
del pending_async_nodes[unique_id]
|
||||
raise ex
|
||||
else:
|
||||
results.append(r)
|
||||
del pending_async_nodes[unique_id]
|
||||
output_data, output_ui, has_subgraph = get_output_from_returns(results, class_def)
|
||||
has_pending_tasks = False
|
||||
|
||||
else:
|
||||
asyncio.create_task(await_completion())
|
||||
return (ExecutionResult.PENDING, None, None)
|
||||
if len(output_ui) > 0:
|
||||
ui_outputs[unique_id] = {
|
||||
"meta": {
|
||||
@ -657,6 +703,46 @@ class PromptExecutor:
|
||||
self.status_messages = []
|
||||
self.success = True
|
||||
|
||||
async def _notify_execution_graph_safe(self, class_types: set[str], *, fail_loud: bool = False) -> None:
|
||||
if not args.use_process_isolation:
|
||||
return
|
||||
try:
|
||||
from comfy.isolation import notify_execution_graph
|
||||
await notify_execution_graph(class_types, caches=self.caches.all)
|
||||
except Exception:
|
||||
if fail_loud:
|
||||
raise
|
||||
logging.debug("][ EX:notify_execution_graph failed", exc_info=True)
|
||||
|
||||
async def _flush_running_extensions_transport_state_safe(self) -> None:
|
||||
if not args.use_process_isolation:
|
||||
return
|
||||
try:
|
||||
from comfy.isolation import flush_running_extensions_transport_state
|
||||
await flush_running_extensions_transport_state()
|
||||
except Exception:
|
||||
logging.debug("][ EX:flush_running_extensions_transport_state failed", exc_info=True)
|
||||
|
||||
async def _wait_model_patcher_quiescence_safe(
|
||||
self,
|
||||
*,
|
||||
fail_loud: bool = False,
|
||||
timeout_ms: int = 120000,
|
||||
marker: str = "EX:wait_model_patcher_idle",
|
||||
) -> None:
|
||||
if not args.use_process_isolation:
|
||||
return
|
||||
try:
|
||||
from comfy.isolation import wait_for_model_patcher_quiescence
|
||||
|
||||
await wait_for_model_patcher_quiescence(
|
||||
timeout_ms=timeout_ms, fail_loud=fail_loud, marker=marker
|
||||
)
|
||||
except Exception:
|
||||
if fail_loud:
|
||||
raise
|
||||
logging.debug("][ EX:wait_model_patcher_quiescence failed", exc_info=True)
|
||||
|
||||
def add_message(self, event, data: dict, broadcast: bool):
|
||||
data = {
|
||||
**data,
|
||||
@ -711,6 +797,18 @@ class PromptExecutor:
|
||||
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
|
||||
|
||||
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||
if args.use_process_isolation:
|
||||
# Update RPC event loops for all isolated extensions.
|
||||
# This is critical for serial workflow execution - each asyncio.run() creates
|
||||
# a new event loop, and RPC instances must be updated to use it.
|
||||
try:
|
||||
from comfy.isolation import update_rpc_event_loops
|
||||
update_rpc_event_loops()
|
||||
except ImportError:
|
||||
pass # Isolation not available
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"Failed to update RPC event loops: {e}")
|
||||
|
||||
set_preview_method(extra_data.get("preview_method"))
|
||||
|
||||
nodes.interrupt_processing(False)
|
||||
@ -723,6 +821,25 @@ class PromptExecutor:
|
||||
self.status_messages = []
|
||||
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
|
||||
|
||||
if args.use_process_isolation:
|
||||
try:
|
||||
# Boundary cleanup runs at the start of the next workflow in
|
||||
# isolation mode, matching non-isolated "next prompt" timing.
|
||||
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
|
||||
await self._wait_model_patcher_quiescence_safe(
|
||||
fail_loud=False,
|
||||
timeout_ms=120000,
|
||||
marker="EX:boundary_cleanup_wait_idle",
|
||||
)
|
||||
await self._flush_running_extensions_transport_state_safe()
|
||||
comfy.model_management.unload_all_models()
|
||||
comfy.model_management.cleanup_models_gc()
|
||||
comfy.model_management.cleanup_models()
|
||||
gc.collect()
|
||||
comfy.model_management.soft_empty_cache()
|
||||
except Exception:
|
||||
logging.debug("][ EX:isolation_boundary_cleanup_start failed", exc_info=True)
|
||||
|
||||
self._notify_prompt_lifecycle("start", prompt_id)
|
||||
ram_headroom = int(self.cache_args["ram"] * (1024 ** 3))
|
||||
ram_release_callback = self.caches.outputs.ram_release if self.cache_type == CacheType.RAM_PRESSURE else None
|
||||
@ -760,6 +877,18 @@ class PromptExecutor:
|
||||
for node_id in list(execute_outputs):
|
||||
execution_list.add_node(node_id)
|
||||
|
||||
if args.use_process_isolation:
|
||||
pending_class_types = set()
|
||||
for node_id in execution_list.pendingNodes.keys():
|
||||
class_type = dynamic_prompt.get_node(node_id)["class_type"]
|
||||
pending_class_types.add(class_type)
|
||||
await self._wait_model_patcher_quiescence_safe(
|
||||
fail_loud=True,
|
||||
timeout_ms=120000,
|
||||
marker="EX:notify_graph_wait_idle",
|
||||
)
|
||||
await self._notify_execution_graph_safe(pending_class_types, fail_loud=True)
|
||||
|
||||
while not execution_list.is_empty():
|
||||
node_id, error, ex = await execution_list.stage_node_execution()
|
||||
if error is not None:
|
||||
|
||||
120
main.py
120
main.py
@ -1,7 +1,27 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
IS_PYISOLATE_CHILD = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
|
||||
if __name__ == "__main__" and IS_PYISOLATE_CHILD:
|
||||
del os.environ["PYISOLATE_CHILD"]
|
||||
IS_PYISOLATE_CHILD = False
|
||||
|
||||
CURRENT_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
if CURRENT_DIR not in sys.path:
|
||||
sys.path.insert(0, CURRENT_DIR)
|
||||
|
||||
if not IS_PYISOLATE_CHILD:
|
||||
python_scripts_dir = os.path.dirname(os.path.realpath(sys.executable))
|
||||
path_entries = os.environ.get("PATH", "").split(os.pathsep)
|
||||
if python_scripts_dir and python_scripts_dir not in path_entries:
|
||||
os.environ["PATH"] = os.pathsep.join([python_scripts_dir, *path_entries])
|
||||
|
||||
IS_PRIMARY_PROCESS = (not IS_PYISOLATE_CHILD) and __name__ == "__main__"
|
||||
|
||||
import comfy.options
|
||||
comfy.options.enable_args_parsing()
|
||||
|
||||
import os
|
||||
import importlib.util
|
||||
import shutil
|
||||
import importlib.metadata
|
||||
@ -9,26 +29,56 @@ import folder_paths
|
||||
import time
|
||||
from comfy.cli_args import args, enables_dynamic_vram
|
||||
from app.logger import setup_logger
|
||||
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
||||
|
||||
from app.assets.seeder import asset_seeder
|
||||
from app.assets.services import register_output_files
|
||||
import itertools
|
||||
import utils.extra_config
|
||||
import utils.extra_config # noqa: F401
|
||||
from utils.mime_types import init_mime_types
|
||||
import faulthandler
|
||||
import logging
|
||||
import sys
|
||||
from comfy_execution.progress import get_progress_state
|
||||
from comfy_execution.utils import get_executing_context
|
||||
from comfy_api import feature_flags
|
||||
from app.database.db import init_db, dependencies_available
|
||||
|
||||
if __name__ == "__main__":
|
||||
#NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
|
||||
import comfy_aimdo.control
|
||||
|
||||
if enables_dynamic_vram():
|
||||
if not comfy_aimdo.control.init():
|
||||
logging.warning(
|
||||
"DynamicVRAM requested, but comfy-aimdo failed to initialize early. "
|
||||
"Will fall back to legacy model loading if device init fails."
|
||||
)
|
||||
|
||||
if args.use_process_isolation:
|
||||
from comfy.isolation import initialize_proxies
|
||||
initialize_proxies()
|
||||
|
||||
# Explicitly register the ComfyUI adapter for pyisolate (v1.0 architecture)
|
||||
try:
|
||||
import pyisolate
|
||||
from comfy.isolation.adapter import ComfyUIAdapter
|
||||
pyisolate.register_adapter(ComfyUIAdapter())
|
||||
logging.info("PyIsolate adapter registered: comfyui")
|
||||
except ImportError:
|
||||
logging.warning("PyIsolate not installed or version too old for explicit registration")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to register PyIsolate adapter: {e}")
|
||||
|
||||
if not IS_PYISOLATE_CHILD:
|
||||
if 'PYTORCH_CUDA_ALLOC_CONF' not in os.environ:
|
||||
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'backend:native'
|
||||
|
||||
if not IS_PYISOLATE_CHILD:
|
||||
from comfy_execution.progress import get_progress_state
|
||||
from comfy_execution.utils import get_executing_context
|
||||
from comfy_api import feature_flags
|
||||
|
||||
if IS_PRIMARY_PROCESS:
|
||||
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
||||
os.environ['DO_NOT_TRACK'] = '1'
|
||||
|
||||
if not IS_PYISOLATE_CHILD:
|
||||
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
||||
|
||||
faulthandler.enable(file=sys.stderr, all_threads=False)
|
||||
|
||||
import comfy_aimdo.control
|
||||
@ -93,14 +143,15 @@ if args.enable_manager:
|
||||
|
||||
|
||||
def apply_custom_paths():
|
||||
from utils import extra_config # Deferred import - spawn re-runs main.py
|
||||
# extra model paths
|
||||
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
||||
if os.path.isfile(extra_model_paths_config_path):
|
||||
utils.extra_config.load_extra_path_config(extra_model_paths_config_path)
|
||||
extra_config.load_extra_path_config(extra_model_paths_config_path)
|
||||
|
||||
if args.extra_model_paths_config:
|
||||
for config_path in itertools.chain(*args.extra_model_paths_config):
|
||||
utils.extra_config.load_extra_path_config(config_path)
|
||||
extra_config.load_extra_path_config(config_path)
|
||||
|
||||
# --output-directory, --input-directory, --user-directory
|
||||
if args.output_directory:
|
||||
@ -173,15 +224,17 @@ def execute_prestartup_script():
|
||||
else:
|
||||
import_message = " (PRESTARTUP FAILED)"
|
||||
logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1]))
|
||||
logging.info("")
|
||||
logging.info("")
|
||||
|
||||
apply_custom_paths()
|
||||
init_mime_types()
|
||||
if not IS_PYISOLATE_CHILD:
|
||||
apply_custom_paths()
|
||||
init_mime_types()
|
||||
|
||||
if args.enable_manager:
|
||||
if args.enable_manager and not IS_PYISOLATE_CHILD:
|
||||
comfyui_manager.prestartup()
|
||||
|
||||
execute_prestartup_script()
|
||||
if not IS_PYISOLATE_CHILD:
|
||||
execute_prestartup_script()
|
||||
|
||||
|
||||
# Main code
|
||||
@ -192,17 +245,17 @@ import gc
|
||||
if 'torch' in sys.modules:
|
||||
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
|
||||
|
||||
|
||||
import comfy.utils
|
||||
|
||||
import execution
|
||||
import server
|
||||
from protocol import BinaryEventTypes
|
||||
import nodes
|
||||
import comfy.model_management
|
||||
import comfyui_version
|
||||
import app.logger
|
||||
import hook_breaker_ac10a0
|
||||
if not IS_PYISOLATE_CHILD:
|
||||
import execution
|
||||
import server
|
||||
from protocol import BinaryEventTypes
|
||||
import nodes
|
||||
import comfy.model_management
|
||||
import comfyui_version
|
||||
import app.logger
|
||||
import hook_breaker_ac10a0
|
||||
|
||||
import comfy.memory_management
|
||||
import comfy.model_patcher
|
||||
@ -462,6 +515,10 @@ def start_comfyui(asyncio_loop=None):
|
||||
asyncio.set_event_loop(asyncio_loop)
|
||||
prompt_server = server.PromptServer(asyncio_loop)
|
||||
|
||||
if args.use_process_isolation:
|
||||
from comfy.isolation import start_isolation_loading_early
|
||||
start_isolation_loading_early(asyncio_loop)
|
||||
|
||||
if args.enable_manager and not args.disable_manager_ui:
|
||||
comfyui_manager.start()
|
||||
|
||||
@ -506,12 +563,13 @@ def start_comfyui(asyncio_loop=None):
|
||||
if __name__ == "__main__":
|
||||
# Running directly, just start ComfyUI.
|
||||
logging.info("Python version: {}".format(sys.version))
|
||||
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
|
||||
for package in ("comfy-aimdo", "comfy-kitchen"):
|
||||
try:
|
||||
logging.info("{} version: {}".format(package, importlib.metadata.version(package)))
|
||||
except:
|
||||
pass
|
||||
if not IS_PYISOLATE_CHILD:
|
||||
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
|
||||
for package in ("comfy-aimdo", "comfy-kitchen"):
|
||||
try:
|
||||
logging.info("{} version: {}".format(package, importlib.metadata.version(package)))
|
||||
except:
|
||||
pass
|
||||
|
||||
if sys.version_info.major == 3 and sys.version_info.minor < 10:
|
||||
logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.")
|
||||
|
||||
40
nodes.py
40
nodes.py
@ -1912,6 +1912,7 @@ class ImageInvert:
|
||||
|
||||
class ImageBatch:
|
||||
SEARCH_ALIASES = ["combine images", "merge images", "stack images"]
|
||||
ESSENTIALS_CATEGORY = "Image Tools"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -2295,6 +2296,27 @@ async def init_external_custom_nodes():
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
whitelist = set()
|
||||
isolated_module_paths = set()
|
||||
if args.use_process_isolation:
|
||||
from pathlib import Path
|
||||
from comfy.isolation import await_isolation_loading, get_claimed_paths
|
||||
from comfy.isolation.host_policy import load_host_policy
|
||||
|
||||
# Load Global Host Policy
|
||||
host_policy = load_host_policy(Path(folder_paths.base_path))
|
||||
whitelist_dict = host_policy.get("whitelist", {})
|
||||
# Normalize whitelist keys to lowercase for case-insensitive matching
|
||||
# (matches ComfyUI-Manager's normalization: project.name.strip().lower())
|
||||
whitelist = set(k.strip().lower() for k in whitelist_dict.keys())
|
||||
logging.info(f"][ Loaded Whitelist: {len(whitelist)} nodes allowed.")
|
||||
|
||||
isolated_specs = await await_isolation_loading()
|
||||
for spec in isolated_specs:
|
||||
NODE_CLASS_MAPPINGS.setdefault(spec.node_name, spec.stub_class)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.setdefault(spec.node_name, spec.display_name)
|
||||
isolated_module_paths = get_claimed_paths()
|
||||
|
||||
base_node_names = set(NODE_CLASS_MAPPINGS.keys())
|
||||
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
||||
node_import_times = []
|
||||
@ -2318,6 +2340,16 @@ async def init_external_custom_nodes():
|
||||
logging.info(f"Blocked by policy: {module_path}")
|
||||
continue
|
||||
|
||||
if args.use_process_isolation:
|
||||
if Path(module_path).resolve() in isolated_module_paths:
|
||||
continue
|
||||
|
||||
# Tri-State Enforcement: If not Isolated (checked above), MUST be Whitelisted.
|
||||
# Normalize to lowercase for case-insensitive matching (matches ComfyUI-Manager)
|
||||
if possible_module.strip().lower() not in whitelist:
|
||||
logging.warning(f"][ REJECTED: Node '{possible_module}' is blocked by security policy (not whitelisted/isolated).")
|
||||
continue
|
||||
|
||||
time_before = time.perf_counter()
|
||||
success = await load_custom_node(module_path, base_node_names, module_parent="custom_nodes")
|
||||
node_import_times.append((time.perf_counter() - time_before, module_path, success))
|
||||
@ -2332,6 +2364,14 @@ async def init_external_custom_nodes():
|
||||
logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1]))
|
||||
logging.info("")
|
||||
|
||||
if args.use_process_isolation:
|
||||
from comfy.isolation import isolated_node_timings
|
||||
if isolated_node_timings:
|
||||
logging.info("\nImport times for isolated custom nodes:")
|
||||
for timing, path, count in sorted(isolated_node_timings):
|
||||
logging.info("{:6.1f} seconds: {} ({})".format(timing, path, count))
|
||||
logging.info("")
|
||||
|
||||
async def init_builtin_extra_nodes():
|
||||
"""
|
||||
Initializes the built-in extra nodes in ComfyUI.
|
||||
|
||||
44
server.py
44
server.py
@ -3,7 +3,6 @@ import sys
|
||||
import asyncio
|
||||
import traceback
|
||||
import time
|
||||
|
||||
import nodes
|
||||
import folder_paths
|
||||
import execution
|
||||
@ -202,6 +201,8 @@ def create_block_external_middleware():
|
||||
class PromptServer():
|
||||
def __init__(self, loop):
|
||||
PromptServer.instance = self
|
||||
if loop is None:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
self.user_manager = UserManager()
|
||||
self.model_file_manager = ModelFileManager()
|
||||
@ -352,6 +353,17 @@ class PromptServer():
|
||||
extensions.extend(list(map(lambda f: "/extensions/" + urllib.parse.quote(
|
||||
name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files)))
|
||||
|
||||
# Include JS files from proxied web directories (isolated nodes)
|
||||
if args.use_process_isolation:
|
||||
from comfy.isolation.proxies.web_directory_proxy import get_web_directory_cache
|
||||
cache = get_web_directory_cache()
|
||||
for ext_name in cache.extension_names:
|
||||
for entry in cache.list_files(ext_name):
|
||||
if entry["relative_path"].endswith(".js"):
|
||||
extensions.append(
|
||||
"/extensions/" + urllib.parse.quote(ext_name) + "/" + entry["relative_path"]
|
||||
)
|
||||
|
||||
return web.json_response(extensions)
|
||||
|
||||
def get_dir_by_type(dir_type):
|
||||
@ -1067,6 +1079,36 @@ class PromptServer():
|
||||
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
|
||||
self.app.add_routes([web.static('/extensions/' + name, dir)])
|
||||
|
||||
# Add dynamic handler for proxied web directories (isolated nodes)
|
||||
if args.use_process_isolation:
|
||||
from comfy.isolation.proxies.web_directory_proxy import (
|
||||
get_web_directory_cache,
|
||||
ALLOWED_EXTENSIONS,
|
||||
)
|
||||
|
||||
async def proxied_web_handler(request):
|
||||
ext_name = request.match_info["ext_name"]
|
||||
file_path = request.match_info["file_path"]
|
||||
|
||||
suffix = os.path.splitext(file_path)[1].lower()
|
||||
if suffix not in ALLOWED_EXTENSIONS:
|
||||
return web.Response(status=403, text="Forbidden file type")
|
||||
|
||||
cache = get_web_directory_cache()
|
||||
result = cache.get_file(ext_name, file_path)
|
||||
if result is None:
|
||||
return web.Response(status=404, text="Not found")
|
||||
|
||||
return web.Response(
|
||||
body=result["content"],
|
||||
content_type=result["content_type"],
|
||||
)
|
||||
|
||||
self.app.router.add_get(
|
||||
"/extensions/{ext_name}/{file_path:.*}",
|
||||
proxied_web_handler,
|
||||
)
|
||||
|
||||
installed_templates_version = FrontendManager.get_installed_templates_version()
|
||||
use_legacy_templates = True
|
||||
if installed_templates_version:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user