feat(isolation): sandbox policy and runtime fencing

This commit is contained in:
John Pollock 2026-04-30 14:51:29 -05:00
parent fbb6be5624
commit c9ebc6aa57
8 changed files with 388 additions and 34 deletions

View File

@ -1,2 +1,2 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
pause
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
pause

View File

@ -14,6 +14,9 @@ if TYPE_CHECKING:
import comfy.lora
import comfy.model_management
import comfy.patcher_extension
from comfy.cli_args import args
import uuid
import os
from node_helpers import conditioning_set_values
# #######################################################################################################
@ -61,8 +64,37 @@ class EnumHookScope(enum.Enum):
HookedOnly = "hooked_only"
_ISOLATION_HOOKREF_MODE = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
class _HookRef:
pass
def __init__(self):
if _ISOLATION_HOOKREF_MODE:
self._pyisolate_id = str(uuid.uuid4())
def _ensure_pyisolate_id(self):
pyisolate_id = getattr(self, "_pyisolate_id", None)
if pyisolate_id is None:
pyisolate_id = str(uuid.uuid4())
self._pyisolate_id = pyisolate_id
return pyisolate_id
def __eq__(self, other):
if not _ISOLATION_HOOKREF_MODE:
return self is other
if not isinstance(other, _HookRef):
return False
return self._ensure_pyisolate_id() == other._ensure_pyisolate_id()
def __hash__(self):
if not _ISOLATION_HOOKREF_MODE:
return id(self)
return hash(self._ensure_pyisolate_id())
def __str__(self):
if not _ISOLATION_HOOKREF_MODE:
return super().__str__()
return f"PYISOLATE_HOOKREF:{self._ensure_pyisolate_id()}"
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
@ -168,6 +200,8 @@ class WeightHook(Hook):
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
else:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
if self.weights is None:
self.weights = {}
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
else:
if target == EnumWeightTarget.Clip:

View File

