From d1155318eb885f0c454f4b5072fc832c65533cec Mon Sep 17 00:00:00 2001 From: John Pollock Date: Thu, 30 Apr 2026 14:51:08 -0500 Subject: [PATCH] feat(isolation): execution and extension loading integration --- comfy/isolation/extension_loader.py | 540 +++++++++++++++ comfy/isolation/extension_wrapper.py | 942 +++++++++++++++++++++++++++ execution.py | 161 ++++- main.py | 120 +++- nodes.py | 40 ++ server.py | 44 +- 6 files changed, 1799 insertions(+), 48 deletions(-) create mode 100644 comfy/isolation/extension_loader.py create mode 100644 comfy/isolation/extension_wrapper.py diff --git a/comfy/isolation/extension_loader.py b/comfy/isolation/extension_loader.py new file mode 100644 index 000000000..1574c58e7 --- /dev/null +++ b/comfy/isolation/extension_loader.py @@ -0,0 +1,540 @@ +# pylint: disable=cyclic-import,import-outside-toplevel,redefined-outer-name +from __future__ import annotations + +import logging +import os +import inspect +import sys +import types +import platform +from pathlib import Path +from typing import Any, Callable, Dict, List, Tuple + +import pyisolate +from pyisolate import ExtensionManager, ExtensionManagerConfig +from packaging.requirements import InvalidRequirement, Requirement +from packaging.utils import canonicalize_name + +from .manifest_loader import is_cache_valid, load_from_cache, save_to_cache +from .host_policy import load_host_policy + +try: + import tomllib +except ImportError: + import tomli as tomllib # type: ignore[no-redef] + +logger = logging.getLogger(__name__) + + +def _register_web_directory(extension_name: str, node_dir: Path) -> None: + """Register an isolated extension's web directory on the host side.""" + import nodes + + # Method 1: pyproject.toml [tool.comfy] web field + pyproject = node_dir / "pyproject.toml" + if pyproject.exists(): + try: + with pyproject.open("rb") as f: + data = tomllib.load(f) + web_dir_name = data.get("tool", {}).get("comfy", {}).get("web") + if web_dir_name: + web_dir_path = str(node_dir / web_dir_name) + if os.path.isdir(web_dir_path): + nodes.EXTENSION_WEB_DIRS[extension_name] = web_dir_path + logger.debug( + "][ Registered web dir for isolated %s: %s", + extension_name, + web_dir_path, + ) + return + except Exception: + pass + + # Method 2: __init__.py WEB_DIRECTORY constant (parse without importing) + init_file = node_dir / "__init__.py" + if init_file.exists(): + try: + source = init_file.read_text() + for line in source.splitlines(): + stripped = line.strip() + if stripped.startswith("WEB_DIRECTORY"): + # Parse: WEB_DIRECTORY = "./web" or WEB_DIRECTORY = "web" + _, _, value = stripped.partition("=") + value = value.strip().strip("\"'") + if value: + web_dir_path = str((node_dir / value).resolve()) + if os.path.isdir(web_dir_path): + nodes.EXTENSION_WEB_DIRS[extension_name] = web_dir_path + logger.debug( + "][ Registered web dir for isolated %s: %s", + extension_name, + web_dir_path, + ) + return + except Exception: + pass + + +def _get_extension_type(execution_model: str) -> type[Any]: + if execution_model == "sealed_worker": + return pyisolate.SealedNodeExtension + + from .extension_wrapper import ComfyNodeExtension + + return ComfyNodeExtension + + +async def _stop_extension_safe(extension: Any, extension_name: str) -> None: + try: + stop_result = extension.stop() + if inspect.isawaitable(stop_result): + await stop_result + except Exception: + logger.debug("][ %s stop failed", extension_name, exc_info=True) + + +def _normalize_dependency_spec(dep: str, base_paths: list[Path]) -> str: + req, sep, marker = dep.partition(";") + req = req.strip() + marker_suffix = f";{marker}" if sep else "" + + def _resolve_local_path(local_path: str) -> Path | None: + for base in base_paths: + candidate = (base / local_path).resolve() + if candidate.exists(): + return candidate + return None + + if req.startswith("./") or req.startswith("../"): + resolved = _resolve_local_path(req) + if resolved is not None: + return f"{resolved}{marker_suffix}" + + if req.startswith("file://"): + raw = req[len("file://") :] + if raw.startswith("./") or raw.startswith("../"): + resolved = _resolve_local_path(raw) + if resolved is not None: + return f"file://{resolved}{marker_suffix}" + + return dep + + +def _dependency_name_from_spec(dep: str) -> str | None: + stripped = dep.strip() + if not stripped or stripped == "-e" or stripped.startswith("-e "): + return None + if stripped.startswith(("/", "./", "../", "file://")): + return None + + try: + return canonicalize_name(Requirement(stripped).name) + except InvalidRequirement: + return None + + +def _parse_cuda_wheels_config( + tool_config: dict[str, object], dependencies: list[str] +) -> dict[str, object] | None: + raw_config = tool_config.get("cuda_wheels") + if raw_config is None: + return None + if not isinstance(raw_config, dict): + raise ExtensionLoadError("[tool.comfy.isolation.cuda_wheels] must be a table") + + index_url = raw_config.get("index_url") + index_urls = raw_config.get("index_urls") + if index_urls is not None: + if not isinstance(index_urls, list) or not all( + isinstance(u, str) and u.strip() for u in index_urls + ): + raise ExtensionLoadError( + "[tool.comfy.isolation.cuda_wheels.index_urls] must be a list of non-empty strings" + ) + elif not isinstance(index_url, str) or not index_url.strip(): + raise ExtensionLoadError( + "[tool.comfy.isolation.cuda_wheels.index_url] must be a non-empty string" + ) + + packages = raw_config.get("packages") + if not isinstance(packages, list) or not all( + isinstance(package_name, str) and package_name.strip() + for package_name in packages + ): + raise ExtensionLoadError( + "[tool.comfy.isolation.cuda_wheels.packages] must be a list of non-empty strings" + ) + + declared_dependencies = { + dependency_name + for dep in dependencies + if (dependency_name := _dependency_name_from_spec(dep)) is not None + } + normalized_packages = [canonicalize_name(package_name) for package_name in packages] + missing = [ + package_name + for package_name in normalized_packages + if package_name not in declared_dependencies + ] + if missing: + missing_joined = ", ".join(sorted(missing)) + raise ExtensionLoadError( + "[tool.comfy.isolation.cuda_wheels.packages] references undeclared dependencies: " + f"{missing_joined}" + ) + + package_map = raw_config.get("package_map", {}) + if not isinstance(package_map, dict): + raise ExtensionLoadError( + "[tool.comfy.isolation.cuda_wheels.package_map] must be a table" + ) + + normalized_package_map: dict[str, str] = {} + for dependency_name, index_package_name in package_map.items(): + if not isinstance(dependency_name, str) or not dependency_name.strip(): + raise ExtensionLoadError( + "[tool.comfy.isolation.cuda_wheels.package_map] keys must be non-empty strings" + ) + if not isinstance(index_package_name, str) or not index_package_name.strip(): + raise ExtensionLoadError( + "[tool.comfy.isolation.cuda_wheels.package_map] values must be non-empty strings" + ) + canonical_dependency_name = canonicalize_name(dependency_name) + if canonical_dependency_name not in normalized_packages: + raise ExtensionLoadError( + "[tool.comfy.isolation.cuda_wheels.package_map] can only override packages listed in " + "[tool.comfy.isolation.cuda_wheels.packages]" + ) + normalized_package_map[canonical_dependency_name] = index_package_name.strip() + + result: dict = { + "packages": normalized_packages, + "package_map": normalized_package_map, + } + if index_urls is not None: + result["index_urls"] = [u.rstrip("/") + "/" for u in index_urls] + else: + result["index_url"] = index_url.rstrip("/") + "/" + return result + + +def get_enforcement_policy() -> Dict[str, bool]: + return { + "force_isolated": os.environ.get("PYISOLATE_ENFORCE_ISOLATED") == "1", + "force_sandbox": os.environ.get("PYISOLATE_ENFORCE_SANDBOX") == "1", + } + + +class ExtensionLoadError(RuntimeError): + pass + + +def register_dummy_module(extension_name: str, node_dir: Path) -> None: + normalized_name = extension_name.replace("-", "_").replace(".", "_") + if normalized_name not in sys.modules: + dummy_module = types.ModuleType(normalized_name) + dummy_module.__file__ = str(node_dir / "__init__.py") + dummy_module.__path__ = [str(node_dir)] + dummy_module.__package__ = normalized_name + sys.modules[normalized_name] = dummy_module + + +def _is_stale_node_cache(cached_data: Dict[str, Dict]) -> bool: + for details in cached_data.values(): + if not isinstance(details, dict): + return True + if details.get("is_v3") and "schema_v1" not in details: + return True + return False + + +async def load_isolated_node( + node_dir: Path, + manifest_path: Path, + logger: logging.Logger, + build_stub_class: Callable[[str, Dict[str, object], Any], type], + venv_root: Path, + extension_managers: List[ExtensionManager], +) -> List[Tuple[str, str, type]]: + try: + with manifest_path.open("rb") as handle: + manifest_data = tomllib.load(handle) + except Exception as e: + logger.warning(f"][ Failed to parse {manifest_path}: {e}") + return [] + + # Parse [tool.comfy.isolation] + tool_config = manifest_data.get("tool", {}).get("comfy", {}).get("isolation", {}) + can_isolate = tool_config.get("can_isolate", False) + share_torch = tool_config.get("share_torch", False) + package_manager = tool_config.get("package_manager", "uv") + is_conda = package_manager == "conda" + execution_model = tool_config.get("execution_model") + if execution_model is None: + execution_model = "sealed_worker" if is_conda else "host-coupled" + + if "sealed_host_ro_paths" in tool_config: + raise ValueError( + "Manifest field 'sealed_host_ro_paths' is not allowed. " + "Configure [tool.comfy.host].sealed_worker_ro_import_paths in host policy." + ) + + # Conda-specific manifest fields + conda_channels: list[str] = ( + tool_config.get("conda_channels", []) if is_conda else [] + ) + conda_dependencies: list[str] = ( + tool_config.get("conda_dependencies", []) if is_conda else [] + ) + conda_platforms: list[str] = ( + tool_config.get("conda_platforms", []) if is_conda else [] + ) + conda_python: str = ( + tool_config.get("conda_python", "*") if is_conda else "*" + ) + + # Parse [project] dependencies + project_config = manifest_data.get("project", {}) + dependencies = project_config.get("dependencies", []) + if not isinstance(dependencies, list): + dependencies = [] + + # Get extension name (default to folder name if not in project.name) + extension_name = project_config.get("name", node_dir.name) + + # LOGIC: Isolation Decision + policy = get_enforcement_policy() + isolated = can_isolate or policy["force_isolated"] + + if not isolated: + return [] + + import folder_paths + + base_paths = [Path(folder_paths.base_path), node_dir] + dependencies = [ + _normalize_dependency_spec(dep, base_paths) if isinstance(dep, str) else dep + for dep in dependencies + ] + cuda_wheels = _parse_cuda_wheels_config(tool_config, dependencies) + + manager_config = ExtensionManagerConfig(venv_root_path=str(venv_root)) + extension_type = _get_extension_type(execution_model) + manager: ExtensionManager = pyisolate.ExtensionManager( + extension_type, manager_config + ) + extension_managers.append(manager) + + host_policy = load_host_policy(Path(folder_paths.base_path)) + + sandbox_config = {} + is_linux = platform.system() == "Linux" + + if is_conda: + share_torch = False + share_cuda_ipc = False + else: + share_cuda_ipc = share_torch and is_linux + + if is_linux and isolated: + sandbox_config = { + "network": host_policy["allow_network"], + "writable_paths": host_policy["writable_paths"], + "readonly_paths": host_policy["readonly_paths"], + } + + extension_config: dict = { + "name": extension_name, + "module_path": str(node_dir), + "isolated": True, + "dependencies": dependencies, + "share_torch": share_torch, + "share_cuda_ipc": share_cuda_ipc, + "sandbox_mode": host_policy["sandbox_mode"], + "sandbox": sandbox_config, + } + + share_torch_no_deps = tool_config.get("share_torch_no_deps", []) + if share_torch_no_deps: + if not isinstance(share_torch_no_deps, list) or not all( + isinstance(dep, str) and dep.strip() for dep in share_torch_no_deps + ): + raise ExtensionLoadError( + "[tool.comfy.isolation.share_torch_no_deps] must be a list of non-empty strings" + ) + extension_config["share_torch_no_deps"] = share_torch_no_deps + + _is_sealed = execution_model == "sealed_worker" + _is_sandboxed = host_policy["sandbox_mode"] != "disabled" and is_linux + logger.info( + "][ Loading isolated node: %s (torch_share [%s], sealed [%s], sandboxed [%s])", + extension_name, + "x" if share_torch else " ", + "x" if _is_sealed else " ", + "x" if _is_sandboxed else " ", + ) + + if cuda_wheels is not None: + extension_config["cuda_wheels"] = cuda_wheels + + extra_index_urls = tool_config.get("extra_index_urls", []) + if extra_index_urls: + if not isinstance(extra_index_urls, list) or not all( + isinstance(u, str) and u.strip() for u in extra_index_urls + ): + raise ExtensionLoadError( + "[tool.comfy.isolation.extra_index_urls] must be a list of non-empty strings" + ) + extension_config["extra_index_urls"] = extra_index_urls + + # Conda-specific keys + if is_conda: + extension_config["package_manager"] = "conda" + extension_config["conda_channels"] = conda_channels + extension_config["conda_dependencies"] = conda_dependencies + extension_config["conda_python"] = conda_python + find_links = tool_config.get("find_links", []) + if find_links: + extension_config["find_links"] = find_links + if conda_platforms: + extension_config["conda_platforms"] = conda_platforms + + if execution_model != "host-coupled": + extension_config["execution_model"] = execution_model + if execution_model == "sealed_worker": + policy_ro_paths = host_policy.get("sealed_worker_ro_import_paths", []) + if isinstance(policy_ro_paths, list) and policy_ro_paths: + extension_config["sealed_host_ro_paths"] = list(policy_ro_paths) + # Sealed workers keep the host RPC service inventory even when the + # child resolves no API classes locally. + + extension = manager.load_extension(extension_config) + register_dummy_module(extension_name, node_dir) + + # Register host-side event handlers via adapter + from .adapter import ComfyUIAdapter + ComfyUIAdapter.register_host_event_handlers(extension) + + # Register web directory on the host — only when sandbox is disabled. + # In sandbox mode, serving untrusted JS to the browser is not safe. + if host_policy["sandbox_mode"] == "disabled": + _register_web_directory(extension_name, node_dir) + + # Register for proxied web serving — the child's web dir may have + # content that doesn't exist on the host (e.g., pip-installed viewer + # bundles). The WebDirectoryCache will lazily fetch via RPC. + from .proxies.web_directory_proxy import WebDirectoryProxy, get_web_directory_cache + + class ChildWebDirectoryProxy: + def __init__(self, host_extension): + self._host_extension = host_extension + self._caller = None + + def _get_caller(self): + self._host_extension.proxy + rpc = self._host_extension._extension.rpc + caller = rpc.create_caller(WebDirectoryProxy, WebDirectoryProxy.get_remote_id()) + if self._caller is not caller: + self._caller = caller + return self._caller + + def list_web_files(self, extension_name: str): + from .proxies.base import run_sync_rpc_coro + return run_sync_rpc_coro(self._get_caller().list_web_files(extension_name)) + + def get_web_file(self, extension_name: str, relative_path: str): + from .proxies.base import run_sync_rpc_coro + return run_sync_rpc_coro( + self._get_caller().get_web_file(extension_name, relative_path) + ) + + cache = get_web_directory_cache() + cache.register_proxy(extension_name, ChildWebDirectoryProxy(extension)) + + # Try cache first (lazy spawn) + if is_cache_valid(node_dir, manifest_path, venv_root): + cached_data = load_from_cache(node_dir, venv_root) + if cached_data: + if _is_stale_node_cache(cached_data): + pass + else: + try: + flushed = await extension.flush_pending_routes() + logger.info("][ %s flushed %d routes", extension_name, flushed) + except Exception as exc: + logger.warning("][ %s route flush failed: %s", extension_name, exc) + specs: List[Tuple[str, str, type]] = [] + for node_name, details in cached_data.items(): + stub_cls = build_stub_class(node_name, details, extension) + specs.append( + (node_name, details.get("display_name", node_name), stub_cls) + ) + return specs + # Cache miss - spawn process and get metadata + + try: + remote_nodes: Dict[str, str] = await extension.list_nodes() + except Exception as exc: + logger.warning( + "][ %s metadata discovery failed, skipping isolated load: %s", + extension_name, + exc, + ) + await _stop_extension_safe(extension, extension_name) + return [] + + if not remote_nodes: + logger.debug("][ %s exposed no isolated nodes; skipping", extension_name) + await _stop_extension_safe(extension, extension_name) + return [] + + specs: List[Tuple[str, str, type]] = [] + cache_data: Dict[str, Dict] = {} + + for node_name, display_name in remote_nodes.items(): + try: + details = await extension.get_node_details(node_name) + except Exception as exc: + logger.warning( + "][ %s failed to load metadata for %s, skipping node: %s", + extension_name, + node_name, + exc, + ) + continue + details["display_name"] = display_name + cache_data[node_name] = details + stub_cls = build_stub_class(node_name, details, extension) + specs.append((node_name, display_name, stub_cls)) + + if not specs: + logger.warning( + "][ %s produced no usable nodes after metadata scan; skipping", + extension_name, + ) + await _stop_extension_safe(extension, extension_name) + return [] + + # Save metadata to cache for future runs + save_to_cache(node_dir, venv_root, cache_data, manifest_path) + logger.debug(f"][ {extension_name} metadata cached") + + # Re-check web directory AFTER child has populated it + if host_policy["sandbox_mode"] == "disabled": + _register_web_directory(extension_name, node_dir) + + # Flush any routes the child buffered during module import — must happen + # before router freeze and before we kill the child process. + try: + flushed = await extension.flush_pending_routes() + logger.info("][ %s flushed %d routes", extension_name, flushed) + except Exception as exc: + logger.warning("][ %s route flush failed: %s", extension_name, exc) + + # EJECT: Kill process after getting metadata (will respawn on first execution) + await _stop_extension_safe(extension, extension_name) + + return specs + + +__all__ = ["ExtensionLoadError", "register_dummy_module", "load_isolated_node"] diff --git a/comfy/isolation/extension_wrapper.py b/comfy/isolation/extension_wrapper.py new file mode 100644 index 000000000..afa4bceb2 --- /dev/null +++ b/comfy/isolation/extension_wrapper.py @@ -0,0 +1,942 @@ +# pylint: disable=consider-using-from-import,cyclic-import,import-outside-toplevel,logging-fstring-interpolation,protected-access,wrong-import-position +from __future__ import annotations + +import asyncio +import torch + + +class AttrDict(dict): + def __getattr__(self, item): + try: + return self[item] + except KeyError as e: + raise AttributeError(item) from e + + def copy(self): + return AttrDict(super().copy()) + + +import importlib +import inspect +import json +import logging +import os +import sys +import uuid +from dataclasses import asdict +from typing import Any, Dict, List, Tuple + +from pyisolate import ExtensionBase + +from comfy_api.internal import _ComfyNodeInternal + +LOG_PREFIX = "][" +V3_DISCOVERY_TIMEOUT = 30 +_PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024 + +logger = logging.getLogger(__name__) + + +def _run_prestartup_web_copy(module: Any, module_dir: str, web_dir_path: str) -> None: + """Run the web asset copy step that prestartup_script.py used to do. + + If the module's web/ directory is empty and the module had a + prestartup_script.py that copied assets from pip packages, this + function replicates that work inside the child process. + + Generic pattern: reads _PRESTARTUP_WEB_COPY from the module if + defined, otherwise falls back to detecting common asset packages. + """ + import shutil + + # Already populated — nothing to do + if os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path)): + return + + os.makedirs(web_dir_path, exist_ok=True) + + # Try module-defined copy spec first (generic hook for any node pack) + copy_spec = getattr(module, "_PRESTARTUP_WEB_COPY", None) + if copy_spec is not None and callable(copy_spec): + try: + copy_spec(web_dir_path) + logger.info( + "%s Ran _PRESTARTUP_WEB_COPY for %s", LOG_PREFIX, module_dir + ) + return + except Exception as e: + logger.warning( + "%s _PRESTARTUP_WEB_COPY failed for %s: %s", + LOG_PREFIX, module_dir, e, + ) + + # Fallback: detect comfy_3d_viewers and run copy_viewer() + try: + from comfy_3d_viewers import copy_viewer, VIEWER_FILES + viewers = list(VIEWER_FILES.keys()) + for viewer in viewers: + try: + copy_viewer(viewer, web_dir_path) + except Exception: + pass + if any(os.scandir(web_dir_path)): + logger.info( + "%s Copied %d viewer types from comfy_3d_viewers to %s", + LOG_PREFIX, len(viewers), web_dir_path, + ) + except ImportError: + pass + + # Fallback: detect comfy_dynamic_widgets + try: + from comfy_dynamic_widgets import get_js_path + src = os.path.realpath(get_js_path()) + if os.path.exists(src): + dst_dir = os.path.join(web_dir_path, "js") + os.makedirs(dst_dir, exist_ok=True) + dst = os.path.join(dst_dir, "dynamic_widgets.js") + shutil.copy2(src, dst) + except ImportError: + pass + + +def _read_extension_name(module_dir: str) -> str: + """Read extension name from pyproject.toml, falling back to directory name.""" + pyproject = os.path.join(module_dir, "pyproject.toml") + if os.path.exists(pyproject): + try: + import tomllib + except ImportError: + import tomli as tomllib # type: ignore[no-redef] + try: + with open(pyproject, "rb") as f: + data = tomllib.load(f) + name = data.get("project", {}).get("name") + if name: + return name + except Exception: + pass + return os.path.basename(module_dir) + + +def _flush_tensor_transport_state(marker: str) -> int: + try: + from pyisolate import flush_tensor_keeper # type: ignore[attr-defined] + except Exception: + return 0 + if not callable(flush_tensor_keeper): + return 0 + flushed = flush_tensor_keeper() + if flushed > 0: + logger.debug( + "%s %s flush_tensor_keeper released=%d", LOG_PREFIX, marker, flushed + ) + return flushed + + +def _relieve_child_vram_pressure(marker: str) -> None: + import comfy.model_management as model_management + + model_management.cleanup_models_gc() + model_management.cleanup_models() + + device = model_management.get_torch_device() + if not hasattr(device, "type") or device.type == "cpu": + return + + required = max( + model_management.minimum_inference_memory(), + _PRE_EXEC_MIN_FREE_VRAM_BYTES, + ) + if model_management.get_free_memory(device) < required: + model_management.free_memory(required, device, for_dynamic=True) + if model_management.get_free_memory(device) < required: + model_management.free_memory(required, device, for_dynamic=False) + model_management.cleanup_models() + model_management.soft_empty_cache() + logger.debug("%s %s free_memory target=%d", LOG_PREFIX, marker, required) + + +def _sanitize_for_transport(value): + primitives = (str, int, float, bool, type(None)) + if isinstance(value, primitives): + return value + + cls_name = value.__class__.__name__ + if cls_name == "FlexibleOptionalInputType": + return { + "__pyisolate_flexible_optional__": True, + "type": _sanitize_for_transport(getattr(value, "type", "*")), + } + if cls_name == "AnyType": + return {"__pyisolate_any_type__": True, "value": str(value)} + if cls_name == "ByPassTypeTuple": + return { + "__pyisolate_bypass_tuple__": [ + _sanitize_for_transport(v) for v in tuple(value) + ] + } + + if isinstance(value, dict): + return {k: _sanitize_for_transport(v) for k, v in value.items()} + if isinstance(value, tuple): + return {"__pyisolate_tuple__": [_sanitize_for_transport(v) for v in value]} + if isinstance(value, list): + return [_sanitize_for_transport(v) for v in value] + + return str(value) + + +# Re-export RemoteObjectHandle from pyisolate for backward compatibility +# The canonical definition is now in pyisolate._internal.remote_handle +from pyisolate._internal.remote_handle import RemoteObjectHandle # noqa: E402,F401 + + +class ComfyNodeExtension(ExtensionBase): + def __init__(self) -> None: + super().__init__() + self.node_classes: Dict[str, type] = {} + self.display_names: Dict[str, str] = {} + self.node_instances: Dict[str, Any] = {} + self.remote_objects: Dict[str, Any] = {} + self._route_handlers: Dict[str, Any] = {} + self._module: Any = None + self._metadata_ready = asyncio.Event() + + async def on_module_loaded(self, module: Any) -> None: + try: + self._module = module + + # Registries are initialized in host_hooks.py initialize_host_process() + # They auto-register via ProxiedSingleton when instantiated + # NO additional setup required here - if a registry is missing from host_hooks, it WILL fail + + self.node_classes = getattr(module, "NODE_CLASS_MAPPINGS", {}) or {} + self.display_names = getattr(module, "NODE_DISPLAY_NAME_MAPPINGS", {}) or {} + self._register_module_routes(module) + + # Register web directory with WebDirectoryProxy (child-side) + web_dir_attr = getattr(module, "WEB_DIRECTORY", None) + if web_dir_attr is not None: + module_dir = os.path.dirname(os.path.abspath(module.__file__)) + web_dir_path = os.path.abspath(os.path.join(module_dir, web_dir_attr)) + ext_name = _read_extension_name(module_dir) + + # If web dir is empty, run the copy step that prestartup_script.py did + _run_prestartup_web_copy(module, module_dir, web_dir_path) + + if os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path)): + from comfy.isolation.proxies.web_directory_proxy import WebDirectoryProxy + WebDirectoryProxy.register_web_dir(ext_name, web_dir_path) + + try: + from comfy_api.latest import ComfyExtension + + for name, obj in inspect.getmembers(module): + if not ( + inspect.isclass(obj) + and issubclass(obj, ComfyExtension) + and obj is not ComfyExtension + ): + continue + if not obj.__module__.startswith(module.__name__): + continue + try: + ext_instance = obj() + try: + await asyncio.wait_for( + ext_instance.on_load(), timeout=V3_DISCOVERY_TIMEOUT + ) + except asyncio.TimeoutError: + logger.error( + "%s V3 Extension %s timed out in on_load()", + LOG_PREFIX, + name, + ) + continue + try: + v3_nodes = await asyncio.wait_for( + ext_instance.get_node_list(), timeout=V3_DISCOVERY_TIMEOUT + ) + except asyncio.TimeoutError: + logger.error( + "%s V3 Extension %s timed out in get_node_list()", + LOG_PREFIX, + name, + ) + continue + for node_cls in v3_nodes: + if hasattr(node_cls, "GET_SCHEMA"): + schema = node_cls.GET_SCHEMA() + self.node_classes[schema.node_id] = node_cls + if schema.display_name: + self.display_names[schema.node_id] = schema.display_name + except Exception as e: + logger.error("%s V3 Extension %s failed: %s", LOG_PREFIX, name, e) + except ImportError: + pass + + module_name = getattr(module, "__name__", "isolated_nodes") + for node_cls in self.node_classes.values(): + if hasattr(node_cls, "__module__") and "/" in str(node_cls.__module__): + node_cls.__module__ = module_name + + self.node_instances = {} + finally: + self._metadata_ready.set() + + def _register_module_routes(self, module: Any) -> None: + """Bridge legacy module-level ROUTES declarations into isolated routing.""" + routes = getattr(module, "ROUTES", None) or [] + if not routes: + return + + from comfy.isolation.proxies.prompt_server_impl import PromptServerStub + + prompt_server = PromptServerStub() + route_table = getattr(prompt_server, "routes", None) + if route_table is None: + logger.warning("%s Route registration unavailable for %s", LOG_PREFIX, module) + return + + for route_spec in routes: + if not isinstance(route_spec, dict): + logger.warning("%s Ignoring non-dict ROUTES entry: %r", LOG_PREFIX, route_spec) + continue + + method = str(route_spec.get("method", "")).strip().upper() + path = str(route_spec.get("path", "")).strip() + handler_ref = route_spec.get("handler") + if not method or not path: + logger.warning("%s Ignoring incomplete route spec: %r", LOG_PREFIX, route_spec) + continue + + if isinstance(handler_ref, str): + handler = getattr(module, handler_ref, None) + else: + handler = handler_ref + if not callable(handler): + logger.warning( + "%s Ignoring route with missing handler %r for %s %s", + LOG_PREFIX, + handler_ref, + method, + path, + ) + continue + + decorator = getattr(route_table, method.lower(), None) + if not callable(decorator): + logger.warning("%s Unsupported route method %s for %s", LOG_PREFIX, method, path) + continue + + decorator(path)(handler) + self._route_handlers[f"{method} {path}"] = handler + logger.info("%s buffered legacy route %s %s", LOG_PREFIX, method, path) + + async def list_nodes(self) -> Dict[str, str]: + await asyncio.wait_for( + self._metadata_ready.wait(), timeout=V3_DISCOVERY_TIMEOUT + ) + return {name: self.display_names.get(name, name) for name in self.node_classes} + + async def get_node_info(self, node_name: str) -> Dict[str, Any]: + return await self.get_node_details(node_name) + + async def get_node_details(self, node_name: str) -> Dict[str, Any]: + await asyncio.wait_for( + self._metadata_ready.wait(), timeout=V3_DISCOVERY_TIMEOUT + ) + node_cls = self._get_node_class(node_name) + is_v3 = issubclass(node_cls, _ComfyNodeInternal) + + input_types_raw = ( + node_cls.INPUT_TYPES() if hasattr(node_cls, "INPUT_TYPES") else {} + ) + output_is_list = getattr(node_cls, "OUTPUT_IS_LIST", None) + if output_is_list is not None: + output_is_list = tuple(bool(x) for x in output_is_list) + + details: Dict[str, Any] = { + "input_types": _sanitize_for_transport(input_types_raw), + "return_types": tuple( + str(t) for t in getattr(node_cls, "RETURN_TYPES", ()) + ), + "return_names": getattr(node_cls, "RETURN_NAMES", None), + "function": str(getattr(node_cls, "FUNCTION", "execute")), + "category": str(getattr(node_cls, "CATEGORY", "")), + "output_node": bool(getattr(node_cls, "OUTPUT_NODE", False)), + "output_is_list": output_is_list, + "is_v3": is_v3, + } + + if is_v3: + try: + schema = node_cls.GET_SCHEMA() + schema_v1 = asdict(schema.get_v1_info(node_cls)) + try: + schema_v3 = asdict(schema.get_v3_info(node_cls)) + except (AttributeError, TypeError): + schema_v3 = self._build_schema_v3_fallback(schema) + details.update( + { + "schema_v1": schema_v1, + "schema_v3": schema_v3, + "hidden": [h.value for h in (schema.hidden or [])], + "description": getattr(schema, "description", ""), + "deprecated": bool(getattr(node_cls, "DEPRECATED", False)), + "experimental": bool(getattr(node_cls, "EXPERIMENTAL", False)), + "api_node": bool(getattr(node_cls, "API_NODE", False)), + "input_is_list": bool( + getattr(node_cls, "INPUT_IS_LIST", False) + ), + "not_idempotent": bool( + getattr(node_cls, "NOT_IDEMPOTENT", False) + ), + "accept_all_inputs": bool( + getattr(node_cls, "ACCEPT_ALL_INPUTS", False) + ), + } + ) + except Exception as exc: + logger.warning( + "%s V3 schema serialization failed for %s: %s", + LOG_PREFIX, + node_name, + exc, + ) + return details + + def _build_schema_v3_fallback(self, schema) -> Dict[str, Any]: + input_dict: Dict[str, Any] = {} + output_dict: Dict[str, Any] = {} + hidden_list: List[str] = [] + + if getattr(schema, "inputs", None): + for inp in schema.inputs: + self._add_schema_io_v3(inp, input_dict) + if getattr(schema, "outputs", None): + for out in schema.outputs: + self._add_schema_io_v3(out, output_dict) + if getattr(schema, "hidden", None): + for h in schema.hidden: + hidden_list.append(getattr(h, "value", str(h))) + + return { + "input": input_dict, + "output": output_dict, + "hidden": hidden_list, + "name": getattr(schema, "node_id", None), + "display_name": getattr(schema, "display_name", None), + "description": getattr(schema, "description", None), + "category": getattr(schema, "category", None), + "output_node": getattr(schema, "is_output_node", False), + "deprecated": getattr(schema, "is_deprecated", False), + "experimental": getattr(schema, "is_experimental", False), + "api_node": getattr(schema, "is_api_node", False), + } + + def _add_schema_io_v3(self, io_obj: Any, target: Dict[str, Any]) -> None: + io_id = getattr(io_obj, "id", None) + if io_id is None: + return + + io_type_fn = getattr(io_obj, "get_io_type", None) + io_type = ( + io_type_fn() if callable(io_type_fn) else getattr(io_obj, "io_type", None) + ) + + as_dict_fn = getattr(io_obj, "as_dict", None) + payload = as_dict_fn() if callable(as_dict_fn) else {} + + target[str(io_id)] = (io_type, payload) + + async def get_input_types(self, node_name: str) -> Dict[str, Any]: + node_cls = self._get_node_class(node_name) + if hasattr(node_cls, "INPUT_TYPES"): + return node_cls.INPUT_TYPES() + return {} + + async def execute_node(self, node_name: str, **inputs: Any) -> Tuple[Any, ...]: + logger.debug( + "%s ISO:child_execute_start ext=%s node=%s input_keys=%d", + LOG_PREFIX, + getattr(self, "name", "?"), + node_name, + len(inputs), + ) + if os.environ.get("PYISOLATE_CHILD") == "1": + _relieve_child_vram_pressure("EXT:pre_execute") + + resolved_inputs = self._resolve_remote_objects(inputs) + + instance = self._get_node_instance(node_name) + node_cls = self._get_node_class(node_name) + + # V3 API nodes expect hidden parameters in cls.hidden, not as kwargs + # Hidden params come through RPC as string keys like "Hidden.prompt" + from comfy_api.latest._io import Hidden, HiddenHolder + + # Map string representations back to Hidden enum keys + hidden_string_map = { + "Hidden.unique_id": Hidden.unique_id, + "Hidden.prompt": Hidden.prompt, + "Hidden.extra_pnginfo": Hidden.extra_pnginfo, + "Hidden.dynprompt": Hidden.dynprompt, + "Hidden.auth_token_comfy_org": Hidden.auth_token_comfy_org, + "Hidden.api_key_comfy_org": Hidden.api_key_comfy_org, + # Uppercase enum VALUE forms — V3 execution engine passes these + "UNIQUE_ID": Hidden.unique_id, + "PROMPT": Hidden.prompt, + "EXTRA_PNGINFO": Hidden.extra_pnginfo, + "DYNPROMPT": Hidden.dynprompt, + "AUTH_TOKEN_COMFY_ORG": Hidden.auth_token_comfy_org, + "API_KEY_COMFY_ORG": Hidden.api_key_comfy_org, + } + + # Find and extract hidden parameters (both enum and string form) + hidden_found = {} + keys_to_remove = [] + + for key in list(resolved_inputs.keys()): + # Check string form first (from RPC serialization) + if key in hidden_string_map: + hidden_found[hidden_string_map[key]] = resolved_inputs[key] + keys_to_remove.append(key) + # Also check enum form (direct calls) + elif isinstance(key, Hidden): + hidden_found[key] = resolved_inputs[key] + keys_to_remove.append(key) + + # Remove hidden params from kwargs + for key in keys_to_remove: + resolved_inputs.pop(key) + + # Set hidden on node class if any hidden params found + if hidden_found: + if not hasattr(node_cls, "hidden") or node_cls.hidden is None: + node_cls.hidden = HiddenHolder.from_dict(hidden_found) + else: + # Update existing hidden holder + for key, value in hidden_found.items(): + setattr(node_cls.hidden, key.value.lower(), value) + + # INPUT_IS_LIST: ComfyUI's executor passes all inputs as lists when this + # flag is set. The isolation RPC delivers unwrapped values, so we must + # wrap each input in a single-element list to match the contract. + if getattr(node_cls, "INPUT_IS_LIST", False): + resolved_inputs = {k: [v] for k, v in resolved_inputs.items()} + + function_name = getattr(node_cls, "FUNCTION", "execute") + if not hasattr(instance, function_name): + raise AttributeError(f"Node {node_name} missing callable '{function_name}'") + + handler = getattr(instance, function_name) + + try: + import torch + if asyncio.iscoroutinefunction(handler): + with torch.inference_mode(): + result = await handler(**resolved_inputs) + else: + import functools + + def _run_with_inference_mode(**kwargs): + with torch.inference_mode(): + return handler(**kwargs) + + loop = asyncio.get_running_loop() + result = await loop.run_in_executor( + None, functools.partial(_run_with_inference_mode, **resolved_inputs) + ) + except Exception: + logger.exception( + "%s ISO:child_execute_error ext=%s node=%s", + LOG_PREFIX, + getattr(self, "name", "?"), + node_name, + ) + raise + + if type(result).__name__ == "NodeOutput": + node_output_dict = { + "__node_output__": True, + "args": self._wrap_unpicklable_objects(result.args), + } + if result.ui is not None: + node_output_dict["ui"] = self._wrap_unpicklable_objects(result.ui) + if getattr(result, "expand", None) is not None: + node_output_dict["expand"] = result.expand + if getattr(result, "block_execution", None) is not None: + node_output_dict["block_execution"] = result.block_execution + return node_output_dict + if self._is_comfy_protocol_return(result): + wrapped = self._wrap_unpicklable_objects(result) + return wrapped + + if not isinstance(result, tuple): + result = (result,) + wrapped = self._wrap_unpicklable_objects(result) + return wrapped + + async def flush_pending_routes(self) -> int: + """Flush buffered route registrations to host via RPC. Called by host after node discovery.""" + from comfy.isolation.proxies.prompt_server_impl import PromptServerStub + return await PromptServerStub.flush_child_routes() + + async def flush_transport_state(self) -> int: + if os.environ.get("PYISOLATE_CHILD") != "1": + return 0 + logger.debug( + "%s ISO:child_flush_start ext=%s", LOG_PREFIX, getattr(self, "name", "?") + ) + flushed = _flush_tensor_transport_state("EXT:workflow_end") + try: + from comfy.isolation.model_patcher_proxy_registry import ( + ModelPatcherRegistry, + ) + + registry = ModelPatcherRegistry() + removed = registry.sweep_pending_cleanup() + if removed > 0: + logger.debug( + "%s EXT:workflow_end registry sweep removed=%d", LOG_PREFIX, removed + ) + except Exception: + logger.debug( + "%s EXT:workflow_end registry sweep failed", LOG_PREFIX, exc_info=True + ) + logger.debug( + "%s ISO:child_flush_done ext=%s flushed=%d", + LOG_PREFIX, + getattr(self, "name", "?"), + flushed, + ) + return flushed + + async def get_remote_object(self, object_id: str) -> Any: + """Retrieve a remote object by ID for host-side deserialization.""" + if object_id not in self.remote_objects: + raise KeyError(f"Remote object {object_id} not found") + + return self.remote_objects[object_id] + + def _store_remote_object_handle(self, obj: Any) -> RemoteObjectHandle: + object_id = str(uuid.uuid4()) + self.remote_objects[object_id] = obj + return RemoteObjectHandle(object_id, type(obj).__name__) + + async def call_remote_object_method( + self, + object_id: str, + method_name: str, + *args: Any, + **kwargs: Any, + ) -> Any: + """Invoke a method or attribute-backed accessor on a child-owned object.""" + obj = await self.get_remote_object(object_id) + + if method_name == "get_patcher_attr": + return getattr(obj, args[0]) + if method_name == "get_model_options": + return getattr(obj, "model_options") + if method_name == "set_model_options": + setattr(obj, "model_options", args[0]) + return None + if method_name == "get_object_patches": + return getattr(obj, "object_patches") + if method_name == "get_patches": + return getattr(obj, "patches") + if method_name == "get_wrappers": + return getattr(obj, "wrappers") + if method_name == "get_callbacks": + return getattr(obj, "callbacks") + if method_name == "get_load_device": + return getattr(obj, "load_device") + if method_name == "get_offload_device": + return getattr(obj, "offload_device") + if method_name == "get_hook_mode": + return getattr(obj, "hook_mode") + if method_name == "get_parent": + parent = getattr(obj, "parent", None) + if parent is None: + return None + return self._store_remote_object_handle(parent) + if method_name == "get_inner_model_attr": + attr_name = args[0] + if hasattr(obj.model, attr_name): + return getattr(obj.model, attr_name) + if hasattr(obj, attr_name): + return getattr(obj, attr_name) + return None + if method_name == "inner_model_apply_model": + return obj.model.apply_model(*args[0], **args[1]) + if method_name == "inner_model_extra_conds_shapes": + return obj.model.extra_conds_shapes(*args[0], **args[1]) + if method_name == "inner_model_extra_conds": + return obj.model.extra_conds(*args[0], **args[1]) + if method_name == "inner_model_memory_required": + return obj.model.memory_required(*args[0], **args[1]) + if method_name == "process_latent_in": + return obj.model.process_latent_in(*args[0], **args[1]) + if method_name == "process_latent_out": + return obj.model.process_latent_out(*args[0], **args[1]) + if method_name == "scale_latent_inpaint": + return obj.model.scale_latent_inpaint(*args[0], **args[1]) + if method_name.startswith("get_"): + attr_name = method_name[4:] + if hasattr(obj, attr_name): + return getattr(obj, attr_name) + + target = getattr(obj, method_name) + if callable(target): + result = target(*args, **kwargs) + if inspect.isawaitable(result): + result = await result + if type(result).__name__ == "ModelPatcher": + return self._store_remote_object_handle(result) + return result + if args or kwargs: + raise TypeError(f"{method_name} is not callable on remote object {object_id}") + return target + + def _wrap_unpicklable_objects(self, data: Any) -> Any: + if isinstance(data, (str, int, float, bool, type(None))): + return data + if isinstance(data, torch.Tensor): + tensor = data.detach() if data.requires_grad else data + if os.environ.get("PYISOLATE_CHILD") == "1" and tensor.device.type != "cpu": + return tensor.cpu() + return tensor + + # Special-case clip vision outputs: preserve attribute access by packing fields + if hasattr(data, "penultimate_hidden_states") or hasattr( + data, "last_hidden_state" + ): + fields = {} + for attr in ( + "penultimate_hidden_states", + "last_hidden_state", + "image_embeds", + "text_embeds", + ): + if hasattr(data, attr): + try: + fields[attr] = self._wrap_unpicklable_objects( + getattr(data, attr) + ) + except Exception: + pass + if fields: + return {"__pyisolate_attribute_container__": True, "data": fields} + + # Avoid converting arbitrary objects with stateful methods (models, etc.) + # They will be handled via RemoteObjectHandle below. + + type_name = type(data).__name__ + if type_name == "ModelPatcherProxy": + return {"__type__": "ModelPatcherRef", "model_id": data._instance_id} + if type_name == "CLIPProxy": + return {"__type__": "CLIPRef", "clip_id": data._instance_id} + if type_name == "VAEProxy": + return {"__type__": "VAERef", "vae_id": data._instance_id} + if type_name == "ModelSamplingProxy": + return {"__type__": "ModelSamplingRef", "ms_id": data._instance_id} + + if isinstance(data, (list, tuple)): + wrapped = [self._wrap_unpicklable_objects(item) for item in data] + return tuple(wrapped) if isinstance(data, tuple) else wrapped + if isinstance(data, dict): + converted_dict = { + k: self._wrap_unpicklable_objects(v) for k, v in data.items() + } + return {"__pyisolate_attrdict__": True, "data": converted_dict} + + from pyisolate._internal.serialization_registry import SerializerRegistry + + registry = SerializerRegistry.get_instance() + if registry.is_data_type(type_name): + serializer = registry.get_serializer(type_name) + if serializer: + return serializer(data) + + return self._store_remote_object_handle(data) + + def _resolve_remote_objects(self, data: Any) -> Any: + if isinstance(data, RemoteObjectHandle): + if data.object_id not in self.remote_objects: + raise KeyError(f"Remote object {data.object_id} not found") + return self.remote_objects[data.object_id] + + if isinstance(data, dict): + ref_type = data.get("__type__") + if ref_type in ("CLIPRef", "ModelPatcherRef", "VAERef"): + from pyisolate._internal.model_serialization import ( + deserialize_proxy_result, + ) + + return deserialize_proxy_result(data) + if ref_type == "ModelSamplingRef": + from pyisolate._internal.model_serialization import ( + deserialize_proxy_result, + ) + + return deserialize_proxy_result(data) + return {k: self._resolve_remote_objects(v) for k, v in data.items()} + + if isinstance(data, (list, tuple)): + resolved = [self._resolve_remote_objects(item) for item in data] + return tuple(resolved) if isinstance(data, tuple) else resolved + return data + + def _get_node_class(self, node_name: str) -> type: + if node_name not in self.node_classes: + raise KeyError(f"Unknown node: {node_name}") + return self.node_classes[node_name] + + def _get_node_instance(self, node_name: str) -> Any: + if node_name not in self.node_instances: + if node_name not in self.node_classes: + raise KeyError(f"Unknown node: {node_name}") + self.node_instances[node_name] = self.node_classes[node_name]() + return self.node_instances[node_name] + + async def before_module_loaded(self) -> None: + try: + from comfy.isolation import initialize_proxies + + initialize_proxies() + except Exception as e: + logger.error( + "%s before_module_loaded initialize_proxies FAILED: %s", LOG_PREFIX, e + ) + + await super().before_module_loaded() + try: + from comfy_api.latest import ComfyAPI_latest + from .proxies.progress_proxy import ProgressProxy + + ComfyAPI_latest.Execution = ProgressProxy + # ComfyAPI_latest.execution = ProgressProxy() # Eliminated to avoid Singleton collision + # fp_proxy = FolderPathsProxy() # Eliminated to avoid Singleton collision + # latest_ui.folder_paths = fp_proxy + # latest_resources.folder_paths = fp_proxy + except Exception: + pass + + async def call_route_handler( + self, + handler_module: str, + handler_func: str, + request_data: Dict[str, Any], + ) -> Any: + cache_key = f"{handler_module}.{handler_func}" + if cache_key not in self._route_handlers: + if self._module is not None and hasattr(self._module, "__file__"): + node_dir = os.path.dirname(self._module.__file__) + if node_dir not in sys.path: + sys.path.insert(0, node_dir) + try: + module = importlib.import_module(handler_module) + self._route_handlers[cache_key] = getattr(module, handler_func) + except (ImportError, AttributeError) as e: + raise ValueError(f"Route handler not found: {cache_key}") from e + + handler = self._route_handlers[cache_key] + mock_request = MockRequest(request_data) + + if asyncio.iscoroutinefunction(handler): + result = await handler(mock_request) + else: + result = handler(mock_request) + return self._serialize_response(result) + + def _is_comfy_protocol_return(self, result: Any) -> bool: + """ + Check if the result matches the ComfyUI 'Protocol Return' schema. + + A Protocol Return is a dictionary containing specific reserved keys that + ComfyUI's execution engine interprets as instructions (UI updates, + Workflow expansion, etc.) rather than purely data outputs. + + Schema: + - Must be a dict + - Must contain at least one of: 'ui', 'result', 'expand' + """ + if not isinstance(result, dict): + return False + return any(key in result for key in ("ui", "result", "expand")) + + def _serialize_response(self, response: Any) -> Dict[str, Any]: + if response is None: + return {"type": "text", "body": "", "status": 204} + if isinstance(response, dict): + return {"type": "json", "body": response, "status": 200} + if isinstance(response, str): + return {"type": "text", "body": response, "status": 200} + if hasattr(response, "text") and hasattr(response, "status"): + return { + "type": "text", + "body": response.text + if hasattr(response, "text") + else str(response.body), + "status": response.status, + "headers": dict(response.headers) + if hasattr(response, "headers") + else {}, + } + if hasattr(response, "body") and hasattr(response, "status"): + body = response.body + if isinstance(body, bytes): + try: + return { + "type": "text", + "body": body.decode("utf-8"), + "status": response.status, + } + except UnicodeDecodeError: + return { + "type": "binary", + "body": body.hex(), + "status": response.status, + } + return {"type": "json", "body": body, "status": response.status} + return {"type": "text", "body": str(response), "status": 200} + + +class MockRequest: + def __init__(self, data: Dict[str, Any]): + self.method = data.get("method", "GET") + self.path = data.get("path", "/") + self.query = data.get("query", {}) + self._body = data.get("body", {}) + self._text = data.get("text", "") + self.headers = data.get("headers", {}) + self.content_type = data.get( + "content_type", self.headers.get("Content-Type", "application/json") + ) + self.match_info = data.get("match_info", {}) + + async def json(self) -> Any: + if isinstance(self._body, dict): + return self._body + if isinstance(self._body, str): + return json.loads(self._body) + return {} + + async def post(self) -> Dict[str, Any]: + if isinstance(self._body, dict): + return self._body + return {} + + async def text(self) -> str: + if self._text: + return self._text + if isinstance(self._body, str): + return self._body + if isinstance(self._body, dict): + return json.dumps(self._body) + return "" + + async def read(self) -> bytes: + return (await self.text()).encode("utf-8") diff --git a/execution.py b/execution.py index 5a6d3404c..ee4b34684 100644 --- a/execution.py +++ b/execution.py @@ -1,7 +1,9 @@ import copy +import gc import heapq import inspect import logging +import os import sys import threading import time @@ -42,6 +44,8 @@ from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_re from comfy_api.latest import io, _io from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger +_AIMDO_VBAR_RESET_UNAVAILABLE_LOGGED = False + class ExecutionResult(Enum): SUCCESS = 0 @@ -262,20 +266,31 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f pre_execute_cb(index) # V3 if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)): - # if is just a class, then assign no state, just create clone - if is_class(obj): - type_obj = obj - obj.VALIDATE_CLASS() - class_clone = obj.PREPARE_CLASS_CLONE(v3_data) - # otherwise, use class instance to populate/reuse some fields + # Check for isolated node - skip validation and class cloning + if hasattr(obj, "_pyisolate_extension"): + # Isolated Node: The stub is just a proxy; real validation happens in child process + if v3_data is not None: + inputs = _io.build_nested_inputs(inputs, v3_data) + # Inject hidden inputs so they're available in the isolated child process + inputs.update(v3_data.get("hidden_inputs", {})) + f = getattr(obj, func) + # Standard V3 Node (Existing Logic) + else: - type_obj = type(obj) - type_obj.VALIDATE_CLASS() - class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data) - f = make_locked_method_func(type_obj, func, class_clone) - # in case of dynamic inputs, restructure inputs to expected nested dict - if v3_data is not None: - inputs = _io.build_nested_inputs(inputs, v3_data) + # if is just a class, then assign no resources or state, just create clone + if is_class(obj): + type_obj = obj + obj.VALIDATE_CLASS() + class_clone = obj.PREPARE_CLASS_CLONE(v3_data) + # otherwise, use class instance to populate/reuse some fields + else: + type_obj = type(obj) + type_obj.VALIDATE_CLASS() + class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data) + f = make_locked_method_func(type_obj, func, class_clone) + # in case of dynamic inputs, restructure inputs to expected nested dict + if v3_data is not None: + inputs = _io.build_nested_inputs(inputs, v3_data) # V1 else: f = getattr(obj, func) @@ -537,7 +552,17 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, if args.verbose == "DEBUG": comfy_aimdo.control.analyze() comfy.model_management.reset_cast_buffers() - comfy_aimdo.model_vbar.vbars_reset_watermark_limits() + vbar_lib = getattr(comfy_aimdo.model_vbar, "lib", None) + if vbar_lib is not None: + comfy_aimdo.model_vbar.vbars_reset_watermark_limits() + else: + global _AIMDO_VBAR_RESET_UNAVAILABLE_LOGGED + if not _AIMDO_VBAR_RESET_UNAVAILABLE_LOGGED: + logging.warning( + "DynamicVRAM backend unavailable for watermark reset; " + "skipping vbar reset for this process." + ) + _AIMDO_VBAR_RESET_UNAVAILABLE_LOGGED = True if has_pending_tasks: pending_async_nodes[unique_id] = output_data @@ -546,8 +571,29 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, tasks = [x for x in output_data if isinstance(x, asyncio.Task)] await asyncio.gather(*tasks, return_exceptions=True) unblock() - asyncio.create_task(await_completion()) - return (ExecutionResult.PENDING, None, None) + + # Keep isolation node execution deterministic by default, but allow + # opt-out for diagnostics. + isolation_sequential = os.environ.get("COMFY_ISOLATE_SEQUENTIAL", "1").lower() in ("1", "true", "yes") + if args.use_process_isolation and isolation_sequential: + await await_completion() + results = [] + for r in pending_async_nodes[unique_id]: + if isinstance(r, asyncio.Task): + try: + results.append(r.result()) + except Exception as ex: + del pending_async_nodes[unique_id] + raise ex + else: + results.append(r) + del pending_async_nodes[unique_id] + output_data, output_ui, has_subgraph = get_output_from_returns(results, class_def) + has_pending_tasks = False + + else: + asyncio.create_task(await_completion()) + return (ExecutionResult.PENDING, None, None) if len(output_ui) > 0: ui_outputs[unique_id] = { "meta": { @@ -657,6 +703,46 @@ class PromptExecutor: self.status_messages = [] self.success = True + async def _notify_execution_graph_safe(self, class_types: set[str], *, fail_loud: bool = False) -> None: + if not args.use_process_isolation: + return + try: + from comfy.isolation import notify_execution_graph + await notify_execution_graph(class_types, caches=self.caches.all) + except Exception: + if fail_loud: + raise + logging.debug("][ EX:notify_execution_graph failed", exc_info=True) + + async def _flush_running_extensions_transport_state_safe(self) -> None: + if not args.use_process_isolation: + return + try: + from comfy.isolation import flush_running_extensions_transport_state + await flush_running_extensions_transport_state() + except Exception: + logging.debug("][ EX:flush_running_extensions_transport_state failed", exc_info=True) + + async def _wait_model_patcher_quiescence_safe( + self, + *, + fail_loud: bool = False, + timeout_ms: int = 120000, + marker: str = "EX:wait_model_patcher_idle", + ) -> None: + if not args.use_process_isolation: + return + try: + from comfy.isolation import wait_for_model_patcher_quiescence + + await wait_for_model_patcher_quiescence( + timeout_ms=timeout_ms, fail_loud=fail_loud, marker=marker + ) + except Exception: + if fail_loud: + raise + logging.debug("][ EX:wait_model_patcher_quiescence failed", exc_info=True) + def add_message(self, event, data: dict, broadcast: bool): data = { **data, @@ -711,6 +797,18 @@ class PromptExecutor: asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs)) async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): + if args.use_process_isolation: + # Update RPC event loops for all isolated extensions. + # This is critical for serial workflow execution - each asyncio.run() creates + # a new event loop, and RPC instances must be updated to use it. + try: + from comfy.isolation import update_rpc_event_loops + update_rpc_event_loops() + except ImportError: + pass # Isolation not available + except Exception as e: + logging.getLogger(__name__).warning(f"Failed to update RPC event loops: {e}") + set_preview_method(extra_data.get("preview_method")) nodes.interrupt_processing(False) @@ -723,6 +821,25 @@ class PromptExecutor: self.status_messages = [] self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) + if args.use_process_isolation: + try: + # Boundary cleanup runs at the start of the next workflow in + # isolation mode, matching non-isolated "next prompt" timing. + self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args) + await self._wait_model_patcher_quiescence_safe( + fail_loud=False, + timeout_ms=120000, + marker="EX:boundary_cleanup_wait_idle", + ) + await self._flush_running_extensions_transport_state_safe() + comfy.model_management.unload_all_models() + comfy.model_management.cleanup_models_gc() + comfy.model_management.cleanup_models() + gc.collect() + comfy.model_management.soft_empty_cache() + except Exception: + logging.debug("][ EX:isolation_boundary_cleanup_start failed", exc_info=True) + self._notify_prompt_lifecycle("start", prompt_id) ram_headroom = int(self.cache_args["ram"] * (1024 ** 3)) ram_release_callback = self.caches.outputs.ram_release if self.cache_type == CacheType.RAM_PRESSURE else None @@ -760,6 +877,18 @@ class PromptExecutor: for node_id in list(execute_outputs): execution_list.add_node(node_id) + if args.use_process_isolation: + pending_class_types = set() + for node_id in execution_list.pendingNodes.keys(): + class_type = dynamic_prompt.get_node(node_id)["class_type"] + pending_class_types.add(class_type) + await self._wait_model_patcher_quiescence_safe( + fail_loud=True, + timeout_ms=120000, + marker="EX:notify_graph_wait_idle", + ) + await self._notify_execution_graph_safe(pending_class_types, fail_loud=True) + while not execution_list.is_empty(): node_id, error, ex = await execution_list.stage_node_execution() if error is not None: diff --git a/main.py b/main.py index dbaf2745c..3a45789a9 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,27 @@ +import os +import sys + +IS_PYISOLATE_CHILD = os.environ.get("PYISOLATE_CHILD") == "1" + +if __name__ == "__main__" and IS_PYISOLATE_CHILD: + del os.environ["PYISOLATE_CHILD"] + IS_PYISOLATE_CHILD = False + +CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) +if CURRENT_DIR not in sys.path: + sys.path.insert(0, CURRENT_DIR) + +if not IS_PYISOLATE_CHILD: + python_scripts_dir = os.path.dirname(os.path.realpath(sys.executable)) + path_entries = os.environ.get("PATH", "").split(os.pathsep) + if python_scripts_dir and python_scripts_dir not in path_entries: + os.environ["PATH"] = os.pathsep.join([python_scripts_dir, *path_entries]) + +IS_PRIMARY_PROCESS = (not IS_PYISOLATE_CHILD) and __name__ == "__main__" + import comfy.options comfy.options.enable_args_parsing() -import os import importlib.util import shutil import importlib.metadata @@ -9,26 +29,56 @@ import folder_paths import time from comfy.cli_args import args, enables_dynamic_vram from app.logger import setup_logger -setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) - from app.assets.seeder import asset_seeder from app.assets.services import register_output_files import itertools -import utils.extra_config +import utils.extra_config # noqa: F401 from utils.mime_types import init_mime_types import faulthandler import logging import sys -from comfy_execution.progress import get_progress_state -from comfy_execution.utils import get_executing_context -from comfy_api import feature_flags from app.database.db import init_db, dependencies_available -if __name__ == "__main__": - #NOTE: These do not do anything on core ComfyUI, they are for custom nodes. +import comfy_aimdo.control + +if enables_dynamic_vram(): + if not comfy_aimdo.control.init(): + logging.warning( + "DynamicVRAM requested, but comfy-aimdo failed to initialize early. " + "Will fall back to legacy model loading if device init fails." + ) + +if args.use_process_isolation: + from comfy.isolation import initialize_proxies + initialize_proxies() + + # Explicitly register the ComfyUI adapter for pyisolate (v1.0 architecture) + try: + import pyisolate + from comfy.isolation.adapter import ComfyUIAdapter + pyisolate.register_adapter(ComfyUIAdapter()) + logging.info("PyIsolate adapter registered: comfyui") + except ImportError: + logging.warning("PyIsolate not installed or version too old for explicit registration") + except Exception as e: + logging.error(f"Failed to register PyIsolate adapter: {e}") + + if not IS_PYISOLATE_CHILD: + if 'PYTORCH_CUDA_ALLOC_CONF' not in os.environ: + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'backend:native' + +if not IS_PYISOLATE_CHILD: + from comfy_execution.progress import get_progress_state + from comfy_execution.utils import get_executing_context + from comfy_api import feature_flags + +if IS_PRIMARY_PROCESS: os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' os.environ['DO_NOT_TRACK'] = '1' +if not IS_PYISOLATE_CHILD: + setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) + faulthandler.enable(file=sys.stderr, all_threads=False) import comfy_aimdo.control @@ -93,14 +143,15 @@ if args.enable_manager: def apply_custom_paths(): + from utils import extra_config # Deferred import - spawn re-runs main.py # extra model paths extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") if os.path.isfile(extra_model_paths_config_path): - utils.extra_config.load_extra_path_config(extra_model_paths_config_path) + extra_config.load_extra_path_config(extra_model_paths_config_path) if args.extra_model_paths_config: for config_path in itertools.chain(*args.extra_model_paths_config): - utils.extra_config.load_extra_path_config(config_path) + extra_config.load_extra_path_config(config_path) # --output-directory, --input-directory, --user-directory if args.output_directory: @@ -173,15 +224,17 @@ def execute_prestartup_script(): else: import_message = " (PRESTARTUP FAILED)" logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1])) - logging.info("") + logging.info("") -apply_custom_paths() -init_mime_types() +if not IS_PYISOLATE_CHILD: + apply_custom_paths() + init_mime_types() -if args.enable_manager: +if args.enable_manager and not IS_PYISOLATE_CHILD: comfyui_manager.prestartup() -execute_prestartup_script() +if not IS_PYISOLATE_CHILD: + execute_prestartup_script() # Main code @@ -192,17 +245,17 @@ import gc if 'torch' in sys.modules: logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.") - import comfy.utils -import execution -import server -from protocol import BinaryEventTypes -import nodes -import comfy.model_management -import comfyui_version -import app.logger -import hook_breaker_ac10a0 +if not IS_PYISOLATE_CHILD: + import execution + import server + from protocol import BinaryEventTypes + import nodes + import comfy.model_management + import comfyui_version + import app.logger + import hook_breaker_ac10a0 import comfy.memory_management import comfy.model_patcher @@ -462,6 +515,10 @@ def start_comfyui(asyncio_loop=None): asyncio.set_event_loop(asyncio_loop) prompt_server = server.PromptServer(asyncio_loop) + if args.use_process_isolation: + from comfy.isolation import start_isolation_loading_early + start_isolation_loading_early(asyncio_loop) + if args.enable_manager and not args.disable_manager_ui: comfyui_manager.start() @@ -506,12 +563,13 @@ def start_comfyui(asyncio_loop=None): if __name__ == "__main__": # Running directly, just start ComfyUI. logging.info("Python version: {}".format(sys.version)) - logging.info("ComfyUI version: {}".format(comfyui_version.__version__)) - for package in ("comfy-aimdo", "comfy-kitchen"): - try: - logging.info("{} version: {}".format(package, importlib.metadata.version(package))) - except: - pass + if not IS_PYISOLATE_CHILD: + logging.info("ComfyUI version: {}".format(comfyui_version.__version__)) + for package in ("comfy-aimdo", "comfy-kitchen"): + try: + logging.info("{} version: {}".format(package, importlib.metadata.version(package))) + except: + pass if sys.version_info.major == 3 and sys.version_info.minor < 10: logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.") diff --git a/nodes.py b/nodes.py index 99dc07227..470e86e9a 100644 --- a/nodes.py +++ b/nodes.py @@ -1912,6 +1912,7 @@ class ImageInvert: class ImageBatch: SEARCH_ALIASES = ["combine images", "merge images", "stack images"] + ESSENTIALS_CATEGORY = "Image Tools" @classmethod def INPUT_TYPES(s): @@ -2295,6 +2296,27 @@ async def init_external_custom_nodes(): Returns: None """ + whitelist = set() + isolated_module_paths = set() + if args.use_process_isolation: + from pathlib import Path + from comfy.isolation import await_isolation_loading, get_claimed_paths + from comfy.isolation.host_policy import load_host_policy + + # Load Global Host Policy + host_policy = load_host_policy(Path(folder_paths.base_path)) + whitelist_dict = host_policy.get("whitelist", {}) + # Normalize whitelist keys to lowercase for case-insensitive matching + # (matches ComfyUI-Manager's normalization: project.name.strip().lower()) + whitelist = set(k.strip().lower() for k in whitelist_dict.keys()) + logging.info(f"][ Loaded Whitelist: {len(whitelist)} nodes allowed.") + + isolated_specs = await await_isolation_loading() + for spec in isolated_specs: + NODE_CLASS_MAPPINGS.setdefault(spec.node_name, spec.stub_class) + NODE_DISPLAY_NAME_MAPPINGS.setdefault(spec.node_name, spec.display_name) + isolated_module_paths = get_claimed_paths() + base_node_names = set(NODE_CLASS_MAPPINGS.keys()) node_paths = folder_paths.get_folder_paths("custom_nodes") node_import_times = [] @@ -2318,6 +2340,16 @@ async def init_external_custom_nodes(): logging.info(f"Blocked by policy: {module_path}") continue + if args.use_process_isolation: + if Path(module_path).resolve() in isolated_module_paths: + continue + + # Tri-State Enforcement: If not Isolated (checked above), MUST be Whitelisted. + # Normalize to lowercase for case-insensitive matching (matches ComfyUI-Manager) + if possible_module.strip().lower() not in whitelist: + logging.warning(f"][ REJECTED: Node '{possible_module}' is blocked by security policy (not whitelisted/isolated).") + continue + time_before = time.perf_counter() success = await load_custom_node(module_path, base_node_names, module_parent="custom_nodes") node_import_times.append((time.perf_counter() - time_before, module_path, success)) @@ -2332,6 +2364,14 @@ async def init_external_custom_nodes(): logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1])) logging.info("") + if args.use_process_isolation: + from comfy.isolation import isolated_node_timings + if isolated_node_timings: + logging.info("\nImport times for isolated custom nodes:") + for timing, path, count in sorted(isolated_node_timings): + logging.info("{:6.1f} seconds: {} ({})".format(timing, path, count)) + logging.info("") + async def init_builtin_extra_nodes(): """ Initializes the built-in extra nodes in ComfyUI. diff --git a/server.py b/server.py index 881da8e66..246a626a2 100644 --- a/server.py +++ b/server.py @@ -3,7 +3,6 @@ import sys import asyncio import traceback import time - import nodes import folder_paths import execution @@ -202,6 +201,8 @@ def create_block_external_middleware(): class PromptServer(): def __init__(self, loop): PromptServer.instance = self + if loop is None: + loop = asyncio.get_event_loop() self.user_manager = UserManager() self.model_file_manager = ModelFileManager() @@ -352,6 +353,17 @@ class PromptServer(): extensions.extend(list(map(lambda f: "/extensions/" + urllib.parse.quote( name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files))) + # Include JS files from proxied web directories (isolated nodes) + if args.use_process_isolation: + from comfy.isolation.proxies.web_directory_proxy import get_web_directory_cache + cache = get_web_directory_cache() + for ext_name in cache.extension_names: + for entry in cache.list_files(ext_name): + if entry["relative_path"].endswith(".js"): + extensions.append( + "/extensions/" + urllib.parse.quote(ext_name) + "/" + entry["relative_path"] + ) + return web.json_response(extensions) def get_dir_by_type(dir_type): @@ -1067,6 +1079,36 @@ class PromptServer(): for name, dir in nodes.EXTENSION_WEB_DIRS.items(): self.app.add_routes([web.static('/extensions/' + name, dir)]) + # Add dynamic handler for proxied web directories (isolated nodes) + if args.use_process_isolation: + from comfy.isolation.proxies.web_directory_proxy import ( + get_web_directory_cache, + ALLOWED_EXTENSIONS, + ) + + async def proxied_web_handler(request): + ext_name = request.match_info["ext_name"] + file_path = request.match_info["file_path"] + + suffix = os.path.splitext(file_path)[1].lower() + if suffix not in ALLOWED_EXTENSIONS: + return web.Response(status=403, text="Forbidden file type") + + cache = get_web_directory_cache() + result = cache.get_file(ext_name, file_path) + if result is None: + return web.Response(status=404, text="Not found") + + return web.Response( + body=result["content"], + content_type=result["content_type"], + ) + + self.app.router.add_get( + "/extensions/{ext_name}/{file_path:.*}", + proxied_web_handler, + ) + installed_templates_version = FrontendManager.get_installed_templates_version() use_legacy_templates = True if installed_templates_version: