feat(isolation): execution and extension loading integration

This commit is contained in:
John Pollock 2026-04-30 14:51:08 -05:00
parent 359f972a1b
commit d1155318eb
6 changed files with 1799 additions and 48 deletions

View File

@ -0,0 +1,540 @@
# pylint: disable=cyclic-import,import-outside-toplevel,redefined-outer-name
from __future__ import annotations
import logging
import os
import inspect
import sys
import types
import platform
from pathlib import Path
from typing import Any, Callable, Dict, List, Tuple
import pyisolate
from pyisolate import ExtensionManager, ExtensionManagerConfig
from packaging.requirements import InvalidRequirement, Requirement
from packaging.utils import canonicalize_name
from .manifest_loader import is_cache_valid, load_from_cache, save_to_cache
from .host_policy import load_host_policy
try:
import tomllib
except ImportError:
import tomli as tomllib # type: ignore[no-redef]
logger = logging.getLogger(__name__)
def _register_web_directory(extension_name: str, node_dir: Path) -> None:
"""Register an isolated extension's web directory on the host side."""
import nodes
# Method 1: pyproject.toml [tool.comfy] web field
pyproject = node_dir / "pyproject.toml"
if pyproject.exists():
try:
with pyproject.open("rb") as f:
data = tomllib.load(f)
web_dir_name = data.get("tool", {}).get("comfy", {}).get("web")
if web_dir_name:
web_dir_path = str(node_dir / web_dir_name)
if os.path.isdir(web_dir_path):
nodes.EXTENSION_WEB_DIRS[extension_name] = web_dir_path
logger.debug(
"][ Registered web dir for isolated %s: %s",
extension_name,
web_dir_path,
)
return
except Exception:
pass
# Method 2: __init__.py WEB_DIRECTORY constant (parse without importing)
init_file = node_dir / "__init__.py"
if init_file.exists():
try:
source = init_file.read_text()
for line in source.splitlines():
stripped = line.strip()
if stripped.startswith("WEB_DIRECTORY"):
# Parse: WEB_DIRECTORY = "./web" or WEB_DIRECTORY = "web"
_, _, value = stripped.partition("=")
value = value.strip().strip("\"'")
if value:
web_dir_path = str((node_dir / value).resolve())
if os.path.isdir(web_dir_path):
nodes.EXTENSION_WEB_DIRS[extension_name] = web_dir_path
logger.debug(
"][ Registered web dir for isolated %s: %s",
extension_name,
web_dir_path,
)
return
except Exception:
pass
def _get_extension_type(execution_model: str) -> type[Any]:
if execution_model == "sealed_worker":
return pyisolate.SealedNodeExtension
from .extension_wrapper import ComfyNodeExtension
return ComfyNodeExtension
async def _stop_extension_safe(extension: Any, extension_name: str) -> None:
try:
stop_result = extension.stop()
if inspect.isawaitable(stop_result):
await stop_result
except Exception:
logger.debug("][ %s stop failed", extension_name, exc_info=True)
def _normalize_dependency_spec(dep: str, base_paths: list[Path]) -> str:
req, sep, marker = dep.partition(";")
req = req.strip()
marker_suffix = f";{marker}" if sep else ""
def _resolve_local_path(local_path: str) -> Path | None:
for base in base_paths:
candidate = (base / local_path).resolve()
if candidate.exists():
return candidate
return None
if req.startswith("./") or req.startswith("../"):
resolved = _resolve_local_path(req)
if resolved is not None:
return f"{resolved}{marker_suffix}"
if req.startswith("file://"):
raw = req[len("file://") :]
if raw.startswith("./") or raw.startswith("../"):
resolved = _resolve_local_path(raw)
if resolved is not None:
return f"file://{resolved}{marker_suffix}"
return dep
def _dependency_name_from_spec(dep: str) -> str | None:
stripped = dep.strip()
if not stripped or stripped == "-e" or stripped.startswith("-e "):
return None
if stripped.startswith(("/", "./", "../", "file://")):
return None
try:
return canonicalize_name(Requirement(stripped).name)
except InvalidRequirement:
return None
def _parse_cuda_wheels_config(
tool_config: dict[str, object], dependencies: list[str]
) -> dict[str, object] | None:
raw_config = tool_config.get("cuda_wheels")
if raw_config is None:
return None
if not isinstance(raw_config, dict):
raise ExtensionLoadError("[tool.comfy.isolation.cuda_wheels] must be a table")
index_url = raw_config.get("index_url")
index_urls = raw_config.get("index_urls")
if index_urls is not None:
if not isinstance(index_urls, list) or not all(
isinstance(u, str) and u.strip() for u in index_urls
):
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels.index_urls] must be a list of non-empty strings"
)
elif not isinstance(index_url, str) or not index_url.strip():
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels.index_url] must be a non-empty string"
)
packages = raw_config.get("packages")
if not isinstance(packages, list) or not all(
isinstance(package_name, str) and package_name.strip()
for package_name in packages
):
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels.packages] must be a list of non-empty strings"
)
declared_dependencies = {
dependency_name
for dep in dependencies
if (dependency_name := _dependency_name_from_spec(dep)) is not None
}
normalized_packages = [canonicalize_name(package_name) for package_name in packages]
missing = [
package_name
for package_name in normalized_packages
if package_name not in declared_dependencies
]
if missing:
missing_joined = ", ".join(sorted(missing))
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels.packages] references undeclared dependencies: "
f"{missing_joined}"
)
package_map = raw_config.get("package_map", {})
if not isinstance(package_map, dict):
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels.package_map] must be a table"
)
normalized_package_map: dict[str, str] = {}
for dependency_name, index_package_name in package_map.items():
if not isinstance(dependency_name, str) or not dependency_name.strip():
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels.package_map] keys must be non-empty strings"
)
if not isinstance(index_package_name, str) or not index_package_name.strip():
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels.package_map] values must be non-empty strings"
)
canonical_dependency_name = canonicalize_name(dependency_name)
if canonical_dependency_name not in normalized_packages:
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels.package_map] can only override packages listed in "
"[tool.comfy.isolation.cuda_wheels.packages]"
)
normalized_package_map[canonical_dependency_name] = index_package_name.strip()
result: dict = {
"packages": normalized_packages,
"package_map": normalized_package_map,
}
if index_urls is not None:
result["index_urls"] = [u.rstrip("/") + "/" for u in index_urls]
else:
result["index_url"] = index_url.rstrip("/") + "/"
return result
def get_enforcement_policy() -> Dict[str, bool]:
return {
"force_isolated": os.environ.get("PYISOLATE_ENFORCE_ISOLATED") == "1",
"force_sandbox": os.environ.get("PYISOLATE_ENFORCE_SANDBOX") == "1",
}
class ExtensionLoadError(RuntimeError):
pass
def register_dummy_module(extension_name: str, node_dir: Path) -> None:
normalized_name = extension_name.replace("-", "_").replace(".", "_")
if normalized_name not in sys.modules:
dummy_module = types.ModuleType(normalized_name)
dummy_module.__file__ = str(node_dir / "__init__.py")
dummy_module.__path__ = [str(node_dir)]
dummy_module.__package__ = normalized_name
sys.modules[normalized_name] = dummy_module
def _is_stale_node_cache(cached_data: Dict[str, Dict]) -> bool:
for details in cached_data.values():
if not isinstance(details, dict):
return True
if details.get("is_v3") and "schema_v1" not in details:
return True
return False
async def load_isolated_node(
node_dir: Path,
manifest_path: Path,
logger: logging.Logger,
build_stub_class: Callable[[str, Dict[str, object], Any], type],
venv_root: Path,
extension_managers: List[ExtensionManager],
) -> List[Tuple[str, str, type]]:
try:
with manifest_path.open("rb") as handle:
manifest_data = tomllib.load(handle)
except Exception as e:
logger.warning(f"][ Failed to parse {manifest_path}: {e}")
return []
# Parse [tool.comfy.isolation]
tool_config = manifest_data.get("tool", {}).get("comfy", {}).get("isolation", {})
can_isolate = tool_config.get("can_isolate", False)
share_torch = tool_config.get("share_torch", False)
package_manager = tool_config.get("package_manager", "uv")
is_conda = package_manager == "conda"
execution_model = tool_config.get("execution_model")
if execution_model is None:
execution_model = "sealed_worker" if is_conda else "host-coupled"
if "sealed_host_ro_paths" in tool_config:
raise ValueError(
"Manifest field 'sealed_host_ro_paths' is not allowed. "
"Configure [tool.comfy.host].sealed_worker_ro_import_paths in host policy."
)
# Conda-specific manifest fields
conda_channels: list[str] = (
tool_config.get("conda_channels", []) if is_conda else []
)
conda_dependencies: list[str] = (
tool_config.get("conda_dependencies", []) if is_conda else []
)
conda_platforms: list[str] = (
tool_config.get("conda_platforms", []) if is_conda else []
)
conda_python: str = (
tool_config.get("conda_python", "*") if is_conda else "*"
)
# Parse [project] dependencies
project_config = manifest_data.get("project", {})
dependencies = project_config.get("dependencies", [])
if not isinstance(dependencies, list):
dependencies = []
# Get extension name (default to folder name if not in project.name)
extension_name = project_config.get("name", node_dir.name)
# LOGIC: Isolation Decision
policy = get_enforcement_policy()
isolated = can_isolate or policy["force_isolated"]
if not isolated:
return []
import folder_paths
base_paths = [Path(folder_paths.base_path), node_dir]
dependencies = [
_normalize_dependency_spec(dep, base_paths) if isinstance(dep, str) else dep
for dep in dependencies
]
cuda_wheels = _parse_cuda_wheels_config(tool_config, dependencies)
manager_config = ExtensionManagerConfig(venv_root_path=str(venv_root))
extension_type = _get_extension_type(execution_model)
manager: ExtensionManager = pyisolate.ExtensionManager(
extension_type, manager_config
)
extension_managers.append(manager)
host_policy = load_host_policy(Path(folder_paths.base_path))
sandbox_config = {}
is_linux = platform.system() == "Linux"
if is_conda:
share_torch = False
share_cuda_ipc = False
else:
share_cuda_ipc = share_torch and is_linux
if is_linux and isolated:
sandbox_config = {
"network": host_policy["allow_network"],
"writable_paths": host_policy["writable_paths"],
"readonly_paths": host_policy["readonly_paths"],
}
extension_config: dict = {
"name": extension_name,
"module_path": str(node_dir),
"isolated": True,
"dependencies": dependencies,
"share_torch": share_torch,
"share_cuda_ipc": share_cuda_ipc,
"sandbox_mode": host_policy["sandbox_mode"],
"sandbox": sandbox_config,
}
share_torch_no_deps = tool_config.get("share_torch_no_deps", [])
if share_torch_no_deps:
if not isinstance(share_torch_no_deps, list) or not all(
isinstance(dep, str) and dep.strip() for dep in share_torch_no_deps
):
raise ExtensionLoadError(
"[tool.comfy.isolation.share_torch_no_deps] must be a list of non-empty strings"
)
extension_config["share_torch_no_deps"] = share_torch_no_deps
_is_sealed = execution_model == "sealed_worker"
_is_sandboxed = host_policy["sandbox_mode"] != "disabled" and is_linux
logger.info(
"][ Loading isolated node: %s (torch_share [%s], sealed [%s], sandboxed [%s])",
extension_name,
"x" if share_torch else " ",
"x" if _is_sealed else " ",
"x" if _is_sandboxed else " ",
)
if cuda_wheels is not None:
extension_config["cuda_wheels"] = cuda_wheels
extra_index_urls = tool_config.get("extra_index_urls", [])
if extra_index_urls:
if not isinstance(extra_index_urls, list) or not all(
isinstance(u, str) and u.strip() for u in extra_index_urls
):
raise ExtensionLoadError(
"[tool.comfy.isolation.extra_index_urls] must be a list of non-empty strings"
)
extension_config["extra_index_urls"] = extra_index_urls
# Conda-specific keys
if is_conda:
extension_config["package_manager"] = "conda"
extension_config["conda_channels"] = conda_channels
extension_config["conda_dependencies"] = conda_dependencies
extension_config["conda_python"] = conda_python
find_links = tool_config.get("find_links", [])
if find_links:
extension_config["find_links"] = find_links
if conda_platforms:
extension_config["conda_platforms"] = conda_platforms
if execution_model != "host-coupled":
extension_config["execution_model"] = execution_model
if execution_model == "sealed_worker":
policy_ro_paths = host_policy.get("sealed_worker_ro_import_paths", [])
if isinstance(policy_ro_paths, list) and policy_ro_paths:
extension_config["sealed_host_ro_paths"] = list(policy_ro_paths)
# Sealed workers keep the host RPC service inventory even when the
# child resolves no API classes locally.
extension = manager.load_extension(extension_config)
register_dummy_module(extension_name, node_dir)
# Register host-side event handlers via adapter
from .adapter import ComfyUIAdapter
ComfyUIAdapter.register_host_event_handlers(extension)
# Register web directory on the host — only when sandbox is disabled.
# In sandbox mode, serving untrusted JS to the browser is not safe.
if host_policy["sandbox_mode"] == "disabled":
_register_web_directory(extension_name, node_dir)
# Register for proxied web serving — the child's web dir may have
# content that doesn't exist on the host (e.g., pip-installed viewer
# bundles). The WebDirectoryCache will lazily fetch via RPC.
from .proxies.web_directory_proxy import WebDirectoryProxy, get_web_directory_cache
class ChildWebDirectoryProxy:
def __init__(self, host_extension):
self._host_extension = host_extension
self._caller = None
def _get_caller(self):
self._host_extension.proxy
rpc = self._host_extension._extension.rpc
caller = rpc.create_caller(WebDirectoryProxy, WebDirectoryProxy.get_remote_id())
if self._caller is not caller:
self._caller = caller
return self._caller
def list_web_files(self, extension_name: str):
from .proxies.base import run_sync_rpc_coro
return run_sync_rpc_coro(self._get_caller().list_web_files(extension_name))
def get_web_file(self, extension_name: str, relative_path: str):
from .proxies.base import run_sync_rpc_coro
return run_sync_rpc_coro(
self._get_caller().get_web_file(extension_name, relative_path)
)
cache = get_web_directory_cache()
cache.register_proxy(extension_name, ChildWebDirectoryProxy(extension))
# Try cache first (lazy spawn)
if is_cache_valid(node_dir, manifest_path, venv_root):
cached_data = load_from_cache(node_dir, venv_root)
if cached_data:
if _is_stale_node_cache(cached_data):
pass
else:
try:
flushed = await extension.flush_pending_routes()
logger.info("][ %s flushed %d routes", extension_name, flushed)
except Exception as exc:
logger.warning("][ %s route flush failed: %s", extension_name, exc)
specs: List[Tuple[str, str, type]] = []
for node_name, details in cached_data.items():
stub_cls = build_stub_class(node_name, details, extension)
specs.append(
(node_name, details.get("display_name", node_name), stub_cls)
)
return specs
# Cache miss - spawn process and get metadata
try:
remote_nodes: Dict[str, str] = await extension.list_nodes()
except Exception as exc:
logger.warning(
"][ %s metadata discovery failed, skipping isolated load: %s",
extension_name,
exc,
)
await _stop_extension_safe(extension, extension_name)
return []
if not remote_nodes:
logger.debug("][ %s exposed no isolated nodes; skipping", extension_name)
await _stop_extension_safe(extension, extension_name)
return []
specs: List[Tuple[str, str, type]] = []
cache_data: Dict[str, Dict] = {}
for node_name, display_name in remote_nodes.items():
try:
details = await extension.get_node_details(node_name)
except Exception as exc:
logger.warning(
"][ %s failed to load metadata for %s, skipping node: %s",
extension_name,
node_name,
exc,
)
continue
details["display_name"] = display_name
cache_data[node_name] = details
stub_cls = build_stub_class(node_name, details, extension)
specs.append((node_name, display_name, stub_cls))
if not specs:
logger.warning(
"][ %s produced no usable nodes after metadata scan; skipping",
extension_name,
)
await _stop_extension_safe(extension, extension_name)
return []
# Save metadata to cache for future runs
save_to_cache(node_dir, venv_root, cache_data, manifest_path)
logger.debug(f"][ {extension_name} metadata cached")
# Re-check web directory AFTER child has populated it
if host_policy["sandbox_mode"] == "disabled":
_register_web_directory(extension_name, node_dir)
# Flush any routes the child buffered during module import — must happen
# before router freeze and before we kill the child process.
try:
flushed = await extension.flush_pending_routes()
logger.info("][ %s flushed %d routes", extension_name, flushed)
except Exception as exc:
logger.warning("][ %s route flush failed: %s", extension_name, exc)
# EJECT: Kill process after getting metadata (will respawn on first execution)
await _stop_extension_safe(extension, extension_name)
return specs
__all__ = ["ExtensionLoadError", "register_dummy_module", "load_isolated_node"]

View File

@ -0,0 +1,942 @@
# pylint: disable=consider-using-from-import,cyclic-import,import-outside-toplevel,logging-fstring-interpolation,protected-access,wrong-import-position
from __future__ import annotations
import asyncio
import torch
class AttrDict(dict):
def __getattr__(self, item):
try:
return self[item]
except KeyError as e:
raise AttributeError(item) from e
def copy(self):
return AttrDict(super().copy())
import importlib
import inspect
import json
import logging
import os
import sys
import uuid
from dataclasses import asdict
from typing import Any, Dict, List, Tuple
from pyisolate import ExtensionBase
from comfy_api.internal import _ComfyNodeInternal
LOG_PREFIX = "]["
V3_DISCOVERY_TIMEOUT = 30
_PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
logger = logging.getLogger(__name__)
def _run_prestartup_web_copy(module: Any, module_dir: str, web_dir_path: str) -> None:
"""Run the web asset copy step that prestartup_script.py used to do.
If the module's web/ directory is empty and the module had a
prestartup_script.py that copied assets from pip packages, this
function replicates that work inside the child process.
Generic pattern: reads _PRESTARTUP_WEB_COPY from the module if
defined, otherwise falls back to detecting common asset packages.
"""
import shutil
# Already populated — nothing to do
if os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path)):
return
os.makedirs(web_dir_path, exist_ok=True)
# Try module-defined copy spec first (generic hook for any node pack)
copy_spec = getattr(module, "_PRESTARTUP_WEB_COPY", None)
if copy_spec is not None and callable(copy_spec):
try:
copy_spec(web_dir_path)
logger.info(
"%s Ran _PRESTARTUP_WEB_COPY for %s", LOG_PREFIX, module_dir
)
return
except Exception as e:
logger.warning(
"%s _PRESTARTUP_WEB_COPY failed for %s: %s",
LOG_PREFIX, module_dir, e,
)
# Fallback: detect comfy_3d_viewers and run copy_viewer()
try:
from comfy_3d_viewers import copy_viewer, VIEWER_FILES
viewers = list(VIEWER_FILES.keys())
for viewer in viewers:
try:
copy_viewer(viewer, web_dir_path)
except Exception:
pass
if any(os.scandir(web_dir_path)):
logger.info(
"%s Copied %d viewer types from comfy_3d_viewers to %s",
LOG_PREFIX, len(viewers), web_dir_path,
)
except ImportError:
pass
# Fallback: detect comfy_dynamic_widgets
try:
from comfy_dynamic_widgets import get_js_path
src = os.path.realpath(get_js_path())
if os.path.exists(src):
dst_dir = os.path.join(web_dir_path, "js")
os.makedirs(dst_dir, exist_ok=True)
dst = os.path.join(dst_dir, "dynamic_widgets.js")
shutil.copy2(src, dst)
except ImportError:
pass
def _read_extension_name(module_dir: str) -> str:
"""Read extension name from pyproject.toml, falling back to directory name."""
pyproject = os.path.join(module_dir, "pyproject.toml")
if os.path.exists(pyproject):
try:
import tomllib
except ImportError:
import tomli as tomllib # type: ignore[no-redef]
try:
with open(pyproject, "rb") as f:
data = tomllib.load(f)
name = data.get("project", {}).get("name")
if name:
return name
except Exception:
pass
return os.path.basename(module_dir)
def _flush_tensor_transport_state(marker: str) -> int:
try:
from pyisolate import flush_tensor_keeper # type: ignore[attr-defined]
except Exception:
return 0
if not callable(flush_tensor_keeper):
return 0
flushed = flush_tensor_keeper()
if flushed > 0:
logger.debug(
"%s %s flush_tensor_keeper released=%d", LOG_PREFIX, marker, flushed
)
return flushed
def _relieve_child_vram_pressure(marker: str) -> None:
import comfy.model_management as model_management
model_management.cleanup_models_gc()
model_management.cleanup_models()
device = model_management.get_torch_device()
if not hasattr(device, "type") or device.type == "cpu":
return
required = max(
model_management.minimum_inference_memory(),
_PRE_EXEC_MIN_FREE_VRAM_BYTES,
)
if model_management.get_free_memory(device) < required:
model_management.free_memory(required, device, for_dynamic=True)
if model_management.get_free_memory(device) < required:
model_management.free_memory(required, device, for_dynamic=False)
model_management.cleanup_models()
model_management.soft_empty_cache()
logger.debug("%s %s free_memory target=%d", LOG_PREFIX, marker, required)
def _sanitize_for_transport(value):
primitives = (str, int, float, bool, type(None))
if isinstance(value, primitives):
return value
cls_name = value.__class__.__name__
if cls_name == "FlexibleOptionalInputType":
return {
"__pyisolate_flexible_optional__": True,
"type": _sanitize_for_transport(getattr(value, "type", "*")),
}
if cls_name == "AnyType":
return {"__pyisolate_any_type__": True, "value": str(value)}
if cls_name == "ByPassTypeTuple":
return {
"__pyisolate_bypass_tuple__": [
_sanitize_for_transport(v) for v in tuple(value)
]
}
if isinstance(value, dict):
return {k: _sanitize_for_transport(v) for k, v in value.items()}
if isinstance(value, tuple):
return {"__pyisolate_tuple__": [_sanitize_for_transport(v) for v in value]}
if isinstance(value, list):
return [_sanitize_for_transport(v) for v in value]
return str(value)
# Re-export RemoteObjectHandle from pyisolate for backward compatibility
# The canonical definition is now in pyisolate._internal.remote_handle
from pyisolate._internal.remote_handle import RemoteObjectHandle # noqa: E402,F401
class ComfyNodeExtension(ExtensionBase):
def __init__(self) -> None:
super().__init__()
self.node_classes: Dict[str, type] = {}
self.display_names: Dict[str, str] = {}
self.node_instances: Dict[str, Any] = {}
self.remote_objects: Dict[str, Any] = {}
self._route_handlers: Dict[str, Any] = {}
self._module: Any = None
self._metadata_ready = asyncio.Event()
async def on_module_loaded(self, module: Any) -> None:
try:
self._module = module
# Registries are initialized in host_hooks.py initialize_host_process()
# They auto-register via ProxiedSingleton when instantiated
# NO additional setup required here - if a registry is missing from host_hooks, it WILL fail
self.node_classes = getattr(module, "NODE_CLASS_MAPPINGS", {}) or {}
self.display_names = getattr(module, "NODE_DISPLAY_NAME_MAPPINGS", {}) or {}
self._register_module_routes(module)
# Register web directory with WebDirectoryProxy (child-side)
web_dir_attr = getattr(module, "WEB_DIRECTORY", None)
if web_dir_attr is not None:
module_dir = os.path.dirname(os.path.abspath(module.__file__))
web_dir_path = os.path.abspath(os.path.join(module_dir, web_dir_attr))
ext_name = _read_extension_name(module_dir)
# If web dir is empty, run the copy step that prestartup_script.py did
_run_prestartup_web_copy(module, module_dir, web_dir_path)
if os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path)):
from comfy.isolation.proxies.web_directory_proxy import WebDirectoryProxy
WebDirectoryProxy.register_web_dir(ext_name, web_dir_path)
try:
from comfy_api.latest import ComfyExtension
for name, obj in inspect.getmembers(module):
if not (
inspect.isclass(obj)
and issubclass(obj, ComfyExtension)
and obj is not ComfyExtension
):
continue
if not obj.__module__.startswith(module.__name__):
continue
try:
ext_instance = obj()
try:
await asyncio.wait_for(
ext_instance.on_load(), timeout=V3_DISCOVERY_TIMEOUT
)
except asyncio.TimeoutError:
logger.error(
"%s V3 Extension %s timed out in on_load()",
LOG_PREFIX,
name,
)
continue
try:
v3_nodes = await asyncio.wait_for(
ext_instance.get_node_list(), timeout=V3_DISCOVERY_TIMEOUT
)
except asyncio.TimeoutError:
logger.error(
"%s V3 Extension %s timed out in get_node_list()",
LOG_PREFIX,
name,
)
continue
for node_cls in v3_nodes:
if hasattr(node_cls, "GET_SCHEMA"):
schema = node_cls.GET_SCHEMA()
self.node_classes[schema.node_id] = node_cls
if schema.display_name:
self.display_names[schema.node_id] = schema.display_name
except Exception as e:
logger.error("%s V3 Extension %s failed: %s", LOG_PREFIX, name, e)
except ImportError:
pass
module_name = getattr(module, "__name__", "isolated_nodes")
for node_cls in self.node_classes.values():
if hasattr(node_cls, "__module__") and "/" in str(node_cls.__module__):
node_cls.__module__ = module_name
self.node_instances = {}
finally:
self._metadata_ready.set()
def _register_module_routes(self, module: Any) -> None:
"""Bridge legacy module-level ROUTES declarations into isolated routing."""
routes = getattr(module, "ROUTES", None) or []
if not routes:
return
from comfy.isolation.proxies.prompt_server_impl import PromptServerStub
prompt_server = PromptServerStub()
route_table = getattr(prompt_server, "routes", None)
if route_table is None:
logger.warning("%s Route registration unavailable for %s", LOG_PREFIX, module)
return
for route_spec in routes:
if not isinstance(route_spec, dict):
logger.warning("%s Ignoring non-dict ROUTES entry: %r", LOG_PREFIX, route_spec)
continue
method = str(route_spec.get("method", "")).strip().upper()
path = str(route_spec.get("path", "")).strip()
handler_ref = route_spec.get("handler")
if not method or not path:
logger.warning("%s Ignoring incomplete route spec: %r", LOG_PREFIX, route_spec)
continue
if isinstance(handler_ref, str):
handler = getattr(module, handler_ref, None)
else:
handler = handler_ref
if not callable(handler):
logger.warning(
"%s Ignoring route with missing handler %r for %s %s",
LOG_PREFIX,
handler_ref,
method,
path,
)
continue
decorator = getattr(route_table, method.lower(), None)
if not callable(decorator):
logger.warning("%s Unsupported route method %s for %s", LOG_PREFIX, method, path)
continue
decorator(path)(handler)
self._route_handlers[f"{method} {path}"] = handler
logger.info("%s buffered legacy route %s %s", LOG_PREFIX, method, path)
async def list_nodes(self) -> Dict[str, str]:
await asyncio.wait_for(
self._metadata_ready.wait(), timeout=V3_DISCOVERY_TIMEOUT
)
return {name: self.display_names.get(name, name) for name in self.node_classes}
async def get_node_info(self, node_name: str) -> Dict[str, Any]:
return await self.get_node_details(node_name)
async def get_node_details(self, node_name: str) -> Dict[str, Any]:
await asyncio.wait_for(
self._metadata_ready.wait(), timeout=V3_DISCOVERY_TIMEOUT
)
node_cls = self._get_node_class(node_name)
is_v3 = issubclass(node_cls, _ComfyNodeInternal)
input_types_raw = (
node_cls.INPUT_TYPES() if hasattr(node_cls, "INPUT_TYPES") else {}
)
output_is_list = getattr(node_cls, "OUTPUT_IS_LIST", None)
if output_is_list is not None:
output_is_list = tuple(bool(x) for x in output_is_list)
details: Dict[str, Any] = {
"input_types": _sanitize_for_transport(input_types_raw),
"return_types": tuple(
str(t) for t in getattr(node_cls, "RETURN_TYPES", ())
),
"return_names": getattr(node_cls, "RETURN_NAMES", None),
"function": str(getattr(node_cls, "FUNCTION", "execute")),
"category": str(getattr(node_cls, "CATEGORY", "")),
"output_node": bool(getattr(node_cls, "OUTPUT_NODE", False)),
"output_is_list": output_is_list,
"is_v3": is_v3,
}
if is_v3:
try:
schema = node_cls.GET_SCHEMA()
schema_v1 = asdict(schema.get_v1_info(node_cls))
try:
schema_v3 = asdict(schema.get_v3_info(node_cls))
except (AttributeError, TypeError):
schema_v3 = self._build_schema_v3_fallback(schema)
details.update(
{
"schema_v1": schema_v1,
"schema_v3": schema_v3,
"hidden": [h.value for h in (schema.hidden or [])],
"description": getattr(schema, "description", ""),
"deprecated": bool(getattr(node_cls, "DEPRECATED", False)),
"experimental": bool(getattr(node_cls, "EXPERIMENTAL", False)),
"api_node": bool(getattr(node_cls, "API_NODE", False)),
"input_is_list": bool(
getattr(node_cls, "INPUT_IS_LIST", False)
),
"not_idempotent": bool(
getattr(node_cls, "NOT_IDEMPOTENT", False)
),
"accept_all_inputs": bool(
getattr(node_cls, "ACCEPT_ALL_INPUTS", False)
),
}
)
except Exception as exc:
logger.warning(
"%s V3 schema serialization failed for %s: %s",
LOG_PREFIX,
node_name,
exc,
)
return details
def _build_schema_v3_fallback(self, schema) -> Dict[str, Any]:
input_dict: Dict[str, Any] = {}
output_dict: Dict[str, Any] = {}
hidden_list: List[str] = []
if getattr(schema, "inputs", None):
for inp in schema.inputs:
self._add_schema_io_v3(inp, input_dict)
if getattr(schema, "outputs", None):
for out in schema.outputs:
self._add_schema_io_v3(out, output_dict)
if getattr(schema, "hidden", None):
for h in schema.hidden:
hidden_list.append(getattr(h, "value", str(h)))
return {
"input": input_dict,
"output": output_dict,
"hidden": hidden_list,
"name": getattr(schema, "node_id", None),
"display_name": getattr(schema, "display_name", None),
"description": getattr(schema, "description", None),
"category": getattr(schema, "category", None),
"output_node": getattr(schema, "is_output_node", False),
"deprecated": getattr(schema, "is_deprecated", False),
"experimental": getattr(schema, "is_experimental", False),
"api_node": getattr(schema, "is_api_node", False),
}
def _add_schema_io_v3(self, io_obj: Any, target: Dict[str, Any]) -> None:
io_id = getattr(io_obj, "id", None)
if io_id is None:
return
io_type_fn = getattr(io_obj, "get_io_type", None)
io_type = (
io_type_fn() if callable(io_type_fn) else getattr(io_obj, "io_type", None)
)
as_dict_fn = getattr(io_obj, "as_dict", None)
payload = as_dict_fn() if callable(as_dict_fn) else {}
target[str(io_id)] = (io_type, payload)
async def get_input_types(self, node_name: str) -> Dict[str, Any]:
node_cls = self._get_node_class(node_name)
if hasattr(node_cls, "INPUT_TYPES"):
return node_cls.INPUT_TYPES()
return {}
async def execute_node(self, node_name: str, **inputs: Any) -> Tuple[Any, ...]:
logger.debug(
"%s ISO:child_execute_start ext=%s node=%s input_keys=%d",
LOG_PREFIX,
getattr(self, "name", "?"),
node_name,
len(inputs),
)
if os.environ.get("PYISOLATE_CHILD") == "1":
_relieve_child_vram_pressure("EXT:pre_execute")
resolved_inputs = self._resolve_remote_objects(inputs)
instance = self._get_node_instance(node_name)
node_cls = self._get_node_class(node_name)
# V3 API nodes expect hidden parameters in cls.hidden, not as kwargs
# Hidden params come through RPC as string keys like "Hidden.prompt"
from comfy_api.latest._io import Hidden, HiddenHolder
# Map string representations back to Hidden enum keys
hidden_string_map = {
"Hidden.unique_id": Hidden.unique_id,
"Hidden.prompt": Hidden.prompt,
"Hidden.extra_pnginfo": Hidden.extra_pnginfo,
"Hidden.dynprompt": Hidden.dynprompt,
"Hidden.auth_token_comfy_org": Hidden.auth_token_comfy_org,
"Hidden.api_key_comfy_org": Hidden.api_key_comfy_org,
# Uppercase enum VALUE forms — V3 execution engine passes these
"UNIQUE_ID": Hidden.unique_id,
"PROMPT": Hidden.prompt,
"EXTRA_PNGINFO": Hidden.extra_pnginfo,
"DYNPROMPT": Hidden.dynprompt,
"AUTH_TOKEN_COMFY_ORG": Hidden.auth_token_comfy_org,
"API_KEY_COMFY_ORG": Hidden.api_key_comfy_org,
}
# Find and extract hidden parameters (both enum and string form)
hidden_found = {}
keys_to_remove = []
for key in list(resolved_inputs.keys()):
# Check string form first (from RPC serialization)
if key in hidden_string_map:
hidden_found[hidden_string_map[key]] = resolved_inputs[key]
keys_to_remove.append(key)
# Also check enum form (direct calls)
elif isinstance(key, Hidden):
hidden_found[key] = resolved_inputs[key]
keys_to_remove.append(key)
# Remove hidden params from kwargs
for key in keys_to_remove:
resolved_inputs.pop(key)
# Set hidden on node class if any hidden params found
if hidden_found:
if not hasattr(node_cls, "hidden") or node_cls.hidden is None:
node_cls.hidden = HiddenHolder.from_dict(hidden_found)
else:
# Update existing hidden holder
for key, value in hidden_found.items():
setattr(node_cls.hidden, key.value.lower(), value)
# INPUT_IS_LIST: ComfyUI's executor passes all inputs as lists when this
# flag is set. The isolation RPC delivers unwrapped values, so we must
# wrap each input in a single-element list to match the contract.
if getattr(node_cls, "INPUT_IS_LIST", False):
resolved_inputs = {k: [v] for k, v in resolved_inputs.items()}
function_name = getattr(node_cls, "FUNCTION", "execute")
if not hasattr(instance, function_name):
raise AttributeError(f"Node {node_name} missing callable '{function_name}'")
handler = getattr(instance, function_name)
try:
import torch
if asyncio.iscoroutinefunction(handler):
with torch.inference_mode():
result = await handler(**resolved_inputs)
else:
import functools
def _run_with_inference_mode(**kwargs):
with torch.inference_mode():
return handler(**kwargs)
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(
None, functools.partial(_run_with_inference_mode, **resolved_inputs)
)
except Exception:
logger.exception(
"%s ISO:child_execute_error ext=%s node=%s",
LOG_PREFIX,
getattr(self, "name", "?"),
node_name,
)
raise
if type(result).__name__ == "NodeOutput":
node_output_dict = {
"__node_output__": True,
"args": self._wrap_unpicklable_objects(result.args),
}
if result.ui is not None:
node_output_dict["ui"] = self._wrap_unpicklable_objects(result.ui)
if getattr(result, "expand", None) is not None:
node_output_dict["expand"] = result.expand
if getattr(result, "block_execution", None) is not None:
node_output_dict["block_execution"] = result.block_execution
return node_output_dict
if self._is_comfy_protocol_return(result):
wrapped = self._wrap_unpicklable_objects(result)
return wrapped
if not isinstance(result, tuple):
result = (result,)
wrapped = self._wrap_unpicklable_objects(result)
return wrapped
async def flush_pending_routes(self) -> int:
"""Flush buffered route registrations to host via RPC. Called by host after node discovery."""
from comfy.isolation.proxies.prompt_server_impl import PromptServerStub
return await PromptServerStub.flush_child_routes()
async def flush_transport_state(self) -> int:
if os.environ.get("PYISOLATE_CHILD") != "1":
return 0
logger.debug(
"%s ISO:child_flush_start ext=%s", LOG_PREFIX, getattr(self, "name", "?")
)
flushed = _flush_tensor_transport_state("EXT:workflow_end")
try:
from comfy.isolation.model_patcher_proxy_registry import (
ModelPatcherRegistry,
)
registry = ModelPatcherRegistry()
removed = registry.sweep_pending_cleanup()
if removed > 0:
logger.debug(
"%s EXT:workflow_end registry sweep removed=%d", LOG_PREFIX, removed
)
except Exception:
logger.debug(
"%s EXT:workflow_end registry sweep failed", LOG_PREFIX, exc_info=True
)
logger.debug(
"%s ISO:child_flush_done ext=%s flushed=%d",
LOG_PREFIX,
getattr(self, "name", "?"),
flushed,
)
return flushed
async def get_remote_object(self, object_id: str) -> Any:
"""Retrieve a remote object by ID for host-side deserialization."""
if object_id not in self.remote_objects:
raise KeyError(f"Remote object {object_id} not found")
return self.remote_objects[object_id]
def _store_remote_object_handle(self, obj: Any) -> RemoteObjectHandle:
object_id = str(uuid.uuid4())
self.remote_objects[object_id] = obj
return RemoteObjectHandle(object_id, type(obj).__name__)
async def call_remote_object_method(
self,
object_id: str,
method_name: str,
*args: Any,
**kwargs: Any,
) -> Any:
"""Invoke a method or attribute-backed accessor on a child-owned object."""
obj = await self.get_remote_object(object_id)
if method_name == "get_patcher_attr":
return getattr(obj, args[0])
if method_name == "get_model_options":
return getattr(obj, "model_options")
if method_name == "set_model_options":
setattr(obj, "model_options", args[0])
return None
if method_name == "get_object_patches":
return getattr(obj, "object_patches")
if method_name == "get_patches":
return getattr(obj, "patches")
if method_name == "get_wrappers":
return getattr(obj, "wrappers")
if method_name == "get_callbacks":
return getattr(obj, "callbacks")
if method_name == "get_load_device":
return getattr(obj, "load_device")
if method_name == "get_offload_device":
return getattr(obj, "offload_device")
if method_name == "get_hook_mode":
return getattr(obj, "hook_mode")
if method_name == "get_parent":
parent = getattr(obj, "parent", None)
if parent is None:
return None
return self._store_remote_object_handle(parent)
if method_name == "get_inner_model_attr":
attr_name = args[0]
if hasattr(obj.model, attr_name):
return getattr(obj.model, attr_name)
if hasattr(obj, attr_name):
return getattr(obj, attr_name)
return None
if method_name == "inner_model_apply_model":
return obj.model.apply_model(*args[0], **args[1])
if method_name == "inner_model_extra_conds_shapes":
return obj.model.extra_conds_shapes(*args[0], **args[1])
if method_name == "inner_model_extra_conds":
return obj.model.extra_conds(*args[0], **args[1])
if method_name == "inner_model_memory_required":
return obj.model.memory_required(*args[0], **args[1])
if method_name == "process_latent_in":
return obj.model.process_latent_in(*args[0], **args[1])
if method_name == "process_latent_out":
return obj.model.process_latent_out(*args[0], **args[1])
if method_name == "scale_latent_inpaint":
return obj.model.scale_latent_inpaint(*args[0], **args[1])
if method_name.startswith("get_"):
attr_name = method_name[4:]
if hasattr(obj, attr_name):
return getattr(obj, attr_name)
target = getattr(obj, method_name)
if callable(target):
result = target(*args, **kwargs)
if inspect.isawaitable(result):
result = await result
if type(result).__name__ == "ModelPatcher":
return self._store_remote_object_handle(result)
return result
if args or kwargs:
raise TypeError(f"{method_name} is not callable on remote object {object_id}")
return target
def _wrap_unpicklable_objects(self, data: Any) -> Any:
if isinstance(data, (str, int, float, bool, type(None))):
return data
if isinstance(data, torch.Tensor):
tensor = data.detach() if data.requires_grad else data
if os.environ.get("PYISOLATE_CHILD") == "1" and tensor.device.type != "cpu":
return tensor.cpu()
return tensor
# Special-case clip vision outputs: preserve attribute access by packing fields
if hasattr(data, "penultimate_hidden_states") or hasattr(
data, "last_hidden_state"
):
fields = {}
for attr in (
"penultimate_hidden_states",
"last_hidden_state",
"image_embeds",
"text_embeds",
):
if hasattr(data, attr):
try:
fields[attr] = self._wrap_unpicklable_objects(
getattr(data, attr)
)
except Exception:
pass
if fields:
return {"__pyisolate_attribute_container__": True, "data": fields}
# Avoid converting arbitrary objects with stateful methods (models, etc.)
# They will be handled via RemoteObjectHandle below.
type_name = type(data).__name__
if type_name == "ModelPatcherProxy":
return {"__type__": "ModelPatcherRef", "model_id": data._instance_id}
if type_name == "CLIPProxy":
return {"__type__": "CLIPRef", "clip_id": data._instance_id}
if type_name == "VAEProxy":
return {"__type__": "VAERef", "vae_id": data._instance_id}
if type_name == "ModelSamplingProxy":
return {"__type__": "ModelSamplingRef", "ms_id": data._instance_id}
if isinstance(data, (list, tuple)):
wrapped = [self._wrap_unpicklable_objects(item) for item in data]
return tuple(wrapped) if isinstance(data, tuple) else wrapped
if isinstance(data, dict):
converted_dict = {
k: self._wrap_unpicklable_objects(v) for k, v in data.items()
}
return {"__pyisolate_attrdict__": True, "data": converted_dict}
from pyisolate._internal.serialization_registry import SerializerRegistry
registry = SerializerRegistry.get_instance()
if registry.is_data_type(type_name):
serializer = registry.get_serializer(type_name)
if serializer:
return serializer(data)
return self._store_remote_object_handle(data)
def _resolve_remote_objects(self, data: Any) -> Any:
if isinstance(data, RemoteObjectHandle):
if data.object_id not in self.remote_objects:
raise KeyError(f"Remote object {data.object_id} not found")
return self.remote_objects[data.object_id]
if isinstance(data, dict):
ref_type = data.get("__type__")
if ref_type in ("CLIPRef", "ModelPatcherRef", "VAERef"):
from pyisolate._internal.model_serialization import (
deserialize_proxy_result,
)
return deserialize_proxy_result(data)
if ref_type == "ModelSamplingRef":
from pyisolate._internal.model_serialization import (
deserialize_proxy_result,
)
return deserialize_proxy_result(data)
return {k: self._resolve_remote_objects(v) for k, v in data.items()}
if isinstance(data, (list, tuple)):
resolved = [self._resolve_remote_objects(item) for item in data]
return tuple(resolved) if isinstance(data, tuple) else resolved
return data
def _get_node_class(self, node_name: str) -> type:
if node_name not in self.node_classes:
raise KeyError(f"Unknown node: {node_name}")
return self.node_classes[node_name]
def _get_node_instance(self, node_name: str) -> Any:
if node_name not in self.node_instances:
if node_name not in self.node_classes:
raise KeyError(f"Unknown node: {node_name}")
self.node_instances[node_name] = self.node_classes[node_name]()
return self.node_instances[node_name]
async def before_module_loaded(self) -> None:
try:
from comfy.isolation import initialize_proxies
initialize_proxies()
except Exception as e:
logger.error(
"%s before_module_loaded initialize_proxies FAILED: %s", LOG_PREFIX, e
)
await super().before_module_loaded()
try:
from comfy_api.latest import ComfyAPI_latest
from .proxies.progress_proxy import ProgressProxy
ComfyAPI_latest.Execution = ProgressProxy
# ComfyAPI_latest.execution = ProgressProxy() # Eliminated to avoid Singleton collision
# fp_proxy = FolderPathsProxy() # Eliminated to avoid Singleton collision
# latest_ui.folder_paths = fp_proxy
# latest_resources.folder_paths = fp_proxy
except Exception:
pass
async def call_route_handler(
self,
handler_module: str,
handler_func: str,
request_data: Dict[str, Any],
) -> Any:
cache_key = f"{handler_module}.{handler_func}"
if cache_key not in self._route_handlers:
if self._module is not None and hasattr(self._module, "__file__"):
node_dir = os.path.dirname(self._module.__file__)
if node_dir not in sys.path:
sys.path.insert(0, node_dir)
try:
module = importlib.import_module(handler_module)
self._route_handlers[cache_key] = getattr(module, handler_func)
except (ImportError, AttributeError) as e:
raise ValueError(f"Route handler not found: {cache_key}") from e
handler = self._route_handlers[cache_key]
mock_request = MockRequest(request_data)
if asyncio.iscoroutinefunction(handler):
result = await handler(mock_request)
else:
result = handler(mock_request)
return self._serialize_response(result)
def _is_comfy_protocol_return(self, result: Any) -> bool:
"""
Check if the result matches the ComfyUI 'Protocol Return' schema.
A Protocol Return is a dictionary containing specific reserved keys that
ComfyUI's execution engine interprets as instructions (UI updates,
Workflow expansion, etc.) rather than purely data outputs.
Schema:
- Must be a dict
- Must contain at least one of: 'ui', 'result', 'expand'
"""
if not isinstance(result, dict):
return False
return any(key in result for key in ("ui", "result", "expand"))
def _serialize_response(self, response: Any) -> Dict[str, Any]:
if response is None:
return {"type": "text", "body": "", "status": 204}
if isinstance(response, dict):
return {"type": "json", "body": response, "status": 200}
if isinstance(response, str):
return {"type": "text", "body": response, "status": 200}
if hasattr(response, "text") and hasattr(response, "status"):
return {
"type": "text",
"body": response.text
if hasattr(response, "text")
else str(response.body),
"status": response.status,
"headers": dict(response.headers)
if hasattr(response, "headers")
else {},
}
if hasattr(response, "body") and hasattr(response, "status"):
body = response.body
if isinstance(body, bytes):
try:
return {
"type": "text",
"body": body.decode("utf-8"),
"status": response.status,
}
except UnicodeDecodeError:
return {
"type": "binary",
"body": body.hex(),
"status": response.status,
}
return {"type": "json", "body": body, "status": response.status}
return {"type": "text", "body": str(response), "status": 200}
class MockRequest:
def __init__(self, data: Dict[str, Any]):
self.method = data.get("method", "GET")
self.path = data.get("path", "/")
self.query = data.get("query", {})
self._body = data.get("body", {})
self._text = data.get("text", "")
self.headers = data.get("headers", {})
self.content_type = data.get(
"content_type", self.headers.get("Content-Type", "application/json")
)
self.match_info = data.get("match_info", {})
async def json(self) -> Any:
if isinstance(self._body, dict):
return self._body
if isinstance(self._body, str):
return json.loads(self._body)
return {}
async def post(self) -> Dict[str, Any]:
if isinstance(self._body, dict):
return self._body
return {}
async def text(self) -> str:
if self._text:
return self._text
if isinstance(self._body, str):
return self._body
if isinstance(self._body, dict):
return json.dumps(self._body)
return ""
async def read(self) -> bytes:
return (await self.text()).encode("utf-8")

