From d6a72e6eff7a523c5fd88570c5d207e1c80a913b Mon Sep 17 00:00:00 2001 From: John Pollock Date: Wed, 29 Apr 2026 00:22:46 -0500 Subject: [PATCH] feat(isolation): core infrastructure and pyisolate plumbing --- .gitignore | 1 + comfy/cli_args.py | 2 + comfy/isolation/__init__.py | 436 +++++++++++++++ comfy/isolation/adapter.py | 864 +++++++++++++++++++++++++++++ comfy/isolation/child_hooks.py | 122 ++++ comfy/isolation/host_hooks.py | 25 + comfy/isolation/manifest_loader.py | 221 ++++++++ comfy/isolation/rpc_bridge.py | 49 ++ comfy/isolation/runtime_helpers.py | 471 ++++++++++++++++ comfy/isolation/shm_forensics.py | 217 ++++++++ pyproject.toml | 11 + requirements.txt | 2 + 12 files changed, 2421 insertions(+) create mode 100644 comfy/isolation/__init__.py create mode 100644 comfy/isolation/adapter.py create mode 100644 comfy/isolation/child_hooks.py create mode 100644 comfy/isolation/host_hooks.py create mode 100644 comfy/isolation/manifest_loader.py create mode 100644 comfy/isolation/rpc_bridge.py create mode 100644 comfy/isolation/runtime_helpers.py create mode 100644 comfy/isolation/shm_forensics.py diff --git a/.gitignore b/.gitignore index 0ab4ba75e..0031414ed 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,4 @@ web_custom_versions/ .DS_Store filtered-openapi.yaml uv.lock +.pyisolate_venvs/ diff --git a/comfy/cli_args.py b/comfy/cli_args.py index dbaadf723..3d7963879 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -184,6 +184,8 @@ parser.add_argument("--disable-api-nodes", action="store_true", help="Disable lo parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.") +parser.add_argument("--use-process-isolation", action="store_true", help="Enable process isolation for custom nodes with pyproject.toml manifests containing a [tool.comfy.isolation] section.") + parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level') parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).") diff --git a/comfy/isolation/__init__.py b/comfy/isolation/__init__.py new file mode 100644 index 000000000..640092f45 --- /dev/null +++ b/comfy/isolation/__init__.py @@ -0,0 +1,436 @@ +# pylint: disable=consider-using-from-import,cyclic-import,global-statement,global-variable-not-assigned,import-outside-toplevel,logging-fstring-interpolation +from __future__ import annotations +import asyncio +import inspect +import logging +import os +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Set, TYPE_CHECKING +_IMPORT_TORCH = os.environ.get("PYISOLATE_IMPORT_TORCH", "1") == "1" + +load_isolated_node = None +find_manifest_directories = None +build_stub_class = None +get_class_types_for_extension = None +scan_shm_forensics = None +start_shm_forensics = None + +if _IMPORT_TORCH: + import folder_paths + from .extension_loader import load_isolated_node + from .manifest_loader import find_manifest_directories + from .runtime_helpers import build_stub_class, get_class_types_for_extension + from .shm_forensics import scan_shm_forensics, start_shm_forensics + +if TYPE_CHECKING: + from pyisolate import ExtensionManager + from .extension_wrapper import ComfyNodeExtension + +LOG_PREFIX = "][" +isolated_node_timings: List[tuple[float, Path, int]] = [] + +if _IMPORT_TORCH: + PYISOLATE_VENV_ROOT = Path(folder_paths.base_path) / ".pyisolate_venvs" + PYISOLATE_VENV_ROOT.mkdir(parents=True, exist_ok=True) + +logger = logging.getLogger(__name__) +_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024 +_MODEL_PATCHER_IDLE_TIMEOUT_MS = 120000 + + +def initialize_proxies() -> None: + from .child_hooks import is_child_process + + is_child = is_child_process() + + if is_child: + from .child_hooks import initialize_child_process + + initialize_child_process() + else: + from .host_hooks import initialize_host_process + + initialize_host_process() + if start_shm_forensics is not None: + start_shm_forensics() + + +@dataclass(frozen=True) +class IsolatedNodeSpec: + node_name: str + display_name: str + stub_class: type + module_path: Path + + +_ISOLATED_NODE_SPECS: List[IsolatedNodeSpec] = [] +_CLAIMED_PATHS: Set[Path] = set() +_ISOLATION_SCAN_ATTEMPTED = False +_EXTENSION_MANAGERS: List["ExtensionManager"] = [] +_RUNNING_EXTENSIONS: Dict[str, "ComfyNodeExtension"] = {} +_ISOLATION_BACKGROUND_TASK: Optional["asyncio.Task[List[IsolatedNodeSpec]]"] = None +_EARLY_START_TIME: Optional[float] = None + + +def start_isolation_loading_early(loop: "asyncio.AbstractEventLoop") -> None: + global _ISOLATION_BACKGROUND_TASK, _EARLY_START_TIME + if _ISOLATION_BACKGROUND_TASK is not None: + return + _EARLY_START_TIME = time.perf_counter() + _ISOLATION_BACKGROUND_TASK = loop.create_task(initialize_isolation_nodes()) + + +async def await_isolation_loading() -> List[IsolatedNodeSpec]: + global _ISOLATION_BACKGROUND_TASK, _EARLY_START_TIME + if _ISOLATION_BACKGROUND_TASK is not None: + specs = await _ISOLATION_BACKGROUND_TASK + return specs + return await initialize_isolation_nodes() + + +async def initialize_isolation_nodes() -> List[IsolatedNodeSpec]: + global _ISOLATED_NODE_SPECS, _ISOLATION_SCAN_ATTEMPTED, _CLAIMED_PATHS + + if _ISOLATED_NODE_SPECS: + return _ISOLATED_NODE_SPECS + + if _ISOLATION_SCAN_ATTEMPTED: + return [] + + _ISOLATION_SCAN_ATTEMPTED = True + if find_manifest_directories is None or load_isolated_node is None or build_stub_class is None: + return [] + manifest_entries = find_manifest_directories() + _CLAIMED_PATHS = {entry[0].resolve() for entry in manifest_entries} + + if not manifest_entries: + return [] + + os.environ["PYISOLATE_ISOLATION_ACTIVE"] = "1" + concurrency_limit = max(1, (os.cpu_count() or 4) // 2) + semaphore = asyncio.Semaphore(concurrency_limit) + + async def load_with_semaphore( + node_dir: Path, manifest: Path + ) -> List[IsolatedNodeSpec]: + async with semaphore: + load_start = time.perf_counter() + spec_list = await load_isolated_node( + node_dir, + manifest, + logger, + lambda name, info, extension: build_stub_class( + name, + info, + extension, + _RUNNING_EXTENSIONS, + logger, + ), + PYISOLATE_VENV_ROOT, + _EXTENSION_MANAGERS, + ) + spec_list = [ + IsolatedNodeSpec( + node_name=node_name, + display_name=display_name, + stub_class=stub_cls, + module_path=node_dir, + ) + for node_name, display_name, stub_cls in spec_list + ] + isolated_node_timings.append( + (time.perf_counter() - load_start, node_dir, len(spec_list)) + ) + return spec_list + + tasks = [ + load_with_semaphore(node_dir, manifest) + for node_dir, manifest in manifest_entries + ] + results = await asyncio.gather(*tasks, return_exceptions=True) + + specs: List[IsolatedNodeSpec] = [] + for result in results: + if isinstance(result, Exception): + logger.error( + "%s Isolated node failed during startup; continuing: %s", + LOG_PREFIX, + result, + ) + continue + specs.extend(result) + + _ISOLATED_NODE_SPECS = specs + return list(_ISOLATED_NODE_SPECS) + + +def _get_class_types_for_extension(extension_name: str) -> Set[str]: + """Get all node class types (node names) belonging to an extension.""" + extension = _RUNNING_EXTENSIONS.get(extension_name) + if not extension: + return set() + + ext_path = Path(extension.module_path) + class_types = set() + for spec in _ISOLATED_NODE_SPECS: + if spec.module_path.resolve() == ext_path.resolve(): + class_types.add(spec.node_name) + + return class_types + + +async def notify_execution_graph(needed_class_types: Set[str], caches: list | None = None) -> None: + """Evict running extensions not needed for current execution. + + When *caches* is provided, cache entries for evicted extensions' node + class_types are invalidated to prevent stale ``RemoteObjectHandle`` + references from surviving in the output cache. + """ + await wait_for_model_patcher_quiescence( + timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS, + fail_loud=True, + marker="ISO:notify_graph_wait_idle", + ) + + evicted_class_types: Set[str] = set() + + async def _stop_extension( + ext_name: str, extension: "ComfyNodeExtension", reason: str + ) -> None: + # Collect class_types BEFORE stopping so we can invalidate cache entries. + ext_class_types = _get_class_types_for_extension(ext_name) + evicted_class_types.update(ext_class_types) + logger.info("%s ISO:eject_start ext=%s reason=%s", LOG_PREFIX, ext_name, reason) + logger.debug("%s ISO:stop_start ext=%s", LOG_PREFIX, ext_name) + stop_result = extension.stop() + if inspect.isawaitable(stop_result): + await stop_result + _RUNNING_EXTENSIONS.pop(ext_name, None) + logger.debug("%s ISO:stop_done ext=%s", LOG_PREFIX, ext_name) + if scan_shm_forensics is not None: + scan_shm_forensics("ISO:stop_extension", refresh_model_context=True) + + if scan_shm_forensics is not None: + scan_shm_forensics("ISO:notify_graph_start", refresh_model_context=True) + isolated_class_types_in_graph = needed_class_types.intersection( + {spec.node_name for spec in _ISOLATED_NODE_SPECS} + ) + graph_uses_isolation = bool(isolated_class_types_in_graph) + logger.debug( + "%s ISO:notify_graph_start running=%d needed=%d", + LOG_PREFIX, + len(_RUNNING_EXTENSIONS), + len(needed_class_types), + ) + if graph_uses_isolation: + for ext_name, extension in list(_RUNNING_EXTENSIONS.items()): + ext_class_types = _get_class_types_for_extension(ext_name) + + # If NONE of this extension's nodes are in the execution graph -> evict. + if not ext_class_types.intersection(needed_class_types): + await _stop_extension( + ext_name, + extension, + "isolated custom_node not in execution graph, evicting", + ) + else: + logger.debug( + "%s ISO:notify_graph_skip_evict running=%d reason=no isolated nodes in graph", + LOG_PREFIX, + len(_RUNNING_EXTENSIONS), + ) + + # Isolated child processes add steady VRAM pressure; reclaim host-side models + # at workflow boundaries so subsequent host nodes (e.g. CLIP encode) keep headroom. + try: + import comfy.model_management as model_management + + device = model_management.get_torch_device() + if getattr(device, "type", None) == "cuda": + required = max( + model_management.minimum_inference_memory(), + _WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES, + ) + free_before = model_management.get_free_memory(device) + if free_before < required and _RUNNING_EXTENSIONS and graph_uses_isolation: + for ext_name, extension in list(_RUNNING_EXTENSIONS.items()): + await _stop_extension( + ext_name, + extension, + f"boundary low-vram restart (free={int(free_before)} target={int(required)})", + ) + if model_management.get_free_memory(device) < required: + model_management.unload_all_models() + model_management.cleanup_models_gc() + model_management.cleanup_models() + if model_management.get_free_memory(device) < required: + model_management.free_memory(required, device, for_dynamic=False) + model_management.soft_empty_cache() + except Exception: + logger.debug( + "%s workflow-boundary host VRAM relief failed", LOG_PREFIX, exc_info=True + ) + finally: + # Invalidate cached outputs for evicted extensions so stale + # RemoteObjectHandle references are not served from cache. + if evicted_class_types and caches: + total_invalidated = 0 + for cache in caches: + if hasattr(cache, "invalidate_by_class_types"): + total_invalidated += cache.invalidate_by_class_types( + evicted_class_types + ) + if total_invalidated > 0: + logger.info( + "%s ISO:cache_invalidated count=%d class_types=%s", + LOG_PREFIX, + total_invalidated, + evicted_class_types, + ) + scan_shm_forensics("ISO:notify_graph_done", refresh_model_context=True) + logger.debug( + "%s ISO:notify_graph_done running=%d", LOG_PREFIX, len(_RUNNING_EXTENSIONS) + ) + + +async def flush_running_extensions_transport_state() -> int: + await wait_for_model_patcher_quiescence( + timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS, + fail_loud=True, + marker="ISO:flush_transport_wait_idle", + ) + total_flushed = 0 + for ext_name, extension in list(_RUNNING_EXTENSIONS.items()): + flush_fn = getattr(extension, "flush_transport_state", None) + if not callable(flush_fn): + continue + try: + flushed = await flush_fn() + if isinstance(flushed, int): + total_flushed += flushed + if flushed > 0: + logger.debug( + "%s %s workflow-end flush released=%d", + LOG_PREFIX, + ext_name, + flushed, + ) + except Exception: + logger.debug( + "%s %s workflow-end flush failed", LOG_PREFIX, ext_name, exc_info=True + ) + scan_shm_forensics( + "ISO:flush_running_extensions_transport_state", refresh_model_context=True + ) + return total_flushed + + +async def wait_for_model_patcher_quiescence( + timeout_ms: int = _MODEL_PATCHER_IDLE_TIMEOUT_MS, + *, + fail_loud: bool = False, + marker: str = "ISO:wait_model_patcher_idle", +) -> bool: + try: + from comfy.isolation.model_patcher_proxy_registry import ModelPatcherRegistry + + registry = ModelPatcherRegistry() + start = time.perf_counter() + idle = await registry.wait_all_idle(timeout_ms) + elapsed_ms = (time.perf_counter() - start) * 1000.0 + if idle: + logger.debug( + "%s %s idle=1 timeout_ms=%d elapsed_ms=%.3f", + LOG_PREFIX, + marker, + timeout_ms, + elapsed_ms, + ) + return True + + states = await registry.get_all_operation_states() + logger.error( + "%s %s idle_timeout timeout_ms=%d elapsed_ms=%.3f states=%s", + LOG_PREFIX, + marker, + timeout_ms, + elapsed_ms, + states, + ) + if fail_loud: + raise TimeoutError( + f"ModelPatcherRegistry did not quiesce within {timeout_ms} ms" + ) + return False + except Exception: + if fail_loud: + raise + logger.debug("%s %s failed", LOG_PREFIX, marker, exc_info=True) + return False + + +def get_claimed_paths() -> Set[Path]: + return _CLAIMED_PATHS + + +def update_rpc_event_loops(loop: "asyncio.AbstractEventLoop | None" = None) -> None: + """Update all active RPC instances with the current event loop. + + This MUST be called at the start of each workflow execution to ensure + RPC calls are scheduled on the correct event loop. This handles the case + where asyncio.run() creates a new event loop for each workflow. + + Args: + loop: The event loop to use. If None, uses asyncio.get_running_loop(). + """ + if loop is None: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.get_event_loop() + + update_count = 0 + + # Update RPCs from ExtensionManagers + for manager in _EXTENSION_MANAGERS: + if not hasattr(manager, "extensions"): + continue + for name, extension in manager.extensions.items(): + if hasattr(extension, "rpc") and extension.rpc is not None: + if hasattr(extension.rpc, "update_event_loop"): + extension.rpc.update_event_loop(loop) + update_count += 1 + logger.debug(f"{LOG_PREFIX}Updated loop on extension '{name}'") + + # Also update RPCs from running extensions (they may have direct RPC refs) + for name, extension in _RUNNING_EXTENSIONS.items(): + if hasattr(extension, "rpc") and extension.rpc is not None: + if hasattr(extension.rpc, "update_event_loop"): + extension.rpc.update_event_loop(loop) + update_count += 1 + logger.debug(f"{LOG_PREFIX}Updated loop on running extension '{name}'") + + if update_count > 0: + logger.debug(f"{LOG_PREFIX}Updated event loop on {update_count} RPC instances") + else: + logger.debug( + f"{LOG_PREFIX}No RPC instances found to update (managers={len(_EXTENSION_MANAGERS)}, running={len(_RUNNING_EXTENSIONS)})" + ) + + +__all__ = [ + "LOG_PREFIX", + "initialize_proxies", + "initialize_isolation_nodes", + "start_isolation_loading_early", + "await_isolation_loading", + "notify_execution_graph", + "flush_running_extensions_transport_state", + "wait_for_model_patcher_quiescence", + "get_claimed_paths", + "update_rpc_event_loops", + "IsolatedNodeSpec", + "get_class_types_for_extension", +] diff --git a/comfy/isolation/adapter.py b/comfy/isolation/adapter.py new file mode 100644 index 000000000..153411c04 --- /dev/null +++ b/comfy/isolation/adapter.py @@ -0,0 +1,864 @@ +# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,protected-access,raise-missing-from,useless-return,wrong-import-position +from __future__ import annotations + +import logging +import os +import inspect +from pathlib import Path +from typing import Any, Dict, List, Optional, cast + +from pyisolate.interfaces import IsolationAdapter, SerializerRegistryProtocol # type: ignore[import-untyped] +from pyisolate._internal.rpc_protocol import AsyncRPC, ProxiedSingleton # type: ignore[import-untyped] + +_IMPORT_TORCH = os.environ.get("PYISOLATE_IMPORT_TORCH", "1") == "1" + +# Singleton proxies that do NOT transitively import torch/PIL/psutil/aiohttp. +# Safe to import in sealed workers without host framework modules. +from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy +from comfy.isolation.proxies.helper_proxies import HelperProxiesService +from comfy.isolation.proxies.web_directory_proxy import WebDirectoryProxy + +# Singleton proxies that transitively import torch, PIL, or heavy host modules. +# Only available when torch/host framework is present. +CLIPProxy = None +CLIPRegistry = None +ModelPatcherProxy = None +ModelPatcherRegistry = None +ModelSamplingProxy = None +ModelSamplingRegistry = None +VAEProxy = None +VAERegistry = None +FirstStageModelRegistry = None +ModelManagementProxy = None +PromptServerService = None +ProgressProxy = None +UtilsProxy = None +_HAS_TORCH_PROXIES = False +if _IMPORT_TORCH: + from comfy.isolation.clip_proxy import CLIPProxy, CLIPRegistry + from comfy.isolation.model_patcher_proxy import ( + ModelPatcherProxy, + ModelPatcherRegistry, + ) + from comfy.isolation.model_sampling_proxy import ( + ModelSamplingProxy, + ModelSamplingRegistry, + ) + from comfy.isolation.vae_proxy import VAEProxy, VAERegistry, FirstStageModelRegistry + from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy + from comfy.isolation.proxies.prompt_server_impl import PromptServerService + from comfy.isolation.proxies.progress_proxy import ProgressProxy + from comfy.isolation.proxies.utils_proxy import UtilsProxy + _HAS_TORCH_PROXIES = True + +logger = logging.getLogger(__name__) + +# Force /dev/shm for shared memory (bwrap makes /tmp private) +import tempfile + +if os.path.exists("/dev/shm"): + # Only override if not already set or if default is not /dev/shm + current_tmp = tempfile.gettempdir() + if not current_tmp.startswith("/dev/shm"): + logger.debug( + f"Configuring shared memory: Changing TMPDIR from {current_tmp} to /dev/shm" + ) + os.environ["TMPDIR"] = "/dev/shm" + tempfile.tempdir = None # Clear cache to force re-evaluation + + +class ComfyUIAdapter(IsolationAdapter): + # ComfyUI-specific IsolationAdapter implementation + + @property + def identifier(self) -> str: + return "comfyui" + + def get_path_config(self, module_path: str) -> Optional[Dict[str, Any]]: + if "ComfyUI" in module_path and "custom_nodes" in module_path: + parts = module_path.split("ComfyUI") + if len(parts) > 1: + comfy_root = parts[0] + "ComfyUI" + return { + "preferred_root": comfy_root, + "additional_paths": [ + os.path.join(comfy_root, "custom_nodes"), + os.path.join(comfy_root, "comfy"), + ], + "filtered_subdirs": ["comfy", "app", "comfy_execution", "utils"], + } + return None + + def get_sandbox_system_paths(self) -> Optional[List[str]]: + """Returns required application paths to mount in the sandbox.""" + # By inspecting where our adapter is loaded from, we can determine the comfy root + adapter_file = inspect.getfile(self.__class__) + # adapter_file = /home/johnj/ComfyUI/comfy/isolation/adapter.py + comfy_root = os.path.dirname(os.path.dirname(os.path.dirname(adapter_file))) + if os.path.exists(comfy_root): + return [comfy_root] + return None + + def setup_child_environment(self, snapshot: Dict[str, Any]) -> None: + comfy_root = snapshot.get("preferred_root") + if not comfy_root: + return + + requirements_path = Path(comfy_root) / "requirements.txt" + if requirements_path.exists(): + import re + + for line in requirements_path.read_text().splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + pkg_name = re.split(r"[<>=!~\[]", line)[0].strip() + if pkg_name: + logging.getLogger(pkg_name).setLevel(logging.ERROR) + + def register_serializers(self, registry: SerializerRegistryProtocol) -> None: + if not _IMPORT_TORCH: + # Sealed worker without torch — register torch-free TensorValue handler + # so IMAGE/MASK/LATENT tensors arrive as numpy arrays, not raw dicts. + import numpy as np + + _TORCH_DTYPE_TO_NUMPY = { + "torch.float32": np.float32, + "torch.float64": np.float64, + "torch.float16": np.float16, + "torch.bfloat16": np.float32, # numpy has no bfloat16; upcast + "torch.int32": np.int32, + "torch.int64": np.int64, + "torch.int16": np.int16, + "torch.int8": np.int8, + "torch.uint8": np.uint8, + "torch.bool": np.bool_, + } + + def _deserialize_tensor_value(data: Dict[str, Any]) -> Any: + dtype_str = data["dtype"] + np_dtype = _TORCH_DTYPE_TO_NUMPY.get(dtype_str, np.float32) + shape = tuple(data["tensor_size"]) + arr = np.array(data["data"], dtype=np_dtype).reshape(shape) + return arr + + _NUMPY_TO_TORCH_DTYPE = { + np.float32: "torch.float32", + np.float64: "torch.float64", + np.float16: "torch.float16", + np.int32: "torch.int32", + np.int64: "torch.int64", + np.int16: "torch.int16", + np.int8: "torch.int8", + np.uint8: "torch.uint8", + np.bool_: "torch.bool", + } + + def _serialize_tensor_value(obj: Any) -> Dict[str, Any]: + arr = np.asarray(obj, dtype=np.float32) if obj.dtype not in _NUMPY_TO_TORCH_DTYPE else np.asarray(obj) + dtype_str = _NUMPY_TO_TORCH_DTYPE.get(arr.dtype.type, "torch.float32") + return { + "__type__": "TensorValue", + "dtype": dtype_str, + "tensor_size": list(arr.shape), + "requires_grad": False, + "data": arr.tolist(), + } + + registry.register("TensorValue", _serialize_tensor_value, _deserialize_tensor_value, data_type=True) + # ndarray output from sealed workers serializes as TensorValue for host torch reconstruction + registry.register("ndarray", _serialize_tensor_value, _deserialize_tensor_value, data_type=True) + return + + import torch + + def serialize_device(obj: Any) -> Dict[str, Any]: + return {"__type__": "device", "device_str": str(obj)} + + def deserialize_device(data: Dict[str, Any]) -> Any: + return torch.device(data["device_str"]) + + registry.register("device", serialize_device, deserialize_device) + + _VALID_DTYPES = { + "float16", "float32", "float64", "bfloat16", + "int8", "int16", "int32", "int64", + "uint8", "bool", + } + + def serialize_dtype(obj: Any) -> Dict[str, Any]: + return {"__type__": "dtype", "dtype_str": str(obj)} + + def deserialize_dtype(data: Dict[str, Any]) -> Any: + dtype_name = data["dtype_str"].replace("torch.", "") + if dtype_name not in _VALID_DTYPES: + raise ValueError(f"Invalid dtype: {data['dtype_str']}") + return getattr(torch, dtype_name) + + registry.register("dtype", serialize_dtype, deserialize_dtype) + + from comfy_api.latest._io import FolderType + from comfy_api.latest._ui import SavedImages, SavedResult + + def serialize_saved_result(obj: Any) -> Dict[str, Any]: + return { + "__type__": "SavedResult", + "filename": obj.filename, + "subfolder": obj.subfolder, + "folder_type": obj.type.value, + } + + def deserialize_saved_result(data: Dict[str, Any]) -> Any: + if isinstance(data, SavedResult): + return data + folder_type = data["folder_type"] if "folder_type" in data else data["type"] + return SavedResult( + filename=data["filename"], + subfolder=data["subfolder"], + type=FolderType(folder_type), + ) + + registry.register( + "SavedResult", + serialize_saved_result, + deserialize_saved_result, + data_type=True, + ) + + def serialize_saved_images(obj: Any) -> Dict[str, Any]: + return { + "__type__": "SavedImages", + "results": [serialize_saved_result(result) for result in obj.results], + "is_animated": obj.is_animated, + } + + def deserialize_saved_images(data: Dict[str, Any]) -> Any: + return SavedImages( + results=[deserialize_saved_result(result) for result in data["results"]], + is_animated=data.get("is_animated", False), + ) + + registry.register( + "SavedImages", + serialize_saved_images, + deserialize_saved_images, + data_type=True, + ) + + def serialize_model_patcher(obj: Any) -> Dict[str, Any]: + # Child-side: must already have _instance_id (proxy) + if os.environ.get("PYISOLATE_CHILD") == "1": + if hasattr(obj, "_instance_id"): + return {"__type__": "ModelPatcherRef", "model_id": obj._instance_id} + raise RuntimeError( + f"ModelPatcher in child lacks _instance_id: " + f"{type(obj).__module__}.{type(obj).__name__}" + ) + # Host-side: register with registry + if hasattr(obj, "_instance_id"): + return {"__type__": "ModelPatcherRef", "model_id": obj._instance_id} + model_id = ModelPatcherRegistry().register(obj) + return {"__type__": "ModelPatcherRef", "model_id": model_id} + + def deserialize_model_patcher(data: Any) -> Any: + """Deserialize ModelPatcher refs; pass through already-materialized objects.""" + if isinstance(data, dict): + return ModelPatcherProxy( + data["model_id"], registry=None, manage_lifecycle=False + ) + return data + + def deserialize_model_patcher_ref(data: Dict[str, Any]) -> Any: + """Context-aware ModelPatcherRef deserializer for both host and child.""" + is_child = os.environ.get("PYISOLATE_CHILD") == "1" + if is_child: + return ModelPatcherProxy( + data["model_id"], registry=None, manage_lifecycle=False + ) + else: + return ModelPatcherRegistry()._get_instance(data["model_id"]) + + # Register ModelPatcher type for serialization + registry.register( + "ModelPatcher", serialize_model_patcher, deserialize_model_patcher + ) + # Register ModelPatcherProxy type (already a proxy, just return ref) + registry.register( + "ModelPatcherProxy", serialize_model_patcher, deserialize_model_patcher + ) + # Register ModelPatcherRef for deserialization (context-aware: host or child) + registry.register("ModelPatcherRef", None, deserialize_model_patcher_ref) + + def serialize_clip(obj: Any) -> Dict[str, Any]: + if hasattr(obj, "_instance_id"): + return {"__type__": "CLIPRef", "clip_id": obj._instance_id} + clip_id = CLIPRegistry().register(obj) + return {"__type__": "CLIPRef", "clip_id": clip_id} + + def deserialize_clip(data: Any) -> Any: + if isinstance(data, dict): + return CLIPProxy(data["clip_id"], registry=None, manage_lifecycle=False) + return data + + def deserialize_clip_ref(data: Dict[str, Any]) -> Any: + """Context-aware CLIPRef deserializer for both host and child.""" + is_child = os.environ.get("PYISOLATE_CHILD") == "1" + if is_child: + return CLIPProxy(data["clip_id"], registry=None, manage_lifecycle=False) + else: + return CLIPRegistry()._get_instance(data["clip_id"]) + + # Register CLIP type for serialization + registry.register("CLIP", serialize_clip, deserialize_clip) + # Register CLIPProxy type (already a proxy, just return ref) + registry.register("CLIPProxy", serialize_clip, deserialize_clip) + # Register CLIPRef for deserialization (context-aware: host or child) + registry.register("CLIPRef", None, deserialize_clip_ref) + + def serialize_vae(obj: Any) -> Dict[str, Any]: + if hasattr(obj, "_instance_id"): + return {"__type__": "VAERef", "vae_id": obj._instance_id} + vae_id = VAERegistry().register(obj) + return {"__type__": "VAERef", "vae_id": vae_id} + + def deserialize_vae(data: Any) -> Any: + if isinstance(data, dict): + return VAEProxy(data["vae_id"]) + return data + + def deserialize_vae_ref(data: Dict[str, Any]) -> Any: + """Context-aware VAERef deserializer for both host and child.""" + is_child = os.environ.get("PYISOLATE_CHILD") == "1" + if is_child: + # Child: create a proxy + return VAEProxy(data["vae_id"]) + else: + # Host: lookup real VAE from registry + return VAERegistry()._get_instance(data["vae_id"]) + + # Register VAE type for serialization + registry.register("VAE", serialize_vae, deserialize_vae) + # Register VAEProxy type (already a proxy, just return ref) + registry.register("VAEProxy", serialize_vae, deserialize_vae) + # Register VAERef for deserialization (context-aware: host or child) + registry.register("VAERef", None, deserialize_vae_ref) + + # ModelSampling serialization - handles ModelSampling* types + # copyreg removed - no pickle fallback allowed + + def serialize_model_sampling(obj: Any) -> Dict[str, Any]: + # Proxy with _instance_id — return ref (works from both host and child) + if hasattr(obj, "_instance_id"): + return {"__type__": "ModelSamplingRef", "ms_id": obj._instance_id} + # Child-side: object created locally in child (e.g. ModelSamplingAdvanced + # in nodes_z_image_turbo.py). Serialize as inline data so the host can + # reconstruct the real torch.nn.Module. + if os.environ.get("PYISOLATE_CHILD") == "1": + import base64 + import io as _io + + # Identify base classes from comfy.model_sampling + bases = [] + for base in type(obj).__mro__: + if base.__module__ == "comfy.model_sampling" and base.__name__ != "object": + bases.append(base.__name__) + # Serialize state_dict as base64 safetensors-like + sd = obj.state_dict() + sd_serialized = {} + for k, v in sd.items(): + buf = _io.BytesIO() + torch.save(v, buf) + sd_serialized[k] = base64.b64encode(buf.getvalue()).decode("ascii") + # Capture plain attrs (shift, multiplier, sigma_data, etc.) + plain_attrs = {} + for k, v in obj.__dict__.items(): + if k.startswith("_"): + continue + if isinstance(v, (bool, int, float, str)): + plain_attrs[k] = v + return { + "__type__": "ModelSamplingInline", + "bases": bases, + "state_dict": sd_serialized, + "attrs": plain_attrs, + } + # Host-side: register with ModelSamplingRegistry and return JSON-safe dict + ms_id = ModelSamplingRegistry().register(obj) + return {"__type__": "ModelSamplingRef", "ms_id": ms_id} + + def deserialize_model_sampling(data: Any) -> Any: + """Deserialize ModelSampling refs or inline data.""" + if isinstance(data, dict): + if data.get("__type__") == "ModelSamplingInline": + return _reconstruct_model_sampling_inline(data) + return ModelSamplingProxy(data["ms_id"]) + return data + + def _reconstruct_model_sampling_inline(data: Dict[str, Any]) -> Any: + """Reconstruct a ModelSampling object on the host from inline child data.""" + import comfy.model_sampling as _ms + import base64 + import io as _io + + # Resolve base classes + base_classes = [] + for name in data["bases"]: + cls = getattr(_ms, name, None) + if cls is not None: + base_classes.append(cls) + if not base_classes: + raise RuntimeError( + f"Cannot reconstruct ModelSampling: no known bases in {data['bases']}" + ) + # Create dynamic class matching the child's class hierarchy + ReconstructedSampling = type("ReconstructedSampling", tuple(base_classes), {}) + obj = ReconstructedSampling.__new__(ReconstructedSampling) + torch.nn.Module.__init__(obj) + # Restore plain attributes first + for k, v in data.get("attrs", {}).items(): + setattr(obj, k, v) + # Restore state_dict (buffers like sigmas) + for k, v_b64 in data.get("state_dict", {}).items(): + buf = _io.BytesIO(base64.b64decode(v_b64)) + tensor = torch.load(buf, weights_only=True) + # Register as buffer so it's part of state_dict + parts = k.split(".") + if len(parts) == 1: + cast(Any, obj).register_buffer(parts[0], tensor) # pylint: disable=no-member + else: + setattr(obj, parts[0], tensor) + # Register on host so future references use proxy pattern. + # Skip in child process — register() is async RPC and cannot be + # called synchronously during deserialization. + if os.environ.get("PYISOLATE_CHILD") != "1": + ModelSamplingRegistry().register(obj) + return obj + + def deserialize_model_sampling_ref(data: Dict[str, Any]) -> Any: + """Context-aware ModelSamplingRef deserializer for both host and child.""" + is_child = os.environ.get("PYISOLATE_CHILD") == "1" + if is_child: + return ModelSamplingProxy(data["ms_id"]) + else: + return ModelSamplingRegistry()._get_instance(data["ms_id"]) + + # Register all ModelSampling* and StableCascadeSampling classes dynamically + import comfy.model_sampling + + for ms_cls in vars(comfy.model_sampling).values(): + if not isinstance(ms_cls, type): + continue + if not issubclass(ms_cls, torch.nn.Module): + continue + if not (ms_cls.__name__.startswith("ModelSampling") or ms_cls.__name__ == "StableCascadeSampling"): + continue + registry.register( + ms_cls.__name__, + serialize_model_sampling, + deserialize_model_sampling, + ) + registry.register( + "ModelSamplingProxy", serialize_model_sampling, deserialize_model_sampling + ) + # Register ModelSamplingRef for deserialization (context-aware: host or child) + registry.register("ModelSamplingRef", None, deserialize_model_sampling_ref) + # Register ModelSamplingInline for deserialization (child→host inline transfer) + registry.register( + "ModelSamplingInline", None, lambda data: _reconstruct_model_sampling_inline(data) + ) + + def serialize_cond(obj: Any) -> Dict[str, Any]: + type_key = f"{type(obj).__module__}.{type(obj).__name__}" + return { + "__type__": type_key, + "cond": obj.cond, + } + + def deserialize_cond(data: Dict[str, Any]) -> Any: + import importlib + + type_key = data["__type__"] + module_name, class_name = type_key.rsplit(".", 1) + module = importlib.import_module(module_name) + cls = getattr(module, class_name) + return cls(data["cond"]) + + def _serialize_public_state(obj: Any) -> Dict[str, Any]: + state: Dict[str, Any] = {} + for key, value in obj.__dict__.items(): + if key.startswith("_"): + continue + if callable(value): + continue + state[key] = value + return state + + def serialize_latent_format(obj: Any) -> Dict[str, Any]: + type_key = f"{type(obj).__module__}.{type(obj).__name__}" + return { + "__type__": type_key, + "state": _serialize_public_state(obj), + } + + def deserialize_latent_format(data: Dict[str, Any]) -> Any: + import importlib + + type_key = data["__type__"] + module_name, class_name = type_key.rsplit(".", 1) + module = importlib.import_module(module_name) + cls = getattr(module, class_name) + obj = cls() + for key, value in data.get("state", {}).items(): + prop = getattr(type(obj), key, None) + if isinstance(prop, property) and prop.fset is None: + continue + setattr(obj, key, value) + return obj + + import comfy.conds + + for cond_cls in vars(comfy.conds).values(): + if not isinstance(cond_cls, type): + continue + if not issubclass(cond_cls, comfy.conds.CONDRegular): + continue + type_key = f"{cond_cls.__module__}.{cond_cls.__name__}" + registry.register(type_key, serialize_cond, deserialize_cond) + registry.register(cond_cls.__name__, serialize_cond, deserialize_cond) + + import comfy.latent_formats + + for latent_cls in vars(comfy.latent_formats).values(): + if not isinstance(latent_cls, type): + continue + if not issubclass(latent_cls, comfy.latent_formats.LatentFormat): + continue + type_key = f"{latent_cls.__module__}.{latent_cls.__name__}" + registry.register( + type_key, serialize_latent_format, deserialize_latent_format + ) + registry.register( + latent_cls.__name__, serialize_latent_format, deserialize_latent_format + ) + + # V3 API: unwrap NodeOutput.args + def deserialize_node_output(data: Any) -> Any: + return getattr(data, "args", data) + + registry.register("NodeOutput", None, deserialize_node_output) + + # KSAMPLER serializer: stores sampler name instead of function object + # sampler_function is a callable which gets filtered out by JSONSocketTransport + def serialize_ksampler(obj: Any) -> Dict[str, Any]: + func_name = obj.sampler_function.__name__ + # Map function name back to sampler name + if func_name == "sample_unipc": + sampler_name = "uni_pc" + elif func_name == "sample_unipc_bh2": + sampler_name = "uni_pc_bh2" + elif func_name == "dpm_fast_function": + sampler_name = "dpm_fast" + elif func_name == "dpm_adaptive_function": + sampler_name = "dpm_adaptive" + elif func_name.startswith("sample_"): + sampler_name = func_name[7:] # Remove "sample_" prefix + else: + sampler_name = func_name + return { + "__type__": "KSAMPLER", + "sampler_name": sampler_name, + "extra_options": obj.extra_options, + "inpaint_options": obj.inpaint_options, + } + + def deserialize_ksampler(data: Dict[str, Any]) -> Any: + import comfy.samplers + + return comfy.samplers.ksampler( + data["sampler_name"], + data.get("extra_options", {}), + data.get("inpaint_options", {}), + ) + + registry.register("KSAMPLER", serialize_ksampler, deserialize_ksampler) + + from comfy.isolation.model_patcher_proxy_utils import register_hooks_serializers + + register_hooks_serializers(registry) + + # -- File3D (comfy_api.latest._util.geometry_types) --------------------- + # Origin: comfy_api by ComfyOrg (Alexander Piskun), PR #12129 + + def serialize_file3d(obj: Any) -> Dict[str, Any]: + import base64 + return { + "__type__": "File3D", + "format": obj.format, + "data": base64.b64encode(obj.get_bytes()).decode("ascii"), + } + + def deserialize_file3d(data: Any) -> Any: + import base64 + from io import BytesIO + from comfy_api.latest._util.geometry_types import File3D + return File3D(BytesIO(base64.b64decode(data["data"])), file_format=data["format"]) + + registry.register("File3D", serialize_file3d, deserialize_file3d, data_type=True) + + # -- VIDEO (comfy_api.latest._input_impl.video_types) ------------------- + # Origin: ComfyAPI Core v0.0.2 by ComfyOrg (guill), PR #8962 + + def serialize_video(obj: Any) -> Dict[str, Any]: + components = obj.get_components() + images = components.images.detach() if components.images.requires_grad else components.images + result: Dict[str, Any] = { + "__type__": "VIDEO", + "images": images, + "frame_rate_num": components.frame_rate.numerator, + "frame_rate_den": components.frame_rate.denominator, + } + if components.audio is not None: + waveform = components.audio["waveform"] + if waveform.requires_grad: + waveform = waveform.detach() + result["audio_waveform"] = waveform + result["audio_sample_rate"] = components.audio["sample_rate"] + if components.metadata is not None: + result["metadata"] = components.metadata + return result + + def deserialize_video(data: Any) -> Any: + from fractions import Fraction + from comfy_api.latest._input_impl.video_types import VideoFromComponents + from comfy_api.latest._util.video_types import VideoComponents + audio = None + if "audio_waveform" in data: + audio = {"waveform": data["audio_waveform"], "sample_rate": data["audio_sample_rate"]} + components = VideoComponents( + images=data["images"], + frame_rate=Fraction(data["frame_rate_num"], data["frame_rate_den"]), + audio=audio, + metadata=data.get("metadata"), + ) + return VideoFromComponents(components) + + registry.register("VIDEO", serialize_video, deserialize_video, data_type=True) + registry.register("VideoFromFile", serialize_video, deserialize_video, data_type=True) + registry.register("VideoFromComponents", serialize_video, deserialize_video, data_type=True) + + def setup_web_directory(self, module: Any) -> None: + """Detect WEB_DIRECTORY on a module and populate/register it. + + Called by the sealed worker after loading the node module. + Mirrors extension_wrapper.py:216-227 for host-coupled nodes. + Does NOT import extension_wrapper.py (it has `import torch` at module level). + """ + import shutil + + web_dir_attr = getattr(module, "WEB_DIRECTORY", None) + if web_dir_attr is None: + return + + module_dir = os.path.dirname(os.path.abspath(module.__file__)) + web_dir_path = os.path.abspath(os.path.join(module_dir, web_dir_attr)) + + # Read extension name from pyproject.toml + ext_name = os.path.basename(module_dir) + pyproject = os.path.join(module_dir, "pyproject.toml") + if os.path.exists(pyproject): + try: + import tomllib + except ImportError: + import tomli as tomllib # type: ignore[no-redef] + try: + with open(pyproject, "rb") as f: + data = tomllib.load(f) + name = data.get("project", {}).get("name") + if name: + ext_name = name + except Exception: + pass + + # Populate web dir if empty (mirrors _run_prestartup_web_copy) + if not (os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path))): + os.makedirs(web_dir_path, exist_ok=True) + + # Module-defined copy spec + copy_spec = getattr(module, "_PRESTARTUP_WEB_COPY", None) + if copy_spec is not None and callable(copy_spec): + try: + copy_spec(web_dir_path) + except Exception as e: + logger.warning("][ _PRESTARTUP_WEB_COPY failed: %s", e) + + # Fallback: comfy_3d_viewers + try: + from comfy_3d_viewers import copy_viewer, VIEWER_FILES + for viewer in VIEWER_FILES: + try: + copy_viewer(viewer, web_dir_path) + except Exception: + pass + except ImportError: + pass + + # Fallback: comfy_dynamic_widgets + try: + from comfy_dynamic_widgets import get_js_path + src = os.path.realpath(get_js_path()) + if os.path.exists(src): + dst_dir = os.path.join(web_dir_path, "js") + os.makedirs(dst_dir, exist_ok=True) + shutil.copy2(src, os.path.join(dst_dir, "dynamic_widgets.js")) + except ImportError: + pass + + if os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path)): + WebDirectoryProxy.register_web_dir(ext_name, web_dir_path) + logger.info( + "][ Adapter: registered web dir for %s (%d files)", + ext_name, + sum(1 for _ in Path(web_dir_path).rglob("*") if _.is_file()), + ) + + @staticmethod + def register_host_event_handlers(extension: Any) -> None: + """Register host-side event handlers for an isolated extension. + + Wires ``"progress"`` events from the child to ``comfy.utils.PROGRESS_BAR_HOOK`` + so the ComfyUI frontend receives progress bar updates. + """ + register_event_handler = inspect.getattr_static( + extension, "register_event_handler", None + ) + if not callable(register_event_handler): + return + + def _host_progress_handler(payload: dict) -> None: + import comfy.utils + + hook = comfy.utils.PROGRESS_BAR_HOOK + if hook is not None: + hook( + payload.get("value", 0), + payload.get("total", 0), + payload.get("preview"), + payload.get("node_id"), + ) + + extension.register_event_handler("progress", _host_progress_handler) + + def setup_child_event_hooks(self, extension: Any) -> None: + """Wire PROGRESS_BAR_HOOK in the child to emit_event on the extension. + + Host-coupled only — sealed workers do not have comfy.utils (torch). + """ + is_child = os.environ.get("PYISOLATE_CHILD") == "1" + logger.info("][ ISO:setup_child_event_hooks called, PYISOLATE_CHILD=%s", is_child) + if not is_child: + return + + if not _IMPORT_TORCH: + logger.info("][ ISO:setup_child_event_hooks skipped — sealed worker (no torch)") + return + + import comfy.utils + + def _event_progress_hook(value, total, preview=None, node_id=None): + logger.debug("][ ISO:event_progress value=%s/%s node_id=%s", value, total, node_id) + extension.emit_event("progress", { + "value": value, + "total": total, + "node_id": node_id, + }) + + comfy.utils.PROGRESS_BAR_HOOK = _event_progress_hook + logger.info("][ ISO:PROGRESS_BAR_HOOK wired to event channel") + + def provide_rpc_services(self) -> List[type[ProxiedSingleton]]: + # Always available — no torch/PIL dependency + services: List[type[ProxiedSingleton]] = [ + FolderPathsProxy, + HelperProxiesService, + WebDirectoryProxy, + ] + # Torch/PIL-dependent proxies + if _HAS_TORCH_PROXIES: + services.extend([ + PromptServerService, + ModelManagementProxy, + UtilsProxy, + ProgressProxy, + VAERegistry, + CLIPRegistry, + ModelPatcherRegistry, + ModelSamplingRegistry, + FirstStageModelRegistry, + ]) + return services + + def handle_api_registration(self, api: ProxiedSingleton, rpc: AsyncRPC) -> None: + # Resolve the real name whether it's an instance or the Singleton class itself + api_name = api.__name__ if isinstance(api, type) else api.__class__.__name__ + + if api_name == "FolderPathsProxy": + import folder_paths + + # Replace module-level functions with proxy methods + # This is aggressive but necessary for transparent proxying + # Handle both instance and class cases + instance = api() if isinstance(api, type) else api + for name in dir(instance): + if not name.startswith("_"): + setattr(folder_paths, name, getattr(instance, name)) + + # Fence: isolated children get writable temp inside sandbox + if os.environ.get("PYISOLATE_CHILD") == "1": + import tempfile + _child_temp = os.path.join(tempfile.gettempdir(), "comfyui_temp") + os.makedirs(_child_temp, exist_ok=True) + folder_paths.temp_directory = _child_temp + + return + + if api_name == "ModelManagementProxy": + if _IMPORT_TORCH: + import comfy.model_management + + instance = api() if isinstance(api, type) else api + # Replace module-level functions with proxy methods + for name in dir(instance): + if not name.startswith("_"): + setattr(comfy.model_management, name, getattr(instance, name)) + return + + if api_name == "UtilsProxy": + if not _IMPORT_TORCH: + logger.info("][ ISO:UtilsProxy handle_api_registration skipped — sealed worker (no torch)") + return + + import comfy.utils + + # Static Injection of RPC mechanism to ensure Child can access it + # independent of instance lifecycle. + api.set_rpc(rpc) + + is_child = os.environ.get("PYISOLATE_CHILD") == "1" + logger.info("][ ISO:UtilsProxy handle_api_registration PYISOLATE_CHILD=%s", is_child) + + # Progress hook wiring moved to setup_child_event_hooks via event channel + + return + + if api_name == "PromptServerService": + if not _IMPORT_TORCH: + return + import server + from comfy.isolation.proxies.prompt_server_impl import PromptServerStub + + stub = PromptServerStub() + if ( + hasattr(server, "PromptServer") + and getattr(server.PromptServer, "instance", None) is not stub + ): + server.PromptServer.instance = stub diff --git a/comfy/isolation/child_hooks.py b/comfy/isolation/child_hooks.py new file mode 100644 index 000000000..8aca5a18a --- /dev/null +++ b/comfy/isolation/child_hooks.py @@ -0,0 +1,122 @@ +# pylint: disable=import-outside-toplevel,logging-fstring-interpolation +# Child process initialization for PyIsolate +import logging +import os + +logger = logging.getLogger(__name__) + + +def is_child_process() -> bool: + return os.environ.get("PYISOLATE_CHILD") == "1" + + +def _load_extra_model_paths() -> None: + """Load extra_model_paths.yaml so the child's folder_paths has the same search paths as the host. + + The host loads this in main.py:143-145. The child is spawned by + pyisolate's uds_client.py and never runs main.py, so folder_paths + only has the base model directories. Any isolated node calling + folder_paths.get_filename_list() in define_schema() would get empty + results for folders whose files live in extra_model_paths locations. + """ + import folder_paths # noqa: F401 — side-effect import; load_extra_path_config writes to folder_paths internals + from utils.extra_config import load_extra_path_config + + extra_config_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), + "extra_model_paths.yaml", + ) + if os.path.isfile(extra_config_path): + load_extra_path_config(extra_config_path) + + +def initialize_child_process() -> None: + if os.environ.get("PYISOLATE_IMPORT_TORCH", "1") != "0": + _load_extra_model_paths() + _setup_child_loop_bridge() + + # Manual RPC injection + try: + from pyisolate._internal.rpc_protocol import get_child_rpc_instance + + rpc = get_child_rpc_instance() + if rpc: + _setup_proxy_callers(rpc) + else: + _setup_proxy_callers() + except Exception as e: + logger.error(f"][ child_hooks Manual RPC Injection failed: {e}") + _setup_proxy_callers() + + _setup_logging() + + +def _setup_child_loop_bridge() -> None: + import asyncio + + main_loop = None + try: + main_loop = asyncio.get_running_loop() + except RuntimeError: + try: + main_loop = asyncio.get_event_loop() + except RuntimeError: + pass + + if main_loop is None: + return + + try: + from .proxies.base import set_global_loop + + set_global_loop(main_loop) + except ImportError: + pass + + +def _setup_prompt_server_stub(rpc=None) -> None: + try: + from .proxies.prompt_server_impl import PromptServerStub + + if rpc: + PromptServerStub.set_rpc(rpc) + elif hasattr(PromptServerStub, "clear_rpc"): + PromptServerStub.clear_rpc() + else: + PromptServerStub._rpc = None # type: ignore[attr-defined] + + except Exception as e: + logger.error(f"Failed to setup PromptServerStub: {e}") + + +def _setup_proxy_callers(rpc=None) -> None: + try: + from .proxies.folder_paths_proxy import FolderPathsProxy + from .proxies.helper_proxies import HelperProxiesService + from .proxies.model_management_proxy import ModelManagementProxy + from .proxies.progress_proxy import ProgressProxy + from .proxies.prompt_server_impl import PromptServerStub + from .proxies.utils_proxy import UtilsProxy + + if rpc is None: + FolderPathsProxy.clear_rpc() + HelperProxiesService.clear_rpc() + ModelManagementProxy.clear_rpc() + ProgressProxy.clear_rpc() + PromptServerStub.clear_rpc() + UtilsProxy.clear_rpc() + return + + FolderPathsProxy.set_rpc(rpc) + HelperProxiesService.set_rpc(rpc) + ModelManagementProxy.set_rpc(rpc) + ProgressProxy.set_rpc(rpc) + PromptServerStub.set_rpc(rpc) + UtilsProxy.set_rpc(rpc) + + except Exception as e: + logger.error(f"Failed to setup child singleton proxy callers: {e}") + + +def _setup_logging() -> None: + logging.getLogger().setLevel(logging.INFO) diff --git a/comfy/isolation/host_hooks.py b/comfy/isolation/host_hooks.py new file mode 100644 index 000000000..75966d884 --- /dev/null +++ b/comfy/isolation/host_hooks.py @@ -0,0 +1,25 @@ +# pylint: disable=import-outside-toplevel +# Host process initialization for PyIsolate +import logging + +logger = logging.getLogger(__name__) + + +def initialize_host_process() -> None: + from .proxies.folder_paths_proxy import FolderPathsProxy + from .proxies.helper_proxies import HelperProxiesService + from .proxies.model_management_proxy import ModelManagementProxy + from .proxies.progress_proxy import ProgressProxy + from .proxies.prompt_server_impl import PromptServerService + from .proxies.utils_proxy import UtilsProxy + from .proxies.web_directory_proxy import WebDirectoryProxy + from .vae_proxy import VAERegistry + + FolderPathsProxy() + HelperProxiesService() + ModelManagementProxy() + ProgressProxy() + PromptServerService() + UtilsProxy() + WebDirectoryProxy() + VAERegistry() diff --git a/comfy/isolation/manifest_loader.py b/comfy/isolation/manifest_loader.py new file mode 100644 index 000000000..4ae21d94d --- /dev/null +++ b/comfy/isolation/manifest_loader.py @@ -0,0 +1,221 @@ +# pylint: disable=import-outside-toplevel +from __future__ import annotations + +import hashlib +import json +import logging +import os +import sys +import tempfile +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import folder_paths + +try: + import tomllib +except ImportError: + import tomli as tomllib # type: ignore[no-redef] + +LOG_PREFIX = "][" +logger = logging.getLogger(__name__) + +CACHE_SUBDIR = "cache" +CACHE_KEY_FILE = "cache_key" +CACHE_DATA_FILE = "node_info.json" +CACHE_KEY_LENGTH = 16 +_NESTED_SCAN_ROOT = "packages" +_IGNORED_MANIFEST_DIRS = {".git", ".venv", "__pycache__"} + + +def _read_manifest(manifest_path: Path) -> dict[str, Any] | None: + try: + with manifest_path.open("rb") as f: + data = tomllib.load(f) + if isinstance(data, dict): + return data + except Exception: + return None + return None + + +def _is_isolation_manifest(data: dict[str, Any]) -> bool: + return ( + "tool" in data + and "comfy" in data["tool"] + and "isolation" in data["tool"]["comfy"] + ) + + +def _discover_nested_manifests(entry: Path) -> List[Tuple[Path, Path]]: + packages_root = entry / _NESTED_SCAN_ROOT + if not packages_root.exists() or not packages_root.is_dir(): + return [] + + nested: List[Tuple[Path, Path]] = [] + for manifest in sorted(packages_root.rglob("pyproject.toml")): + node_dir = manifest.parent + if any(part in _IGNORED_MANIFEST_DIRS for part in node_dir.parts): + continue + + data = _read_manifest(manifest) + if not data or not _is_isolation_manifest(data): + continue + + isolation = data["tool"]["comfy"]["isolation"] + if isolation.get("standalone") is True: + nested.append((node_dir, manifest)) + + return nested + + +def find_manifest_directories() -> List[Tuple[Path, Path]]: + """Find custom node directories containing a valid pyproject.toml with [tool.comfy.isolation].""" + manifest_dirs: List[Tuple[Path, Path]] = [] + + # Standard custom_nodes paths + for base_path in folder_paths.get_folder_paths("custom_nodes"): + base = Path(base_path) + if not base.exists() or not base.is_dir(): + continue + + for entry in base.iterdir(): + if not entry.is_dir(): + continue + + # Look for pyproject.toml + manifest = entry / "pyproject.toml" + if not manifest.exists(): + continue + + data = _read_manifest(manifest) + if not data or not _is_isolation_manifest(data): + continue + + manifest_dirs.append((entry, manifest)) + manifest_dirs.extend(_discover_nested_manifests(entry)) + + return manifest_dirs + + +def compute_cache_key(node_dir: Path, manifest_path: Path) -> str: + """Hash manifest + .py mtimes + Python version + PyIsolate version.""" + hasher = hashlib.sha256() + + try: + # Hashing the manifest content ensures config changes invalidate cache + hasher.update(manifest_path.read_bytes()) + except OSError: + hasher.update(b"__manifest_read_error__") + + try: + py_files = sorted(node_dir.rglob("*.py")) + for py_file in py_files: + rel_path = py_file.relative_to(node_dir) + if "__pycache__" in str(rel_path) or ".venv" in str(rel_path): + continue + hasher.update(str(rel_path).encode("utf-8")) + try: + hasher.update(str(py_file.stat().st_mtime).encode("utf-8")) + except OSError: + hasher.update(b"__file_stat_error__") + except OSError: + hasher.update(b"__dir_scan_error__") + + hasher.update(sys.version.encode("utf-8")) + + try: + import pyisolate + + hasher.update(pyisolate.__version__.encode("utf-8")) + except (ImportError, AttributeError): + hasher.update(b"__pyisolate_unknown__") + + return hasher.hexdigest()[:CACHE_KEY_LENGTH] + + +def get_cache_path(node_dir: Path, venv_root: Path) -> Tuple[Path, Path]: + """Return (cache_key_file, cache_data_file) in venv_root/{node}/cache/.""" + cache_dir = venv_root / node_dir.name / CACHE_SUBDIR + return (cache_dir / CACHE_KEY_FILE, cache_dir / CACHE_DATA_FILE) + + +def is_cache_valid(node_dir: Path, manifest_path: Path, venv_root: Path) -> bool: + """Return True only if stored cache key matches current computed key.""" + try: + cache_key_file, cache_data_file = get_cache_path(node_dir, venv_root) + if not cache_key_file.exists() or not cache_data_file.exists(): + return False + current_key = compute_cache_key(node_dir, manifest_path) + stored_key = cache_key_file.read_text(encoding="utf-8").strip() + return current_key == stored_key + except Exception as e: + logger.debug( + "%s Cache validation error for %s: %s", LOG_PREFIX, node_dir.name, e + ) + return False + + +def load_from_cache(node_dir: Path, venv_root: Path) -> Optional[Dict[str, Any]]: + """Load node metadata from cache, return None on any error.""" + try: + _, cache_data_file = get_cache_path(node_dir, venv_root) + if not cache_data_file.exists(): + return None + data = json.loads(cache_data_file.read_text(encoding="utf-8")) + if not isinstance(data, dict): + return None + return data + except Exception: + return None + + +def save_to_cache( + node_dir: Path, venv_root: Path, node_data: Dict[str, Any], manifest_path: Path +) -> None: + """Save node metadata and cache key atomically.""" + try: + cache_key_file, cache_data_file = get_cache_path(node_dir, venv_root) + cache_dir = cache_key_file.parent + cache_dir.mkdir(parents=True, exist_ok=True) + cache_key = compute_cache_key(node_dir, manifest_path) + + # Atomic write: data + tmp_data_fd, tmp_data_path = tempfile.mkstemp(dir=str(cache_dir), suffix=".tmp") + try: + with os.fdopen(tmp_data_fd, "w", encoding="utf-8") as f: + json.dump(node_data, f, indent=2) + os.replace(tmp_data_path, cache_data_file) + except Exception: + try: + os.unlink(tmp_data_path) + except OSError: + pass + raise + + # Atomic write: key + tmp_key_fd, tmp_key_path = tempfile.mkstemp(dir=str(cache_dir), suffix=".tmp") + try: + with os.fdopen(tmp_key_fd, "w", encoding="utf-8") as f: + f.write(cache_key) + os.replace(tmp_key_path, cache_key_file) + except Exception: + try: + os.unlink(tmp_key_path) + except OSError: + pass + raise + + except Exception as e: + logger.warning("%s Cache save failed for %s: %s", LOG_PREFIX, node_dir.name, e) + + +__all__ = [ + "LOG_PREFIX", + "find_manifest_directories", + "compute_cache_key", + "get_cache_path", + "is_cache_valid", + "load_from_cache", + "save_to_cache", +] diff --git a/comfy/isolation/rpc_bridge.py b/comfy/isolation/rpc_bridge.py new file mode 100644 index 000000000..2beb0f09f --- /dev/null +++ b/comfy/isolation/rpc_bridge.py @@ -0,0 +1,49 @@ +import asyncio +import logging +import threading + +logger = logging.getLogger(__name__) + + +class RpcBridge: + """Minimal helper to run coroutines synchronously inside isolated processes. + + If an event loop is already running, the coroutine is executed on a fresh + thread with its own loop to avoid nested run_until_complete errors. + """ + + def run_sync(self, maybe_coro): + if not asyncio.iscoroutine(maybe_coro): + return maybe_coro + + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop and loop.is_running(): + result_container = {} + exc_container = {} + + def _runner(): + try: + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + result_container["value"] = new_loop.run_until_complete(maybe_coro) + except Exception as exc: # pragma: no cover + exc_container["error"] = exc + finally: + try: + new_loop.close() + except Exception: + pass + + t = threading.Thread(target=_runner, daemon=True) + t.start() + t.join() + + if "error" in exc_container: + raise exc_container["error"] + return result_container.get("value") + + return asyncio.run(maybe_coro) diff --git a/comfy/isolation/runtime_helpers.py b/comfy/isolation/runtime_helpers.py new file mode 100644 index 000000000..f56b1859a --- /dev/null +++ b/comfy/isolation/runtime_helpers.py @@ -0,0 +1,471 @@ +# pylint: disable=consider-using-from-import,import-outside-toplevel,no-member +from __future__ import annotations + +import copy +import logging +import os +from pathlib import Path +from typing import Any, Dict, List, Set, TYPE_CHECKING + +from .proxies.helper_proxies import restore_input_types +from .shm_forensics import scan_shm_forensics + +_IMPORT_TORCH = os.environ.get("PYISOLATE_IMPORT_TORCH", "1") == "1" + +_ComfyNodeInternal = object +latest_io = None + +if _IMPORT_TORCH: + from comfy_api.internal import _ComfyNodeInternal + from comfy_api.latest import _io as latest_io + +if TYPE_CHECKING: + from .extension_wrapper import ComfyNodeExtension + +LOG_PREFIX = "][" +_PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024 + + +class _RemoteObjectRegistryCaller: + def __init__(self, extension: Any) -> None: + self._extension = extension + + def __getattr__(self, method_name: str) -> Any: + async def _call(instance_id: str, *args: Any, **kwargs: Any) -> Any: + return await self._extension.call_remote_object_method( + instance_id, + method_name, + *args, + **kwargs, + ) + + return _call + + +def _wrap_remote_handles_as_host_proxies(value: Any, extension: Any) -> Any: + from pyisolate._internal.remote_handle import RemoteObjectHandle + + if isinstance(value, RemoteObjectHandle): + if value.type_name == "ModelPatcher": + from comfy.isolation.model_patcher_proxy import ModelPatcherProxy + + proxy = ModelPatcherProxy(value.object_id, manage_lifecycle=False) + proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined] + proxy._pyisolate_remote_handle = value # type: ignore[attr-defined] + return proxy + if value.type_name == "VAE": + from comfy.isolation.vae_proxy import VAEProxy + + proxy = VAEProxy(value.object_id, manage_lifecycle=False) + proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined] + proxy._pyisolate_remote_handle = value # type: ignore[attr-defined] + return proxy + if value.type_name == "CLIP": + from comfy.isolation.clip_proxy import CLIPProxy + + proxy = CLIPProxy(value.object_id, manage_lifecycle=False) + proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined] + proxy._pyisolate_remote_handle = value # type: ignore[attr-defined] + return proxy + if value.type_name == "ModelSampling": + from comfy.isolation.model_sampling_proxy import ModelSamplingProxy + + proxy = ModelSamplingProxy(value.object_id, manage_lifecycle=False) + proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined] + proxy._pyisolate_remote_handle = value # type: ignore[attr-defined] + return proxy + return value + + if isinstance(value, dict): + return { + k: _wrap_remote_handles_as_host_proxies(v, extension) for k, v in value.items() + } + + if isinstance(value, (list, tuple)): + wrapped = [_wrap_remote_handles_as_host_proxies(item, extension) for item in value] + return type(value)(wrapped) + + return value + + +def _resource_snapshot() -> Dict[str, int]: + fd_count = -1 + shm_sender_files = 0 + try: + fd_count = len(os.listdir("/proc/self/fd")) + except Exception: + pass + try: + shm_root = Path("/dev/shm") + if shm_root.exists(): + prefix = f"torch_{os.getpid()}_" + shm_sender_files = sum(1 for _ in shm_root.glob(f"{prefix}*")) + except Exception: + pass + return {"fd_count": fd_count, "shm_sender_files": shm_sender_files} + + +def _tensor_transport_summary(value: Any) -> Dict[str, int]: + summary: Dict[str, int] = { + "tensor_count": 0, + "cpu_tensors": 0, + "cuda_tensors": 0, + "shared_cpu_tensors": 0, + "tensor_bytes": 0, + } + try: + import torch + except Exception: + return summary + + def visit(node: Any) -> None: + if isinstance(node, torch.Tensor): + summary["tensor_count"] += 1 + summary["tensor_bytes"] += int(node.numel() * node.element_size()) + if node.device.type == "cpu": + summary["cpu_tensors"] += 1 + if node.is_shared(): + summary["shared_cpu_tensors"] += 1 + elif node.device.type == "cuda": + summary["cuda_tensors"] += 1 + return + if isinstance(node, dict): + for v in node.values(): + visit(v) + return + if isinstance(node, (list, tuple)): + for v in node: + visit(v) + + visit(value) + return summary + + +def _extract_hidden_unique_id(inputs: Dict[str, Any]) -> str | None: + for key, value in inputs.items(): + key_text = str(key) + if "unique_id" in key_text: + return str(value) + return None + + +def _flush_tensor_transport_state(marker: str, logger: logging.Logger) -> None: + try: + from pyisolate import flush_tensor_keeper # type: ignore[attr-defined] + except Exception: + return + if not callable(flush_tensor_keeper): + return + flushed = flush_tensor_keeper() + if flushed > 0: + logger.debug( + "%s %s flush_tensor_keeper released=%d", LOG_PREFIX, marker, flushed + ) + + +def _relieve_host_vram_pressure(marker: str, logger: logging.Logger) -> None: + import comfy.model_management as model_management + + model_management.cleanup_models_gc() + model_management.cleanup_models() + + device = model_management.get_torch_device() + if not hasattr(device, "type") or device.type == "cpu": + return + + required = max( + model_management.minimum_inference_memory(), + _PRE_EXEC_MIN_FREE_VRAM_BYTES, + ) + if model_management.get_free_memory(device) < required: + model_management.free_memory(required, device, for_dynamic=True) + if model_management.get_free_memory(device) < required: + model_management.free_memory(required, device, for_dynamic=False) + model_management.cleanup_models() + model_management.soft_empty_cache() + logger.debug("%s %s free_memory target=%d", LOG_PREFIX, marker, required) + + +def _detach_shared_cpu_tensors(value: Any) -> Any: + try: + import torch + except Exception: + return value + + if isinstance(value, torch.Tensor): + if value.device.type == "cpu" and value.is_shared(): + clone = value.clone() + if value.requires_grad: + clone.requires_grad_(True) + return clone + return value + if isinstance(value, list): + return [_detach_shared_cpu_tensors(v) for v in value] + if isinstance(value, tuple): + return tuple(_detach_shared_cpu_tensors(v) for v in value) + if isinstance(value, dict): + return {k: _detach_shared_cpu_tensors(v) for k, v in value.items()} + return value + + +def build_stub_class( + node_name: str, + info: Dict[str, object], + extension: "ComfyNodeExtension", + running_extensions: Dict[str, "ComfyNodeExtension"], + logger: logging.Logger, +) -> type: + if latest_io is None: + raise RuntimeError("comfy_api.latest._io is required to build isolation stubs") + is_v3 = bool(info.get("is_v3", False)) + function_name = "_pyisolate_execute" + restored_input_types = restore_input_types(info.get("input_types", {})) + + async def _execute(self, **inputs): + from comfy.isolation import _RUNNING_EXTENSIONS + + # Update BOTH the local dict AND the module-level dict + running_extensions[extension.name] = extension + _RUNNING_EXTENSIONS[extension.name] = extension + prev_child = None + node_unique_id = _extract_hidden_unique_id(inputs) + summary = _tensor_transport_summary(inputs) + resources = _resource_snapshot() + logger.debug( + "%s ISO:execute_start ext=%s node=%s uid=%s", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + ) + logger.debug( + "%s ISO:execute_start ext=%s node=%s uid=%s tensors=%d cpu=%d cuda=%d shared_cpu=%d bytes=%d fds=%d sender_shm=%d", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + summary["tensor_count"], + summary["cpu_tensors"], + summary["cuda_tensors"], + summary["shared_cpu_tensors"], + summary["tensor_bytes"], + resources["fd_count"], + resources["shm_sender_files"], + ) + scan_shm_forensics("RUNTIME:execute_start", refresh_model_context=True) + try: + if os.environ.get("PYISOLATE_CHILD") != "1": + _relieve_host_vram_pressure("RUNTIME:pre_execute", logger) + scan_shm_forensics("RUNTIME:pre_execute", refresh_model_context=True) + from pyisolate._internal.model_serialization import ( + serialize_for_isolation, + deserialize_from_isolation, + ) + + prev_child = os.environ.pop("PYISOLATE_CHILD", None) + logger.debug( + "%s ISO:serialize_start ext=%s node=%s uid=%s", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + ) + # Unwrap NodeOutput-like dicts before serialization. + # OUTPUT_NODE nodes return {"ui": {...}, "result": (outputs...)} + # and the executor may pass this dict as input to downstream nodes. + unwrapped_inputs = {} + for k, v in inputs.items(): + if isinstance(v, dict) and "result" in v and ("ui" in v or "__node_output__" in v): + result = v.get("result") + if isinstance(result, (tuple, list)) and len(result) > 0: + unwrapped_inputs[k] = result[0] + else: + unwrapped_inputs[k] = result + else: + unwrapped_inputs[k] = v + serialized = serialize_for_isolation(unwrapped_inputs) + logger.debug( + "%s ISO:serialize_done ext=%s node=%s uid=%s", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + ) + logger.debug( + "%s ISO:dispatch_start ext=%s node=%s uid=%s", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + ) + result = await extension.execute_node(node_name, **serialized) + logger.debug( + "%s ISO:dispatch_done ext=%s node=%s uid=%s", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + ) + # Reconstruct NodeOutput if the child serialized one + if isinstance(result, dict) and result.get("__node_output__"): + from comfy_api.latest import io as latest_io + args_raw = result.get("args", ()) + deserialized_args = await deserialize_from_isolation(args_raw, extension) + deserialized_args = _wrap_remote_handles_as_host_proxies( + deserialized_args, extension + ) + deserialized_args = _detach_shared_cpu_tensors(deserialized_args) + ui_raw = result.get("ui") + deserialized_ui = None + if ui_raw is not None: + deserialized_ui = await deserialize_from_isolation(ui_raw, extension) + deserialized_ui = _wrap_remote_handles_as_host_proxies( + deserialized_ui, extension + ) + deserialized_ui = _detach_shared_cpu_tensors(deserialized_ui) + scan_shm_forensics("RUNTIME:post_execute", refresh_model_context=True) + return latest_io.NodeOutput( + *deserialized_args, + ui=deserialized_ui, + expand=result.get("expand"), + block_execution=result.get("block_execution"), + ) + # OUTPUT_NODE: if sealed worker returned a tuple/list whose first + # element is a {"ui": ...} dict, unwrap it for the executor. + if (isinstance(result, (tuple, list)) and len(result) == 1 + and isinstance(result[0], dict) and "ui" in result[0]): + return result[0] + deserialized = await deserialize_from_isolation(result, extension) + deserialized = _wrap_remote_handles_as_host_proxies(deserialized, extension) + scan_shm_forensics("RUNTIME:post_execute", refresh_model_context=True) + return _detach_shared_cpu_tensors(deserialized) + except ImportError: + return await extension.execute_node(node_name, **inputs) + except Exception: + logger.exception( + "%s ISO:execute_error ext=%s node=%s uid=%s", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + ) + raise + finally: + if prev_child is not None: + os.environ["PYISOLATE_CHILD"] = prev_child + logger.debug( + "%s ISO:execute_end ext=%s node=%s uid=%s", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + ) + scan_shm_forensics("RUNTIME:execute_end", refresh_model_context=True) + + def _input_types( + cls, + include_hidden: bool = True, + return_schema: bool = False, + live_inputs: Any = None, + ): + if not is_v3: + return restored_input_types + + inputs_copy = copy.deepcopy(restored_input_types) + if not include_hidden: + inputs_copy.pop("hidden", None) + + v3_data: Dict[str, Any] = {"hidden_inputs": {}} + dynamic = inputs_copy.pop("dynamic_paths", None) + if dynamic is not None: + v3_data["dynamic_paths"] = dynamic + + if return_schema: + hidden_vals = info.get("hidden", []) or [] + hidden_enums = [] + for h in hidden_vals: + try: + hidden_enums.append(latest_io.Hidden(h)) + except Exception: + hidden_enums.append(h) + + class SchemaProxy: + hidden = hidden_enums + + return inputs_copy, SchemaProxy, v3_data + return inputs_copy + + def _validate_class(cls): + return True + + def _get_node_info_v1(cls): + node_info = copy.deepcopy(info.get("schema_v1", {})) + relative_python_module = node_info.get("python_module") + if not isinstance(relative_python_module, str) or not relative_python_module: + relative_python_module = f"custom_nodes.{extension.name}" + node_info["python_module"] = relative_python_module + return node_info + + def _get_base_class(cls): + return latest_io.ComfyNode + + attributes: Dict[str, object] = { + "FUNCTION": function_name, + "CATEGORY": info.get("category", ""), + "OUTPUT_NODE": info.get("output_node", False), + "RETURN_TYPES": tuple(info.get("return_types", ()) or ()), + "RETURN_NAMES": info.get("return_names"), + function_name: _execute, + "_pyisolate_extension": extension, + "_pyisolate_node_name": node_name, + "INPUT_TYPES": classmethod(_input_types), + } + + output_is_list = info.get("output_is_list") + if output_is_list is not None: + attributes["OUTPUT_IS_LIST"] = tuple(output_is_list) + + if is_v3: + attributes["VALIDATE_CLASS"] = classmethod(_validate_class) + attributes["GET_NODE_INFO_V1"] = classmethod(_get_node_info_v1) + attributes["GET_BASE_CLASS"] = classmethod(_get_base_class) + attributes["DESCRIPTION"] = info.get("description", "") + attributes["EXPERIMENTAL"] = info.get("experimental", False) + attributes["DEPRECATED"] = info.get("deprecated", False) + attributes["API_NODE"] = info.get("api_node", False) + attributes["NOT_IDEMPOTENT"] = info.get("not_idempotent", False) + attributes["ACCEPT_ALL_INPUTS"] = info.get("accept_all_inputs", False) + attributes["_ACCEPT_ALL_INPUTS"] = info.get("accept_all_inputs", False) + attributes["INPUT_IS_LIST"] = info.get("input_is_list", False) + + class_name = f"PyIsolate_{node_name}".replace(" ", "_") + bases = (_ComfyNodeInternal,) if is_v3 else () + stub_cls = type(class_name, bases, attributes) + + if is_v3: + try: + stub_cls.VALIDATE_CLASS() + except Exception as e: + logger.error("%s VALIDATE_CLASS failed: %s - %s", LOG_PREFIX, node_name, e) + + return stub_cls + + +def get_class_types_for_extension( + extension_name: str, + running_extensions: Dict[str, "ComfyNodeExtension"], + specs: List[Any], +) -> Set[str]: + extension = running_extensions.get(extension_name) + if not extension: + return set() + + ext_path = Path(extension.module_path) + class_types = set() + for spec in specs: + if spec.module_path.resolve() == ext_path.resolve(): + class_types.add(spec.node_name) + return class_types + + +__all__ = ["build_stub_class", "get_class_types_for_extension"] diff --git a/comfy/isolation/shm_forensics.py b/comfy/isolation/shm_forensics.py new file mode 100644 index 000000000..36223505a --- /dev/null +++ b/comfy/isolation/shm_forensics.py @@ -0,0 +1,217 @@ +# pylint: disable=consider-using-from-import,import-outside-toplevel +from __future__ import annotations + +import atexit +import hashlib +import logging +import os +from pathlib import Path +from typing import Any, Dict, List, Set + +LOG_PREFIX = "][" +logger = logging.getLogger(__name__) + + +def _shm_debug_enabled() -> bool: + return os.environ.get("COMFY_ISO_SHM_DEBUG") == "1" + + +class _SHMForensicsTracker: + def __init__(self) -> None: + self._started = False + self._tracked_files: Set[str] = set() + self._current_model_context: Dict[str, str] = { + "id": "unknown", + "name": "unknown", + "hash": "????", + } + + @staticmethod + def _snapshot_shm() -> Set[str]: + shm_path = Path("/dev/shm") + if not shm_path.exists(): + return set() + return {f.name for f in shm_path.glob("torch_*")} + + def start(self) -> None: + if self._started or not _shm_debug_enabled(): + return + self._tracked_files = self._snapshot_shm() + self._started = True + logger.debug( + "%s SHM:forensics_enabled tracked=%d", LOG_PREFIX, len(self._tracked_files) + ) + + def stop(self) -> None: + if not self._started: + return + self.scan("shutdown", refresh_model_context=True) + self._started = False + logger.debug("%s SHM:forensics_disabled", LOG_PREFIX) + + def _compute_model_hash(self, model_patcher: Any) -> str: + try: + model_instance_id = getattr(model_patcher, "_instance_id", None) + if model_instance_id is not None: + model_id_text = str(model_instance_id) + return model_id_text[-4:] if len(model_id_text) >= 4 else model_id_text + + import torch + + real_model = ( + model_patcher.model + if hasattr(model_patcher, "model") + else model_patcher + ) + tensor = None + if hasattr(real_model, "parameters"): + for p in real_model.parameters(): + if torch.is_tensor(p) and p.numel() > 0: + tensor = p + break + + if tensor is None: + return "0000" + + flat = tensor.flatten() + values = [] + indices = [0, flat.shape[0] // 2, flat.shape[0] - 1] + for i in indices: + if i < flat.shape[0]: + values.append(flat[i].item()) + + size = 0 + if hasattr(model_patcher, "model_size"): + size = model_patcher.model_size() + sample_str = f"{values}_{id(model_patcher):016x}_{size}" + return hashlib.sha256(sample_str.encode()).hexdigest()[-4:] + except Exception: + return "err!" + + def _get_models_snapshot(self) -> List[Dict[str, Any]]: + try: + import comfy.model_management as model_management + except Exception: + return [] + + snapshot: List[Dict[str, Any]] = [] + try: + for loaded_model in model_management.current_loaded_models: + model = loaded_model.model + if model is None: + continue + if str(getattr(loaded_model, "device", "")) != "cuda:0": + continue + + name = ( + model.model.__class__.__name__ + if hasattr(model, "model") + else type(model).__name__ + ) + model_hash = self._compute_model_hash(model) + model_instance_id = getattr(model, "_instance_id", None) + if model_instance_id is None: + model_instance_id = model_hash + snapshot.append( + { + "name": str(name), + "id": str(model_instance_id), + "hash": str(model_hash or "????"), + "used": bool(getattr(loaded_model, "currently_used", False)), + } + ) + except Exception: + return [] + + return snapshot + + def _update_model_context(self) -> None: + snapshot = self._get_models_snapshot() + selected = None + + used_models = [m for m in snapshot if m.get("used") and m.get("id")] + if used_models: + selected = used_models[-1] + else: + live_models = [m for m in snapshot if m.get("id")] + if live_models: + selected = live_models[-1] + + if selected is None: + self._current_model_context = { + "id": "unknown", + "name": "unknown", + "hash": "????", + } + return + + self._current_model_context = { + "id": str(selected.get("id", "unknown")), + "name": str(selected.get("name", "unknown")), + "hash": str(selected.get("hash", "????") or "????"), + } + + def scan(self, marker: str, refresh_model_context: bool = True) -> None: + if not self._started or not _shm_debug_enabled(): + return + + if refresh_model_context: + self._update_model_context() + + current = self._snapshot_shm() + added = current - self._tracked_files + removed = self._tracked_files - current + self._tracked_files = current + + if not added and not removed: + logger.debug("%s SHM:scan marker=%s changes=0", LOG_PREFIX, marker) + return + + for filename in sorted(added): + logger.info("%s SHM:created | %s", LOG_PREFIX, filename) + model_id = self._current_model_context["id"] + if model_id == "unknown": + logger.error( + "%s SHM:model_association_missing | file=%s | reason=no_active_model_context", + LOG_PREFIX, + filename, + ) + else: + logger.info( + "%s SHM:model_association | model=%s | file=%s | name=%s | hash=%s", + LOG_PREFIX, + model_id, + filename, + self._current_model_context["name"], + self._current_model_context["hash"], + ) + + for filename in sorted(removed): + logger.info("%s SHM:deleted | %s", LOG_PREFIX, filename) + + logger.debug( + "%s SHM:scan marker=%s created=%d deleted=%d active=%d", + LOG_PREFIX, + marker, + len(added), + len(removed), + len(self._tracked_files), + ) + + +_TRACKER = _SHMForensicsTracker() + + +def start_shm_forensics() -> None: + _TRACKER.start() + + +def scan_shm_forensics(marker: str, refresh_model_context: bool = True) -> None: + _TRACKER.scan(marker, refresh_model_context=refresh_model_context) + + +def stop_shm_forensics() -> None: + _TRACKER.stop() + + +atexit.register(stop_shm_forensics) diff --git a/pyproject.toml b/pyproject.toml index 633dac517..51584bf01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,17 @@ homepage = "https://www.comfy.org/" repository = "https://github.com/comfyanonymous/ComfyUI" documentation = "https://docs.comfy.org/" +[tool.comfy.host] +sandbox_mode = "required" +allow_network = false +writable_paths = ["/dev/shm"] + +[tool.comfy.host.whitelist] +"ComfyUI-GGUF" = "*" +"ComfyUI-KJNodes" = "*" +"ComfyUI-Manager" = "*" +"websocket_image_save.py" = "*" + [tool.ruff] lint.select = [ "N805", # invalid-first-argument-name-for-method diff --git a/requirements.txt b/requirements.txt index c3d51e2fa..a34b5a99d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -35,3 +35,5 @@ pydantic~=2.0 pydantic-settings~=2.0 PyOpenGL glfw +uv +pyisolate==0.10.2