@ -0,0 +1,180 @@
# pylint: disable=logging-fstring-interpolation
from __future__ import annotations
import logging
import os
from pathlib import Path
from pathlib import PurePosixPath
from typing import Dict, List, TypedDict
try:
import tomllib
except ImportError:
import tomli as tomllib # type: ignore[no-redef]
logger = logging.getLogger(__name__)
HOST_POLICY_PATH_ENV = "COMFY_HOST_POLICY_PATH"
VALID_SANDBOX_MODES = frozenset({"required", "disabled"})
FORBIDDEN_WRITABLE_PATHS = frozenset({"/tmp"})
class HostSecurityPolicy(TypedDict):
sandbox_mode: str
allow_network: bool
writable_paths: List[str]
readonly_paths: List[str]
sealed_worker_ro_import_paths: List[str]
whitelist: Dict[str, str]
DEFAULT_POLICY: HostSecurityPolicy = {
"sandbox_mode": "required",
"allow_network": False,
"writable_paths": ["/dev/shm"],
"readonly_paths": [],
"sealed_worker_ro_import_paths": [],
"whitelist": {},
}
def _default_policy() -> HostSecurityPolicy:
return {
"sandbox_mode": DEFAULT_POLICY["sandbox_mode"],
"allow_network": DEFAULT_POLICY["allow_network"],
"writable_paths": list(DEFAULT_POLICY["writable_paths"]),
"readonly_paths": list(DEFAULT_POLICY["readonly_paths"]),
"sealed_worker_ro_import_paths": list(DEFAULT_POLICY["sealed_worker_ro_import_paths"]),
"whitelist": dict(DEFAULT_POLICY["whitelist"]),
}
def _normalize_writable_paths(paths: list[object]) -> list[str]:
normalized_paths: list[str] = []
for raw_path in paths:
# Host-policy paths are contract-style POSIX paths; keep representation
# stable across Windows/Linux so tests and config behavior stay consistent.
normalized_path = str(PurePosixPath(str(raw_path).replace("\\", "/")))
if normalized_path in FORBIDDEN_WRITABLE_PATHS:
continue
normalized_paths.append(normalized_path)
return normalized_paths
def _load_whitelist_file(file_path: Path, config_path: Path) -> Dict[str, str]:
if not file_path.is_absolute():
file_path = config_path.parent / file_path
if not file_path.exists():
logger.warning("whitelist_file %s not found, skipping.", file_path)
return {}
entries: Dict[str, str] = {}
for line in file_path.read_text().splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
entries[line] = "*"
logger.debug("Loaded %d whitelist entries from %s", len(entries), file_path)
return entries
def _normalize_sealed_worker_ro_import_paths(raw_paths: object) -> list[str]:
if not isinstance(raw_paths, list):
raise ValueError(
"tool.comfy.host.sealed_worker_ro_import_paths must be a list of absolute paths."
)
normalized_paths: list[str] = []
seen: set[str] = set()
for raw_path in raw_paths:
if not isinstance(raw_path, str) or not raw_path.strip():
raise ValueError(
"tool.comfy.host.sealed_worker_ro_import_paths entries must be non-empty strings."
)
normalized_path = str(PurePosixPath(raw_path.replace("\\", "/")))
# Accept both POSIX absolute paths (/home/...) and Windows drive-letter paths (D:/...)
is_absolute = normalized_path.startswith("/") or (
len(normalized_path) >= 3 and normalized_path[1] == ":" and normalized_path[2] == "/"
)
if not is_absolute:
raise ValueError(
"tool.comfy.host.sealed_worker_ro_import_paths entries must be absolute paths."
)
if normalized_path not in seen:
seen.add(normalized_path)
normalized_paths.append(normalized_path)
return normalized_paths
def load_host_policy(comfy_root: Path) -> HostSecurityPolicy:
config_override = os.environ.get(HOST_POLICY_PATH_ENV)
config_path = Path(config_override) if config_override else comfy_root / "pyproject.toml"
policy = _default_policy()
if not config_path.exists():
logger.debug("Host policy file missing at %s, using defaults.", config_path)
return policy
try:
with config_path.open("rb") as f:
data = tomllib.load(f)
except Exception:
logger.warning(
"Failed to parse host policy from %s, using defaults.",
config_path,
exc_info=True,
)
return policy
tool_config = data.get("tool", {}).get("comfy", {}).get("host", {})
if not isinstance(tool_config, dict):
logger.debug("No [tool.comfy.host] section found, using defaults.")
return policy
sandbox_mode = tool_config.get("sandbox_mode")
if isinstance(sandbox_mode, str):
normalized_sandbox_mode = sandbox_mode.strip().lower()
if normalized_sandbox_mode in VALID_SANDBOX_MODES:
policy["sandbox_mode"] = normalized_sandbox_mode
else:
logger.warning(
"Invalid host sandbox_mode %r in %s, using default %r.",
sandbox_mode,
config_path,
DEFAULT_POLICY["sandbox_mode"],
)
if "allow_network" in tool_config:
policy["allow_network"] = bool(tool_config["allow_network"])
if "writable_paths" in tool_config:
policy["writable_paths"] = _normalize_writable_paths(tool_config["writable_paths"])
if "readonly_paths" in tool_config:
policy["readonly_paths"] = [str(p) for p in tool_config["readonly_paths"]]
if "sealed_worker_ro_import_paths" in tool_config:
policy["sealed_worker_ro_import_paths"] = _normalize_sealed_worker_ro_import_paths(
tool_config["sealed_worker_ro_import_paths"]
)
whitelist_file = tool_config.get("whitelist_file")
if isinstance(whitelist_file, str):
policy["whitelist"].update(_load_whitelist_file(Path(whitelist_file), config_path))
whitelist_raw = tool_config.get("whitelist")
if isinstance(whitelist_raw, dict):
policy["whitelist"].update({str(k): str(v) for k, v in whitelist_raw.items()})
os.environ["PYISOLATE_SANDBOX_MODE"] = policy["sandbox_mode"]
logger.debug(
"Loaded Host Policy: %d whitelisted nodes, Sandbox=%s, Network=%s",
len(policy["whitelist"]),
policy["sandbox_mode"],
policy["allow_network"],
)
return policy
__all__ = ["HostSecurityPolicy", "load_host_policy", "DEFAULT_POLICY"]