View File

@ -1,7 +1,9 @@
import copy
import gc
import heapq
import inspect
import logging
import os
import sys
import threading
import time
@ -42,6 +44,8 @@ from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_re
from comfy_api.latest import io, _io
from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger
_AIMDO_VBAR_RESET_UNAVAILABLE_LOGGED = False
class ExecutionResult(Enum):
SUCCESS = 0
@ -262,20 +266,31 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
pre_execute_cb(index)
# V3
if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)):
# if is just a class, then assign no state, just create clone
if is_class(obj):
type_obj = obj
obj.VALIDATE_CLASS()
class_clone = obj.PREPARE_CLASS_CLONE(v3_data)
# otherwise, use class instance to populate/reuse some fields
# Check for isolated node - skip validation and class cloning
if hasattr(obj, "_pyisolate_extension"):
# Isolated Node: The stub is just a proxy; real validation happens in child process
if v3_data is not None:
inputs = _io.build_nested_inputs(inputs, v3_data)
# Inject hidden inputs so they're available in the isolated child process
inputs.update(v3_data.get("hidden_inputs", {}))
f = getattr(obj, func)
# Standard V3 Node (Existing Logic)
else:
type_obj = type(obj)
type_obj.VALIDATE_CLASS()
class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data)
f = make_locked_method_func(type_obj, func, class_clone)
# in case of dynamic inputs, restructure inputs to expected nested dict
if v3_data is not None:
inputs = _io.build_nested_inputs(inputs, v3_data)
# if is just a class, then assign no resources or state, just create clone
if is_class(obj):
type_obj = obj
obj.VALIDATE_CLASS()
class_clone = obj.PREPARE_CLASS_CLONE(v3_data)
# otherwise, use class instance to populate/reuse some fields
else:
type_obj = type(obj)
type_obj.VALIDATE_CLASS()
class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data)
f = make_locked_method_func(type_obj, func, class_clone)
# in case of dynamic inputs, restructure inputs to expected nested dict
if v3_data is not None:
inputs = _io.build_nested_inputs(inputs, v3_data)
# V1
else:
f = getattr(obj, func)
@ -537,7 +552,17 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
if args.verbose == "DEBUG":
comfy_aimdo.control.analyze()
comfy.model_management.reset_cast_buffers()
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
vbar_lib = getattr(comfy_aimdo.model_vbar, "lib", None)
if vbar_lib is not None:
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
else:
global _AIMDO_VBAR_RESET_UNAVAILABLE_LOGGED
if not _AIMDO_VBAR_RESET_UNAVAILABLE_LOGGED:
logging.warning(
"DynamicVRAM backend unavailable for watermark reset; "
"skipping vbar reset for this process."
)
_AIMDO_VBAR_RESET_UNAVAILABLE_LOGGED = True
if has_pending_tasks:
pending_async_nodes[unique_id] = output_data
@ -546,8 +571,29 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
tasks = [x for x in output_data if isinstance(x, asyncio.Task)]
await asyncio.gather(*tasks, return_exceptions=True)
unblock()
asyncio.create_task(await_completion())
return (ExecutionResult.PENDING, None, None)
# Keep isolation node execution deterministic by default, but allow
# opt-out for diagnostics.
isolation_sequential = os.environ.get("COMFY_ISOLATE_SEQUENTIAL", "1").lower() in ("1", "true", "yes")
if args.use_process_isolation and isolation_sequential:
await await_completion()
results = []
for r in pending_async_nodes[unique_id]:
if isinstance(r, asyncio.Task):
try:
results.append(r.result())
except Exception as ex:
del pending_async_nodes[unique_id]
raise ex
else:
results.append(r)
del pending_async_nodes[unique_id]
output_data, output_ui, has_subgraph = get_output_from_returns(results, class_def)
has_pending_tasks = False
else:
asyncio.create_task(await_completion())
return (ExecutionResult.PENDING, None, None)
if len(output_ui) > 0:
ui_outputs[unique_id] = {
"meta": {
@ -657,6 +703,46 @@ class PromptExecutor:
self.status_messages = []
self.success = True
async def _notify_execution_graph_safe(self, class_types: set[str], *, fail_loud: bool = False) -> None:
if not args.use_process_isolation:
return
try:
from comfy.isolation import notify_execution_graph
await notify_execution_graph(class_types, caches=self.caches.all)
except Exception:
if fail_loud:
raise
logging.debug("][ EX:notify_execution_graph failed", exc_info=True)
async def _flush_running_extensions_transport_state_safe(self) -> None:
if not args.use_process_isolation:
return
try:
from comfy.isolation import flush_running_extensions_transport_state
await flush_running_extensions_transport_state()
except Exception:
logging.debug("][ EX:flush_running_extensions_transport_state failed", exc_info=True)
async def _wait_model_patcher_quiescence_safe(
self,
*,
fail_loud: bool = False,
timeout_ms: int = 120000,
marker: str = "EX:wait_model_patcher_idle",
) -> None:
if not args.use_process_isolation:
return
try:
from comfy.isolation import wait_for_model_patcher_quiescence
await wait_for_model_patcher_quiescence(
timeout_ms=timeout_ms, fail_loud=fail_loud, marker=marker
)
except Exception:
if fail_loud:
raise
logging.debug("][ EX:wait_model_patcher_quiescence failed", exc_info=True)
def add_message(self, event, data: dict, broadcast: bool):
data = {
**data,
@ -711,6 +797,18 @@ class PromptExecutor:
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
if args.use_process_isolation:
# Update RPC event loops for all isolated extensions.
# This is critical for serial workflow execution - each asyncio.run() creates
# a new event loop, and RPC instances must be updated to use it.
try:
from comfy.isolation import update_rpc_event_loops
update_rpc_event_loops()
except ImportError:
pass # Isolation not available
except Exception as e:
logging.getLogger(__name__).warning(f"Failed to update RPC event loops: {e}")
set_preview_method(extra_data.get("preview_method"))
nodes.interrupt_processing(False)
@ -723,6 +821,25 @@ class PromptExecutor:
self.status_messages = []
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
if args.use_process_isolation:
try:
# Boundary cleanup runs at the start of the next workflow in
# isolation mode, matching non-isolated "next prompt" timing.
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
await self._wait_model_patcher_quiescence_safe(
fail_loud=False,
timeout_ms=120000,
marker="EX:boundary_cleanup_wait_idle",
)
await self._flush_running_extensions_transport_state_safe()
comfy.model_management.unload_all_models()
comfy.model_management.cleanup_models_gc()
comfy.model_management.cleanup_models()
gc.collect()
comfy.model_management.soft_empty_cache()
except Exception:
logging.debug("][ EX:isolation_boundary_cleanup_start failed", exc_info=True)
self._notify_prompt_lifecycle("start", prompt_id)
ram_headroom = int(self.cache_args["ram"] * (1024 ** 3))
ram_release_callback = self.caches.outputs.ram_release if self.cache_type == CacheType.RAM_PRESSURE else None
@ -760,6 +877,18 @@ class PromptExecutor:
for node_id in list(execute_outputs):
execution_list.add_node(node_id)
if args.use_process_isolation:
pending_class_types = set()
for node_id in execution_list.pendingNodes.keys():
class_type = dynamic_prompt.get_node(node_id)["class_type"]
pending_class_types.add(class_type)
await self._wait_model_patcher_quiescence_safe(
fail_loud=True,
timeout_ms=120000,
marker="EX:notify_graph_wait_idle",
)
await self._notify_execution_graph_safe(pending_class_types, fail_loud=True)
while not execution_list.is_empty():
node_id, error, ex = await execution_list.stage_node_execution()
if error is not None:

120
main.py
View File

@ -1,7 +1,27 @@
import os
import sys
IS_PYISOLATE_CHILD = os.environ.get("PYISOLATE_CHILD") == "1"
if __name__ == "__main__" and IS_PYISOLATE_CHILD:
del os.environ["PYISOLATE_CHILD"]
IS_PYISOLATE_CHILD = False
CURRENT_DIR = os.path.dirname(os.path.realpath(__file__))
if CURRENT_DIR not in sys.path:
sys.path.insert(0, CURRENT_DIR)
if not IS_PYISOLATE_CHILD:
python_scripts_dir = os.path.dirname(os.path.realpath(sys.executable))
path_entries = os.environ.get("PATH", "").split(os.pathsep)
if python_scripts_dir and python_scripts_dir not in path_entries:
os.environ["PATH"] = os.pathsep.join([python_scripts_dir, *path_entries])
IS_PRIMARY_PROCESS = (not IS_PYISOLATE_CHILD) and __name__ == "__main__"
import comfy.options
comfy.options.enable_args_parsing()
import os
import importlib.util
import shutil
import importlib.metadata
@ -9,26 +29,56 @@ import folder_paths
import time
from comfy.cli_args import args, enables_dynamic_vram
from app.logger import setup_logger
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
from app.assets.seeder import asset_seeder
from app.assets.services import register_output_files
import itertools
import utils.extra_config
import utils.extra_config # noqa: F401
from utils.mime_types import init_mime_types
import faulthandler
import logging
import sys
from comfy_execution.progress import get_progress_state
from comfy_execution.utils import get_executing_context
from comfy_api import feature_flags
from app.database.db import init_db, dependencies_available
if __name__ == "__main__":
#NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
import comfy_aimdo.control
if enables_dynamic_vram():
if not comfy_aimdo.control.init():
logging.warning(
"DynamicVRAM requested, but comfy-aimdo failed to initialize early. "
"Will fall back to legacy model loading if device init fails."
)
if args.use_process_isolation:
from comfy.isolation import initialize_proxies
initialize_proxies()
# Explicitly register the ComfyUI adapter for pyisolate (v1.0 architecture)
try:
import pyisolate
from comfy.isolation.adapter import ComfyUIAdapter
pyisolate.register_adapter(ComfyUIAdapter())
logging.info("PyIsolate adapter registered: comfyui")
except ImportError:
logging.warning("PyIsolate not installed or version too old for explicit registration")
except Exception as e:
logging.error(f"Failed to register PyIsolate adapter: {e}")
if not IS_PYISOLATE_CHILD:
if 'PYTORCH_CUDA_ALLOC_CONF' not in os.environ:
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'backend:native'
if not IS_PYISOLATE_CHILD:
from comfy_execution.progress import get_progress_state
from comfy_execution.utils import get_executing_context
from comfy_api import feature_flags
if IS_PRIMARY_PROCESS:
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
os.environ['DO_NOT_TRACK'] = '1'
if not IS_PYISOLATE_CHILD:
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
faulthandler.enable(file=sys.stderr, all_threads=False)
import comfy_aimdo.control
@ -93,14 +143,15 @@ if args.enable_manager:
def apply_custom_paths():
from utils import extra_config # Deferred import - spawn re-runs main.py
# extra model paths
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
if os.path.isfile(extra_model_paths_config_path):
utils.extra_config.load_extra_path_config(extra_model_paths_config_path)
extra_config.load_extra_path_config(extra_model_paths_config_path)
if args.extra_model_paths_config:
for config_path in itertools.chain(*args.extra_model_paths_config):
utils.extra_config.load_extra_path_config(config_path)
extra_config.load_extra_path_config(config_path)
# --output-directory, --input-directory, --user-directory
if args.output_directory:
@ -173,15 +224,17 @@ def execute_prestartup_script():
else:
import_message = " (PRESTARTUP FAILED)"
logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1]))
logging.info("")
logging.info("")
apply_custom_paths()
init_mime_types()
if not IS_PYISOLATE_CHILD:
apply_custom_paths()
init_mime_types()
if args.enable_manager:
if args.enable_manager and not IS_PYISOLATE_CHILD:
comfyui_manager.prestartup()
execute_prestartup_script()
if not IS_PYISOLATE_CHILD:
execute_prestartup_script()
# Main code
@ -192,17 +245,17 @@ import gc
if 'torch' in sys.modules:
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
import comfy.utils
import execution
import server
from protocol import BinaryEventTypes
import nodes
import comfy.model_management
import comfyui_version
import app.logger
import hook_breaker_ac10a0
if not IS_PYISOLATE_CHILD:
import execution
import server
from protocol import BinaryEventTypes
import nodes
import comfy.model_management
import comfyui_version
import app.logger
import hook_breaker_ac10a0
import comfy.memory_management
import comfy.model_patcher
@ -462,6 +515,10 @@ def start_comfyui(asyncio_loop=None):
asyncio.set_event_loop(asyncio_loop)
prompt_server = server.PromptServer(asyncio_loop)
if args.use_process_isolation:
from comfy.isolation import start_isolation_loading_early
start_isolation_loading_early(asyncio_loop)
if args.enable_manager and not args.disable_manager_ui:
comfyui_manager.start()
@ -506,12 +563,13 @@ def start_comfyui(asyncio_loop=None):
if __name__ == "__main__":
# Running directly, just start ComfyUI.
logging.info("Python version: {}".format(sys.version))
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
for package in ("comfy-aimdo", "comfy-kitchen"):
try:
logging.info("{} version: {}".format(package, importlib.metadata.version(package)))
except:
pass
if not IS_PYISOLATE_CHILD:
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
for package in ("comfy-aimdo", "comfy-kitchen"):
try:
logging.info("{} version: {}".format(package, importlib.metadata.version(package)))
except:
pass
if sys.version_info.major == 3 and sys.version_info.minor < 10:
logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.")

