mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-02 13:41:03 +01:00
328 lines
12 KiB
Python
328 lines
12 KiB
Python
# pylint: disable=consider-using-from-import,cyclic-import,global-statement,global-variable-not-assigned,import-outside-toplevel,logging-fstring-interpolation
|
|
from __future__ import annotations
|
|
import asyncio
|
|
import inspect
|
|
import logging
|
|
import os
|
|
import time
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Set, TYPE_CHECKING
|
|
import folder_paths
|
|
from .extension_loader import load_isolated_node
|
|
from .manifest_loader import find_manifest_directories
|
|
from .runtime_helpers import build_stub_class, get_class_types_for_extension
|
|
from .shm_forensics import scan_shm_forensics, start_shm_forensics
|
|
|
|
if TYPE_CHECKING:
|
|
from pyisolate import ExtensionManager
|
|
from .extension_wrapper import ComfyNodeExtension
|
|
|
|
LOG_PREFIX = "]["
|
|
isolated_node_timings: List[tuple[float, Path, int]] = []
|
|
|
|
PYISOLATE_VENV_ROOT = Path(folder_paths.base_path) / ".pyisolate_venvs"
|
|
PYISOLATE_VENV_ROOT.mkdir(parents=True, exist_ok=True)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
|
|
|
|
|
|
def initialize_proxies() -> None:
|
|
from .child_hooks import is_child_process
|
|
|
|
is_child = is_child_process()
|
|
|
|
if is_child:
|
|
from .child_hooks import initialize_child_process
|
|
|
|
initialize_child_process()
|
|
else:
|
|
from .host_hooks import initialize_host_process
|
|
|
|
initialize_host_process()
|
|
start_shm_forensics()
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class IsolatedNodeSpec:
|
|
node_name: str
|
|
display_name: str
|
|
stub_class: type
|
|
module_path: Path
|
|
|
|
|
|
_ISOLATED_NODE_SPECS: List[IsolatedNodeSpec] = []
|
|
_CLAIMED_PATHS: Set[Path] = set()
|
|
_ISOLATION_SCAN_ATTEMPTED = False
|
|
_EXTENSION_MANAGERS: List["ExtensionManager"] = []
|
|
_RUNNING_EXTENSIONS: Dict[str, "ComfyNodeExtension"] = {}
|
|
_ISOLATION_BACKGROUND_TASK: Optional["asyncio.Task[List[IsolatedNodeSpec]]"] = None
|
|
_EARLY_START_TIME: Optional[float] = None
|
|
|
|
|
|
def start_isolation_loading_early(loop: "asyncio.AbstractEventLoop") -> None:
|
|
global _ISOLATION_BACKGROUND_TASK, _EARLY_START_TIME
|
|
if _ISOLATION_BACKGROUND_TASK is not None:
|
|
return
|
|
_EARLY_START_TIME = time.perf_counter()
|
|
_ISOLATION_BACKGROUND_TASK = loop.create_task(initialize_isolation_nodes())
|
|
|
|
|
|
async def await_isolation_loading() -> List[IsolatedNodeSpec]:
|
|
global _ISOLATION_BACKGROUND_TASK, _EARLY_START_TIME
|
|
if _ISOLATION_BACKGROUND_TASK is not None:
|
|
specs = await _ISOLATION_BACKGROUND_TASK
|
|
return specs
|
|
return await initialize_isolation_nodes()
|
|
|
|
|
|
async def initialize_isolation_nodes() -> List[IsolatedNodeSpec]:
|
|
global _ISOLATED_NODE_SPECS, _ISOLATION_SCAN_ATTEMPTED, _CLAIMED_PATHS
|
|
|
|
if _ISOLATED_NODE_SPECS:
|
|
return _ISOLATED_NODE_SPECS
|
|
|
|
if _ISOLATION_SCAN_ATTEMPTED:
|
|
return []
|
|
|
|
_ISOLATION_SCAN_ATTEMPTED = True
|
|
manifest_entries = find_manifest_directories()
|
|
_CLAIMED_PATHS = {entry[0].resolve() for entry in manifest_entries}
|
|
|
|
if not manifest_entries:
|
|
return []
|
|
|
|
os.environ["PYISOLATE_ISOLATION_ACTIVE"] = "1"
|
|
concurrency_limit = max(1, (os.cpu_count() or 4) // 2)
|
|
semaphore = asyncio.Semaphore(concurrency_limit)
|
|
|
|
async def load_with_semaphore(
|
|
node_dir: Path, manifest: Path
|
|
) -> List[IsolatedNodeSpec]:
|
|
async with semaphore:
|
|
load_start = time.perf_counter()
|
|
spec_list = await load_isolated_node(
|
|
node_dir,
|
|
manifest,
|
|
logger,
|
|
lambda name, info, extension: build_stub_class(
|
|
name,
|
|
info,
|
|
extension,
|
|
_RUNNING_EXTENSIONS,
|
|
logger,
|
|
),
|
|
PYISOLATE_VENV_ROOT,
|
|
_EXTENSION_MANAGERS,
|
|
)
|
|
spec_list = [
|
|
IsolatedNodeSpec(
|
|
node_name=node_name,
|
|
display_name=display_name,
|
|
stub_class=stub_cls,
|
|
module_path=node_dir,
|
|
)
|
|
for node_name, display_name, stub_cls in spec_list
|
|
]
|
|
isolated_node_timings.append(
|
|
(time.perf_counter() - load_start, node_dir, len(spec_list))
|
|
)
|
|
return spec_list
|
|
|
|
tasks = [
|
|
load_with_semaphore(node_dir, manifest)
|
|
for node_dir, manifest in manifest_entries
|
|
]
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
specs: List[IsolatedNodeSpec] = []
|
|
for result in results:
|
|
if isinstance(result, Exception):
|
|
logger.error(
|
|
"%s Isolated node failed during startup; continuing: %s",
|
|
LOG_PREFIX,
|
|
result,
|
|
)
|
|
continue
|
|
specs.extend(result)
|
|
|
|
_ISOLATED_NODE_SPECS = specs
|
|
return list(_ISOLATED_NODE_SPECS)
|
|
|
|
|
|
def _get_class_types_for_extension(extension_name: str) -> Set[str]:
|
|
"""Get all node class types (node names) belonging to an extension."""
|
|
extension = _RUNNING_EXTENSIONS.get(extension_name)
|
|
if not extension:
|
|
return set()
|
|
|
|
ext_path = Path(extension.module_path)
|
|
class_types = set()
|
|
for spec in _ISOLATED_NODE_SPECS:
|
|
if spec.module_path.resolve() == ext_path.resolve():
|
|
class_types.add(spec.node_name)
|
|
|
|
return class_types
|
|
|
|
|
|
async def notify_execution_graph(needed_class_types: Set[str]) -> None:
|
|
"""Evict running extensions not needed for current execution."""
|
|
|
|
async def _stop_extension(
|
|
ext_name: str, extension: "ComfyNodeExtension", reason: str
|
|
) -> None:
|
|
logger.info("%s ISO:eject_start ext=%s reason=%s", LOG_PREFIX, ext_name, reason)
|
|
logger.debug("%s ISO:stop_start ext=%s", LOG_PREFIX, ext_name)
|
|
stop_result = extension.stop()
|
|
if inspect.isawaitable(stop_result):
|
|
await stop_result
|
|
_RUNNING_EXTENSIONS.pop(ext_name, None)
|
|
logger.debug("%s ISO:stop_done ext=%s", LOG_PREFIX, ext_name)
|
|
scan_shm_forensics("ISO:stop_extension", refresh_model_context=True)
|
|
|
|
scan_shm_forensics("ISO:notify_graph_start", refresh_model_context=True)
|
|
logger.debug(
|
|
"%s ISO:notify_graph_start running=%d needed=%d",
|
|
LOG_PREFIX,
|
|
len(_RUNNING_EXTENSIONS),
|
|
len(needed_class_types),
|
|
)
|
|
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
|
|
ext_class_types = _get_class_types_for_extension(ext_name)
|
|
|
|
# If NONE of this extension's nodes are in the execution graph → evict
|
|
if not ext_class_types.intersection(needed_class_types):
|
|
await _stop_extension(
|
|
ext_name,
|
|
extension,
|
|
"isolated custom_node not in execution graph, evicting",
|
|
)
|
|
|
|
# Isolated child processes add steady VRAM pressure; reclaim host-side models
|
|
# at workflow boundaries so subsequent host nodes (e.g. CLIP encode) keep headroom.
|
|
try:
|
|
import comfy.model_management as model_management
|
|
|
|
device = model_management.get_torch_device()
|
|
if getattr(device, "type", None) == "cuda":
|
|
required = max(
|
|
model_management.minimum_inference_memory(),
|
|
_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES,
|
|
)
|
|
free_before = model_management.get_free_memory(device)
|
|
if free_before < required and _RUNNING_EXTENSIONS:
|
|
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
|
|
await _stop_extension(
|
|
ext_name,
|
|
extension,
|
|
f"boundary low-vram restart (free={int(free_before)} target={int(required)})",
|
|
)
|
|
if model_management.get_free_memory(device) < required:
|
|
model_management.unload_all_models()
|
|
model_management.cleanup_models_gc()
|
|
model_management.cleanup_models()
|
|
if model_management.get_free_memory(device) < required:
|
|
model_management.free_memory(required, device, for_dynamic=False)
|
|
model_management.soft_empty_cache()
|
|
except Exception:
|
|
logger.debug(
|
|
"%s workflow-boundary host VRAM relief failed", LOG_PREFIX, exc_info=True
|
|
)
|
|
finally:
|
|
scan_shm_forensics("ISO:notify_graph_done", refresh_model_context=True)
|
|
logger.debug(
|
|
"%s ISO:notify_graph_done running=%d", LOG_PREFIX, len(_RUNNING_EXTENSIONS)
|
|
)
|
|
|
|
|
|
async def flush_running_extensions_transport_state() -> int:
|
|
total_flushed = 0
|
|
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
|
|
flush_fn = getattr(extension, "flush_transport_state", None)
|
|
if not callable(flush_fn):
|
|
continue
|
|
try:
|
|
flushed = await flush_fn()
|
|
if isinstance(flushed, int):
|
|
total_flushed += flushed
|
|
if flushed > 0:
|
|
logger.debug(
|
|
"%s %s workflow-end flush released=%d",
|
|
LOG_PREFIX,
|
|
ext_name,
|
|
flushed,
|
|
)
|
|
except Exception:
|
|
logger.debug(
|
|
"%s %s workflow-end flush failed", LOG_PREFIX, ext_name, exc_info=True
|
|
)
|
|
scan_shm_forensics(
|
|
"ISO:flush_running_extensions_transport_state", refresh_model_context=True
|
|
)
|
|
return total_flushed
|
|
|
|
|
|
def get_claimed_paths() -> Set[Path]:
|
|
return _CLAIMED_PATHS
|
|
|
|
|
|
def update_rpc_event_loops(loop: "asyncio.AbstractEventLoop | None" = None) -> None:
|
|
"""Update all active RPC instances with the current event loop.
|
|
|
|
This MUST be called at the start of each workflow execution to ensure
|
|
RPC calls are scheduled on the correct event loop. This handles the case
|
|
where asyncio.run() creates a new event loop for each workflow.
|
|
|
|
Args:
|
|
loop: The event loop to use. If None, uses asyncio.get_running_loop().
|
|
"""
|
|
if loop is None:
|
|
try:
|
|
loop = asyncio.get_running_loop()
|
|
except RuntimeError:
|
|
loop = asyncio.get_event_loop()
|
|
|
|
update_count = 0
|
|
|
|
# Update RPCs from ExtensionManagers
|
|
for manager in _EXTENSION_MANAGERS:
|
|
if not hasattr(manager, "extensions"):
|
|
continue
|
|
for name, extension in manager.extensions.items():
|
|
if hasattr(extension, "rpc") and extension.rpc is not None:
|
|
if hasattr(extension.rpc, "update_event_loop"):
|
|
extension.rpc.update_event_loop(loop)
|
|
update_count += 1
|
|
logger.debug(f"{LOG_PREFIX}Updated loop on extension '{name}'")
|
|
|
|
# Also update RPCs from running extensions (they may have direct RPC refs)
|
|
for name, extension in _RUNNING_EXTENSIONS.items():
|
|
if hasattr(extension, "rpc") and extension.rpc is not None:
|
|
if hasattr(extension.rpc, "update_event_loop"):
|
|
extension.rpc.update_event_loop(loop)
|
|
update_count += 1
|
|
logger.debug(f"{LOG_PREFIX}Updated loop on running extension '{name}'")
|
|
|
|
if update_count > 0:
|
|
logger.debug(f"{LOG_PREFIX}Updated event loop on {update_count} RPC instances")
|
|
else:
|
|
logger.debug(
|
|
f"{LOG_PREFIX}No RPC instances found to update (managers={len(_EXTENSION_MANAGERS)}, running={len(_RUNNING_EXTENSIONS)})"
|
|
)
|
|
|
|
|
|
__all__ = [
|
|
"LOG_PREFIX",
|
|
"initialize_proxies",
|
|
"initialize_isolation_nodes",
|
|
"start_isolation_loading_early",
|
|
"await_isolation_loading",
|
|
"notify_execution_graph",
|
|
"flush_running_extensions_transport_state",
|
|
"get_claimed_paths",
|
|
"update_rpc_event_loops",
|
|
"IsolatedNodeSpec",
|
|
"get_class_types_for_extension",
|
|
]
|