View File

@ -1,4 +1,5 @@
import math
import os
from functools import partial
from scipy import integrate
@ -12,8 +13,8 @@ from . import deis
from . import sa_solver
import comfy.model_patcher
import comfy.model_sampling
import comfy.memory_management
from comfy.cli_args import args
from comfy.utils import model_trange as trange
def append_zero(x):
@ -191,6 +192,13 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
if isolation_active:
target_device = sigmas.device
if x.device != target_device:
x = x.to(target_device)
s_in = s_in.to(target_device)
for i in trange(len(sigmas) - 1, disable=disable):
if s_churn > 0:
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.

View File

@ -20,6 +20,7 @@ import comfy.ldm.hunyuan3dv2_1
import comfy.ldm.hunyuan3dv2_1.hunyuandit
import torch
import logging
import os
import comfy.ldm.lightricks.av_model
import comfy.context_windows
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
@ -120,8 +121,20 @@ def model_sampling(model_config, model_type):
elif model_type == ModelType.V_PREDICTION_DDPM:
c = comfy.model_sampling.V_PREDICTION_DDPM
from comfy.cli_args import args
isolation_runtime_enabled = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
class ModelSampling(s, c):
pass
if isolation_runtime_enabled:
def __reduce__(self):
"""Ensure pickling yields a proxy instead of failing on local class."""
try:
from comfy.isolation.model_sampling_proxy import ModelSamplingRegistry, ModelSamplingProxy
registry = ModelSamplingRegistry()
ms_id = registry.register(self)
return (ModelSamplingProxy, (ms_id,))
except Exception as exc:
raise RuntimeError("Failed to serialize ModelSampling for isolation.") from exc
return ModelSampling(model_config)

View File