View File

@ -1912,6 +1912,7 @@ class ImageInvert:
class ImageBatch:
SEARCH_ALIASES = ["combine images", "merge images", "stack images"]
ESSENTIALS_CATEGORY = "Image Tools"
@classmethod
def INPUT_TYPES(s):
@ -2295,6 +2296,27 @@ async def init_external_custom_nodes():
Returns:
None
"""
whitelist = set()
isolated_module_paths = set()
if args.use_process_isolation:
from pathlib import Path
from comfy.isolation import await_isolation_loading, get_claimed_paths
from comfy.isolation.host_policy import load_host_policy
# Load Global Host Policy
host_policy = load_host_policy(Path(folder_paths.base_path))
whitelist_dict = host_policy.get("whitelist", {})
# Normalize whitelist keys to lowercase for case-insensitive matching
# (matches ComfyUI-Manager's normalization: project.name.strip().lower())
whitelist = set(k.strip().lower() for k in whitelist_dict.keys())
logging.info(f"][ Loaded Whitelist: {len(whitelist)} nodes allowed.")
isolated_specs = await await_isolation_loading()
for spec in isolated_specs:
NODE_CLASS_MAPPINGS.setdefault(spec.node_name, spec.stub_class)
NODE_DISPLAY_NAME_MAPPINGS.setdefault(spec.node_name, spec.display_name)
isolated_module_paths = get_claimed_paths()
base_node_names = set(NODE_CLASS_MAPPINGS.keys())
node_paths = folder_paths.get_folder_paths("custom_nodes")
node_import_times = []
@ -2318,6 +2340,16 @@ async def init_external_custom_nodes():
logging.info(f"Blocked by policy: {module_path}")
continue
if args.use_process_isolation:
if Path(module_path).resolve() in isolated_module_paths:
continue
# Tri-State Enforcement: If not Isolated (checked above), MUST be Whitelisted.
# Normalize to lowercase for case-insensitive matching (matches ComfyUI-Manager)
if possible_module.strip().lower() not in whitelist:
logging.warning(f"][ REJECTED: Node '{possible_module}' is blocked by security policy (not whitelisted/isolated).")
continue
time_before = time.perf_counter()
success = await load_custom_node(module_path, base_node_names, module_parent="custom_nodes")
node_import_times.append((time.perf_counter() - time_before, module_path, success))
@ -2332,6 +2364,14 @@ async def init_external_custom_nodes():
logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1]))
logging.info("")
if args.use_process_isolation:
from comfy.isolation import isolated_node_timings
if isolated_node_timings:
logging.info("\nImport times for isolated custom nodes:")
for timing, path, count in sorted(isolated_node_timings):
logging.info("{:6.1f} seconds: {} ({})".format(timing, path, count))
logging.info("")
async def init_builtin_extra_nodes():
"""
Initializes the built-in extra nodes in ComfyUI.

