mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-02 13:41:03 +01:00
344 lines
11 KiB
Python
344 lines
11 KiB
Python
# pylint: disable=consider-using-from-import,import-outside-toplevel,no-member
|
|
from __future__ import annotations
|
|
|
|
import copy
|
|
import logging
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Set, TYPE_CHECKING
|
|
|
|
from .proxies.helper_proxies import restore_input_types
|
|
from comfy_api.internal import _ComfyNodeInternal
|
|
from comfy_api.latest import _io as latest_io
|
|
from .shm_forensics import scan_shm_forensics
|
|
|
|
if TYPE_CHECKING:
|
|
from .extension_wrapper import ComfyNodeExtension
|
|
|
|
LOG_PREFIX = "]["
|
|
_PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
|
|
|
|
|
|
def _resource_snapshot() -> Dict[str, int]:
|
|
fd_count = -1
|
|
shm_sender_files = 0
|
|
try:
|
|
fd_count = len(os.listdir("/proc/self/fd"))
|
|
except Exception:
|
|
pass
|
|
try:
|
|
shm_root = Path("/dev/shm")
|
|
if shm_root.exists():
|
|
prefix = f"torch_{os.getpid()}_"
|
|
shm_sender_files = sum(1 for _ in shm_root.glob(f"{prefix}*"))
|
|
except Exception:
|
|
pass
|
|
return {"fd_count": fd_count, "shm_sender_files": shm_sender_files}
|
|
|
|
|
|
def _tensor_transport_summary(value: Any) -> Dict[str, int]:
|
|
summary: Dict[str, int] = {
|
|
"tensor_count": 0,
|
|
"cpu_tensors": 0,
|
|
"cuda_tensors": 0,
|
|
"shared_cpu_tensors": 0,
|
|
"tensor_bytes": 0,
|
|
}
|
|
try:
|
|
import torch
|
|
except Exception:
|
|
return summary
|
|
|
|
def visit(node: Any) -> None:
|
|
if isinstance(node, torch.Tensor):
|
|
summary["tensor_count"] += 1
|
|
summary["tensor_bytes"] += int(node.numel() * node.element_size())
|
|
if node.device.type == "cpu":
|
|
summary["cpu_tensors"] += 1
|
|
if node.is_shared():
|
|
summary["shared_cpu_tensors"] += 1
|
|
elif node.device.type == "cuda":
|
|
summary["cuda_tensors"] += 1
|
|
return
|
|
if isinstance(node, dict):
|
|
for v in node.values():
|
|
visit(v)
|
|
return
|
|
if isinstance(node, (list, tuple)):
|
|
for v in node:
|
|
visit(v)
|
|
|
|
visit(value)
|
|
return summary
|
|
|
|
|
|
def _extract_hidden_unique_id(inputs: Dict[str, Any]) -> str | None:
|
|
for key, value in inputs.items():
|
|
key_text = str(key)
|
|
if "unique_id" in key_text:
|
|
return str(value)
|
|
return None
|
|
|
|
|
|
def _flush_tensor_transport_state(marker: str, logger: logging.Logger) -> None:
|
|
try:
|
|
from pyisolate import flush_tensor_keeper # type: ignore[attr-defined]
|
|
except Exception:
|
|
return
|
|
if not callable(flush_tensor_keeper):
|
|
return
|
|
flushed = flush_tensor_keeper()
|
|
if flushed > 0:
|
|
logger.debug(
|
|
"%s %s flush_tensor_keeper released=%d", LOG_PREFIX, marker, flushed
|
|
)
|
|
|
|
|
|
def _relieve_host_vram_pressure(marker: str, logger: logging.Logger) -> None:
|
|
import comfy.model_management as model_management
|
|
|
|
model_management.cleanup_models_gc()
|
|
model_management.cleanup_models()
|
|
|
|
device = model_management.get_torch_device()
|
|
if not hasattr(device, "type") or device.type == "cpu":
|
|
return
|
|
|
|
required = max(
|
|
model_management.minimum_inference_memory(),
|
|
_PRE_EXEC_MIN_FREE_VRAM_BYTES,
|
|
)
|
|
if model_management.get_free_memory(device) < required:
|
|
model_management.free_memory(required, device, for_dynamic=True)
|
|
if model_management.get_free_memory(device) < required:
|
|
model_management.free_memory(required, device, for_dynamic=False)
|
|
model_management.cleanup_models()
|
|
model_management.soft_empty_cache()
|
|
logger.debug("%s %s free_memory target=%d", LOG_PREFIX, marker, required)
|
|
|
|
|
|
def _detach_shared_cpu_tensors(value: Any) -> Any:
|
|
try:
|
|
import torch
|
|
except Exception:
|
|
return value
|
|
|
|
if isinstance(value, torch.Tensor):
|
|
if value.device.type == "cpu" and value.is_shared():
|
|
clone = value.clone()
|
|
if value.requires_grad:
|
|
clone.requires_grad_(True)
|
|
return clone
|
|
return value
|
|
if isinstance(value, list):
|
|
return [_detach_shared_cpu_tensors(v) for v in value]
|
|
if isinstance(value, tuple):
|
|
return tuple(_detach_shared_cpu_tensors(v) for v in value)
|
|
if isinstance(value, dict):
|
|
return {k: _detach_shared_cpu_tensors(v) for k, v in value.items()}
|
|
return value
|
|
|
|
|
|
def build_stub_class(
|
|
node_name: str,
|
|
info: Dict[str, object],
|
|
extension: "ComfyNodeExtension",
|
|
running_extensions: Dict[str, "ComfyNodeExtension"],
|
|
logger: logging.Logger,
|
|
) -> type:
|
|
is_v3 = bool(info.get("is_v3", False))
|
|
function_name = "_pyisolate_execute"
|
|
restored_input_types = restore_input_types(info.get("input_types", {}))
|
|
|
|
async def _execute(self, **inputs):
|
|
from comfy.isolation import _RUNNING_EXTENSIONS
|
|
|
|
# Update BOTH the local dict AND the module-level dict
|
|
running_extensions[extension.name] = extension
|
|
_RUNNING_EXTENSIONS[extension.name] = extension
|
|
prev_child = None
|
|
node_unique_id = _extract_hidden_unique_id(inputs)
|
|
summary = _tensor_transport_summary(inputs)
|
|
resources = _resource_snapshot()
|
|
logger.debug(
|
|
"%s ISO:execute_start ext=%s node=%s uid=%s tensors=%d cpu=%d cuda=%d shared_cpu=%d bytes=%d fds=%d sender_shm=%d",
|
|
LOG_PREFIX,
|
|
extension.name,
|
|
node_name,
|
|
node_unique_id or "-",
|
|
summary["tensor_count"],
|
|
summary["cpu_tensors"],
|
|
summary["cuda_tensors"],
|
|
summary["shared_cpu_tensors"],
|
|
summary["tensor_bytes"],
|
|
resources["fd_count"],
|
|
resources["shm_sender_files"],
|
|
)
|
|
scan_shm_forensics("RUNTIME:execute_start", refresh_model_context=True)
|
|
try:
|
|
if os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1":
|
|
_relieve_host_vram_pressure("RUNTIME:pre_execute", logger)
|
|
scan_shm_forensics("RUNTIME:pre_execute", refresh_model_context=True)
|
|
from pyisolate._internal.model_serialization import (
|
|
serialize_for_isolation,
|
|
deserialize_from_isolation,
|
|
)
|
|
|
|
prev_child = os.environ.pop("PYISOLATE_CHILD", None)
|
|
logger.debug(
|
|
"%s ISO:serialize_start ext=%s node=%s uid=%s",
|
|
LOG_PREFIX,
|
|
extension.name,
|
|
node_name,
|
|
node_unique_id or "-",
|
|
)
|
|
serialized = serialize_for_isolation(inputs)
|
|
logger.debug(
|
|
"%s ISO:serialize_done ext=%s node=%s uid=%s",
|
|
LOG_PREFIX,
|
|
extension.name,
|
|
node_name,
|
|
node_unique_id or "-",
|
|
)
|
|
logger.debug(
|
|
"%s ISO:dispatch_start ext=%s node=%s uid=%s",
|
|
LOG_PREFIX,
|
|
extension.name,
|
|
node_name,
|
|
node_unique_id or "-",
|
|
)
|
|
result = await extension.execute_node(node_name, **serialized)
|
|
logger.debug(
|
|
"%s ISO:dispatch_done ext=%s node=%s uid=%s",
|
|
LOG_PREFIX,
|
|
extension.name,
|
|
node_name,
|
|
node_unique_id or "-",
|
|
)
|
|
deserialized = await deserialize_from_isolation(result, extension)
|
|
scan_shm_forensics("RUNTIME:post_execute", refresh_model_context=True)
|
|
return _detach_shared_cpu_tensors(deserialized)
|
|
except ImportError:
|
|
return await extension.execute_node(node_name, **inputs)
|
|
except Exception:
|
|
logger.exception(
|
|
"%s ISO:execute_error ext=%s node=%s uid=%s",
|
|
LOG_PREFIX,
|
|
extension.name,
|
|
node_name,
|
|
node_unique_id or "-",
|
|
)
|
|
raise
|
|
finally:
|
|
if prev_child is not None:
|
|
os.environ["PYISOLATE_CHILD"] = prev_child
|
|
logger.debug(
|
|
"%s ISO:execute_end ext=%s node=%s uid=%s",
|
|
LOG_PREFIX,
|
|
extension.name,
|
|
node_name,
|
|
node_unique_id or "-",
|
|
)
|
|
scan_shm_forensics("RUNTIME:execute_end", refresh_model_context=True)
|
|
|
|
def _input_types(
|
|
cls,
|
|
include_hidden: bool = True,
|
|
return_schema: bool = False,
|
|
live_inputs: Any = None,
|
|
):
|
|
if not is_v3:
|
|
return restored_input_types
|
|
|
|
inputs_copy = copy.deepcopy(restored_input_types)
|
|
if not include_hidden:
|
|
inputs_copy.pop("hidden", None)
|
|
|
|
v3_data: Dict[str, Any] = {"hidden_inputs": {}}
|
|
dynamic = inputs_copy.pop("dynamic_paths", None)
|
|
if dynamic is not None:
|
|
v3_data["dynamic_paths"] = dynamic
|
|
|
|
if return_schema:
|
|
hidden_vals = info.get("hidden", []) or []
|
|
hidden_enums = []
|
|
for h in hidden_vals:
|
|
try:
|
|
hidden_enums.append(latest_io.Hidden(h))
|
|
except Exception:
|
|
hidden_enums.append(h)
|
|
|
|
class SchemaProxy:
|
|
hidden = hidden_enums
|
|
|
|
return inputs_copy, SchemaProxy, v3_data
|
|
return inputs_copy
|
|
|
|
def _validate_class(cls):
|
|
return True
|
|
|
|
def _get_node_info_v1(cls):
|
|
return info.get("schema_v1", {})
|
|
|
|
def _get_base_class(cls):
|
|
return latest_io.ComfyNode
|
|
|
|
attributes: Dict[str, object] = {
|
|
"FUNCTION": function_name,
|
|
"CATEGORY": info.get("category", ""),
|
|
"OUTPUT_NODE": info.get("output_node", False),
|
|
"RETURN_TYPES": tuple(info.get("return_types", ()) or ()),
|
|
"RETURN_NAMES": info.get("return_names"),
|
|
function_name: _execute,
|
|
"_pyisolate_extension": extension,
|
|
"_pyisolate_node_name": node_name,
|
|
"INPUT_TYPES": classmethod(_input_types),
|
|
}
|
|
|
|
output_is_list = info.get("output_is_list")
|
|
if output_is_list is not None:
|
|
attributes["OUTPUT_IS_LIST"] = tuple(output_is_list)
|
|
|
|
if is_v3:
|
|
attributes["VALIDATE_CLASS"] = classmethod(_validate_class)
|
|
attributes["GET_NODE_INFO_V1"] = classmethod(_get_node_info_v1)
|
|
attributes["GET_BASE_CLASS"] = classmethod(_get_base_class)
|
|
attributes["DESCRIPTION"] = info.get("description", "")
|
|
attributes["EXPERIMENTAL"] = info.get("experimental", False)
|
|
attributes["DEPRECATED"] = info.get("deprecated", False)
|
|
attributes["API_NODE"] = info.get("api_node", False)
|
|
attributes["NOT_IDEMPOTENT"] = info.get("not_idempotent", False)
|
|
attributes["INPUT_IS_LIST"] = info.get("input_is_list", False)
|
|
|
|
class_name = f"PyIsolate_{node_name}".replace(" ", "_")
|
|
bases = (_ComfyNodeInternal,) if is_v3 else ()
|
|
stub_cls = type(class_name, bases, attributes)
|
|
|
|
if is_v3:
|
|
try:
|
|
stub_cls.VALIDATE_CLASS()
|
|
except Exception as e:
|
|
logger.error("%s VALIDATE_CLASS failed: %s - %s", LOG_PREFIX, node_name, e)
|
|
|
|
return stub_cls
|
|
|
|
|
|
def get_class_types_for_extension(
|
|
extension_name: str,
|
|
running_extensions: Dict[str, "ComfyNodeExtension"],
|
|
specs: List[Any],
|
|
) -> Set[str]:
|
|
extension = running_extensions.get(extension_name)
|
|
if not extension:
|
|
return set()
|
|
|
|
ext_path = Path(extension.module_path)
|
|
class_types = set()
|
|
for spec in specs:
|
|
if spec.module_path.resolve() == ext_path.resolve():
|
|
class_types.add(spec.node_name)
|
|
return class_types
|
|
|
|
|
|
__all__ = ["build_stub_class", "get_class_types_for_extension"]
|