@ -498,6 +498,9 @@ except:
current_loaded_models = []
def _isolation_mode_enabled():
return args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
def module_size(module):
module_mem = 0
sd = module.state_dict()
@ -604,8 +607,9 @@ class LoadedModel:
if freed >= memory_to_free:
return False
self.model.detach(unpatch_weights)
self.model_finalizer.detach()
self.model_finalizer = None
if self.model_finalizer is not None:
self.model_finalizer.detach()
self.model_finalizer = None
self.real_model = None
return True
@ -619,8 +623,15 @@ class LoadedModel:
if self._patcher_finalizer is not None:
self._patcher_finalizer.detach()
def dead_state(self):
model_ref_gone = self.model is None
real_model_ref = self.real_model
real_model_ref_gone = callable(real_model_ref) and real_model_ref() is None
return model_ref_gone, real_model_ref_gone
def is_dead(self):
return self.real_model() is not None and self.model is None
model_ref_gone, real_model_ref_gone = self.dead_state()
return model_ref_gone or real_model_ref_gone
def use_more_memory(extra_memory, loaded_models, device):
@ -667,6 +678,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
unloaded_model = []
can_unload = []
unloaded_models = []
isolation_active = _isolation_mode_enabled()
for i in range(len(current_loaded_models) -1, -1, -1):
shift_model = current_loaded_models[i]
@ -675,6 +687,17 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
shift_model.currently_used = False
if can_unload and isolation_active:
try:
from pyisolate import flush_tensor_keeper # type: ignore[attr-defined]
except Exception:
flush_tensor_keeper = None
if callable(flush_tensor_keeper):
flushed = flush_tensor_keeper()
if flushed > 0:
logging.debug("][ MM:tensor_keeper_flush | released=%d", flushed)
gc.collect()
can_unload_sorted = sorted(can_unload)
for x in can_unload_sorted:
i = x[-1]
@ -705,7 +728,13 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
for i in sorted(unloaded_model, reverse=True):
unloaded_models.append(current_loaded_models.pop(i))
unloaded = current_loaded_models.pop(i)
model_obj = unloaded.model
if model_obj is not None:
cleanup = getattr(model_obj, "cleanup", None)
if callable(cleanup):
cleanup()
unloaded_models.append(unloaded)
if len(unloaded_model) > 0:
soft_empty_cache()
@ -764,7 +793,9 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
for i in to_unload:
model_to_unload = current_loaded_models.pop(i)
model_to_unload.model.detach(unpatch_all=False)
model_to_unload.model_finalizer.detach()
if model_to_unload.model_finalizer is not None:
model_to_unload.model_finalizer.detach()
model_to_unload.model_finalizer = None
total_memory_required = {}
@ -837,25 +868,62 @@ def loaded_models(only_currently_used=False):
def cleanup_models_gc():
do_gc = False
reset_cast_buffers()
if not _isolation_mode_enabled():
dead_found = False
for i in range(len(current_loaded_models)):
if current_loaded_models[i].is_dead():
dead_found = True
break
if dead_found:
logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.")
gc.collect()
soft_empty_cache()
for i in range(len(current_loaded_models) - 1, -1, -1):
cur = current_loaded_models[i]
if cur.is_dead():
logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.")
leaked = current_loaded_models.pop(i)
model_obj = getattr(leaked, "model", None)
if model_obj is not None:
cleanup = getattr(model_obj, "cleanup", None)
if callable(cleanup):
cleanup()
return
dead_found = False
has_real_model_leak = False
for i in range(len(current_loaded_models)):
cur = current_loaded_models[i]
if cur.is_dead():
logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__))
do_gc = True
break
model_ref_gone, real_model_ref_gone = current_loaded_models[i].dead_state()
if model_ref_gone or real_model_ref_gone:
dead_found = True
if real_model_ref_gone and not model_ref_gone:
has_real_model_leak = True
if do_gc:
if dead_found:
if has_real_model_leak:
logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.")
else:
logging.debug("Cleaning stale loaded-model entries with released patcher references.")
gc.collect()
soft_empty_cache()
for i in range(len(current_loaded_models)):
for i in range(len(current_loaded_models) - 1, -1, -1):
cur = current_loaded_models[i]
if cur.is_dead():
logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
model_ref_gone, real_model_ref_gone = cur.dead_state()
if model_ref_gone or real_model_ref_gone:
if real_model_ref_gone and not model_ref_gone:
logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.")
else:
logging.debug("Cleaning stale loaded-model entry with released patcher reference.")
leaked = current_loaded_models.pop(i)
model_obj = getattr(leaked, "model", None)
if model_obj is not None:
cleanup = getattr(model_obj, "cleanup", None)
if callable(cleanup):
cleanup()
def archive_model_dtypes(model):
@ -869,11 +937,20 @@ def archive_model_dtypes(model):
def cleanup_models():
to_delete = []
for i in range(len(current_loaded_models)):
if current_loaded_models[i].real_model() is None:
real_model_ref = current_loaded_models[i].real_model
if real_model_ref is None:
to_delete = [i] + to_delete
continue
if callable(real_model_ref) and real_model_ref() is None:
to_delete = [i] + to_delete
for i in to_delete:
x = current_loaded_models.pop(i)
model_obj = getattr(x, "model", None)
if model_obj is not None:
cleanup = getattr(model_obj, "cleanup", None)
if callable(cleanup):
cleanup()
del x
def dtype_size(dtype):

View File

@ -11,12 +11,14 @@ from functools import partial
import collections
import math
import logging
import os
import comfy.sampler_helpers
import comfy.model_patcher
import comfy.patcher_extension
import comfy.hooks
import comfy.context_windows
import comfy.utils
from comfy.cli_args import args
import scipy.stats
import numpy
@ -210,9 +212,11 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc
_calc_cond_batch,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True)
)
return executor.execute(model, conds, x_in, timestep, model_options)
result = executor.execute(model, conds, x_in, timestep, model_options)
return result
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
out_conds = []
out_counts = []
# separate conds by matching hooks
@ -269,7 +273,8 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
for k, v in to_run[tt][0].conditioning.items():
cond_shapes[k].append(v.size())
if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory:
memory_required = model.memory_required(input_shape, cond_shapes=cond_shapes)
if memory_required * 1.5 < free_memory:
to_batch = batch_amount
break
@ -294,9 +299,17 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
patches = p.patches
batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x)
if isolation_active:
target_device = model.load_device if hasattr(model, "load_device") else input_x[0].device
input_x = torch.cat(input_x).to(target_device)
else:
input_x = torch.cat(input_x)
c = cond_cat(c)
timestep_ = torch.cat([timestep] * batch_chunks)
if isolation_active:
timestep_ = torch.cat([timestep] * batch_chunks).to(target_device)
mult = [m.to(target_device) if hasattr(m, "to") else m for m in mult]
else:
timestep_ = torch.cat([timestep] * batch_chunks)
transformer_options = model.current_patcher.apply_hooks(hooks=hooks)
if 'transformer_options' in model_options:
@ -327,9 +340,17 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
for o in range(batch_chunks):
cond_index = cond_or_uncond[o]
a = area[o]
out_t = output[o]
mult_t = mult[o]
if isolation_active:
target_dev = out_conds[cond_index].device
if hasattr(out_t, "device") and out_t.device != target_dev:
out_t = out_t.to(target_dev)
if hasattr(mult_t, "device") and mult_t.device != target_dev:
mult_t = mult_t.to(target_dev)
if a is None:
out_conds[cond_index] += output[o] * mult[o]
out_counts[cond_index] += mult[o]
out_conds[cond_index] += out_t * mult_t
out_counts[cond_index] += mult_t
else:
out_c = out_conds[cond_index]
out_cts = out_counts[cond_index]
@ -337,8 +358,8 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
for i in range(dims):
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
out_c += output[o] * mult[o]
out_cts += mult[o]
out_c += out_t * mult_t
out_cts += mult_t
for i in range(len(out_conds)):
out_conds[i] /= out_counts[i]
@ -392,14 +413,31 @@ class KSamplerX0Inpaint:
self.inner_model = model
self.sigmas = sigmas
def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None):
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
if denoise_mask is not None:
if isolation_active and denoise_mask.device != x.device:
denoise_mask = denoise_mask.to(x.device)
if "denoise_mask_function" in model_options:
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
latent_mask = 1. - denoise_mask
x = x * denoise_mask + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image) * latent_mask
if isolation_active:
latent_image = self.latent_image
if hasattr(latent_image, "device") and latent_image.device != x.device:
latent_image = latent_image.to(x.device)
scaled = self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=latent_image)
if hasattr(scaled, "device") and scaled.device != x.device:
scaled = scaled.to(x.device)
else:
scaled = self.inner_model.inner_model.scale_latent_inpaint(
x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image
)
x = x * denoise_mask + scaled * latent_mask
out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
if denoise_mask is not None:
out = out * denoise_mask + self.latent_image * latent_mask
latent_image = self.latent_image
if isolation_active and hasattr(latent_image, "device") and latent_image.device != out.device:
latent_image = latent_image.to(out.device)
out = out * denoise_mask + latent_image * latent_mask
return out
def simple_scheduler(model_sampling, steps):
@ -741,7 +779,11 @@ class KSAMPLER(Sampler):
else:
model_k.noise = noise
noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas[0], noise, latent_image, self.max_denoise(model_wrap, sigmas))
max_denoise = self.max_denoise(model_wrap, sigmas)
model_sampling = model_wrap.inner_model.model_sampling
noise = model_sampling.noise_scaling(
sigmas[0], noise, latent_image, max_denoise
)
k_callback = None
total_steps = len(sigmas) - 1

View File

@ -92,7 +92,7 @@ if args.cuda_malloc:
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
if env_var is None:
env_var = "backend:cudaMallocAsync"
else:
elif not args.use_process_isolation:
env_var += ",backend:cudaMallocAsync"
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var