View File

@ -3,7 +3,6 @@ import sys
import asyncio
import traceback
import time
import nodes
import folder_paths
import execution
@ -202,6 +201,8 @@ def create_block_external_middleware():
class PromptServer():
def __init__(self, loop):
PromptServer.instance = self
if loop is None:
loop = asyncio.get_event_loop()
self.user_manager = UserManager()
self.model_file_manager = ModelFileManager()
@ -352,6 +353,17 @@ class PromptServer():
extensions.extend(list(map(lambda f: "/extensions/" + urllib.parse.quote(
name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files)))
# Include JS files from proxied web directories (isolated nodes)
if args.use_process_isolation:
from comfy.isolation.proxies.web_directory_proxy import get_web_directory_cache
cache = get_web_directory_cache()
for ext_name in cache.extension_names:
for entry in cache.list_files(ext_name):
if entry["relative_path"].endswith(".js"):
extensions.append(
"/extensions/" + urllib.parse.quote(ext_name) + "/" + entry["relative_path"]
)
return web.json_response(extensions)
def get_dir_by_type(dir_type):
@ -1067,6 +1079,36 @@ class PromptServer():
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
self.app.add_routes([web.static('/extensions/' + name, dir)])
# Add dynamic handler for proxied web directories (isolated nodes)
if args.use_process_isolation:
from comfy.isolation.proxies.web_directory_proxy import (
get_web_directory_cache,
ALLOWED_EXTENSIONS,
)
async def proxied_web_handler(request):
ext_name = request.match_info["ext_name"]
file_path = request.match_info["file_path"]
suffix = os.path.splitext(file_path)[1].lower()
if suffix not in ALLOWED_EXTENSIONS:
return web.Response(status=403, text="Forbidden file type")
cache = get_web_directory_cache()
result = cache.get_file(ext_name, file_path)
if result is None:
return web.Response(status=404, text="Not found")
return web.Response(
body=result["content"],
content_type=result["content_type"],
)
self.app.router.add_get(
"/extensions/{ext_name}/{file_path:.*}",
proxied_web_handler,
)
installed_templates_version = FrontendManager.get_installed_templates_version()
use_legacy_templates = True
if installed_templates_version: