mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-05 05:46:12 +02:00
feat(isolation): sandbox policy and runtime fencing
This commit is contained in:
parent
fbb6be5624
commit
c9ebc6aa57
@ -1,2 +1,2 @@
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
|
||||
pause
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
|
||||
pause
|
||||
|
||||
@ -14,6 +14,9 @@ if TYPE_CHECKING:
|
||||
import comfy.lora
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
from comfy.cli_args import args
|
||||
import uuid
|
||||
import os
|
||||
from node_helpers import conditioning_set_values
|
||||
|
||||
# #######################################################################################################
|
||||
@ -61,8 +64,37 @@ class EnumHookScope(enum.Enum):
|
||||
HookedOnly = "hooked_only"
|
||||
|
||||
|
||||
_ISOLATION_HOOKREF_MODE = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
|
||||
|
||||
class _HookRef:
|
||||
pass
|
||||
def __init__(self):
|
||||
if _ISOLATION_HOOKREF_MODE:
|
||||
self._pyisolate_id = str(uuid.uuid4())
|
||||
|
||||
def _ensure_pyisolate_id(self):
|
||||
pyisolate_id = getattr(self, "_pyisolate_id", None)
|
||||
if pyisolate_id is None:
|
||||
pyisolate_id = str(uuid.uuid4())
|
||||
self._pyisolate_id = pyisolate_id
|
||||
return pyisolate_id
|
||||
|
||||
def __eq__(self, other):
|
||||
if not _ISOLATION_HOOKREF_MODE:
|
||||
return self is other
|
||||
if not isinstance(other, _HookRef):
|
||||
return False
|
||||
return self._ensure_pyisolate_id() == other._ensure_pyisolate_id()
|
||||
|
||||
def __hash__(self):
|
||||
if not _ISOLATION_HOOKREF_MODE:
|
||||
return id(self)
|
||||
return hash(self._ensure_pyisolate_id())
|
||||
|
||||
def __str__(self):
|
||||
if not _ISOLATION_HOOKREF_MODE:
|
||||
return super().__str__()
|
||||
return f"PYISOLATE_HOOKREF:{self._ensure_pyisolate_id()}"
|
||||
|
||||
|
||||
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
@ -168,6 +200,8 @@ class WeightHook(Hook):
|
||||
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
|
||||
else:
|
||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||
if self.weights is None:
|
||||
self.weights = {}
|
||||
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
|
||||
else:
|
||||
if target == EnumWeightTarget.Clip:
|
||||
|
||||
180
comfy/isolation/host_policy.py
Normal file
180
comfy/isolation/host_policy.py
Normal file
@ -0,0 +1,180 @@
|
||||
# pylint: disable=logging-fstring-interpolation
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from pathlib import PurePosixPath
|
||||
from typing import Dict, List, TypedDict
|
||||
|
||||
try:
|
||||
import tomllib
|
||||
except ImportError:
|
||||
import tomli as tomllib # type: ignore[no-redef]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
HOST_POLICY_PATH_ENV = "COMFY_HOST_POLICY_PATH"
|
||||
VALID_SANDBOX_MODES = frozenset({"required", "disabled"})
|
||||
FORBIDDEN_WRITABLE_PATHS = frozenset({"/tmp"})
|
||||
|
||||
|
||||
class HostSecurityPolicy(TypedDict):
|
||||
sandbox_mode: str
|
||||
allow_network: bool
|
||||
writable_paths: List[str]
|
||||
readonly_paths: List[str]
|
||||
sealed_worker_ro_import_paths: List[str]
|
||||
whitelist: Dict[str, str]
|
||||
|
||||
|
||||
DEFAULT_POLICY: HostSecurityPolicy = {
|
||||
"sandbox_mode": "required",
|
||||
"allow_network": False,
|
||||
"writable_paths": ["/dev/shm"],
|
||||
"readonly_paths": [],
|
||||
"sealed_worker_ro_import_paths": [],
|
||||
"whitelist": {},
|
||||
}
|
||||
|
||||
|
||||
def _default_policy() -> HostSecurityPolicy:
|
||||
return {
|
||||
"sandbox_mode": DEFAULT_POLICY["sandbox_mode"],
|
||||
"allow_network": DEFAULT_POLICY["allow_network"],
|
||||
"writable_paths": list(DEFAULT_POLICY["writable_paths"]),
|
||||
"readonly_paths": list(DEFAULT_POLICY["readonly_paths"]),
|
||||
"sealed_worker_ro_import_paths": list(DEFAULT_POLICY["sealed_worker_ro_import_paths"]),
|
||||
"whitelist": dict(DEFAULT_POLICY["whitelist"]),
|
||||
}
|
||||
|
||||
|
||||
def _normalize_writable_paths(paths: list[object]) -> list[str]:
|
||||
normalized_paths: list[str] = []
|
||||
for raw_path in paths:
|
||||
# Host-policy paths are contract-style POSIX paths; keep representation
|
||||
# stable across Windows/Linux so tests and config behavior stay consistent.
|
||||
normalized_path = str(PurePosixPath(str(raw_path).replace("\\", "/")))
|
||||
if normalized_path in FORBIDDEN_WRITABLE_PATHS:
|
||||
continue
|
||||
normalized_paths.append(normalized_path)
|
||||
return normalized_paths
|
||||
|
||||
|
||||
def _load_whitelist_file(file_path: Path, config_path: Path) -> Dict[str, str]:
|
||||
if not file_path.is_absolute():
|
||||
file_path = config_path.parent / file_path
|
||||
if not file_path.exists():
|
||||
logger.warning("whitelist_file %s not found, skipping.", file_path)
|
||||
return {}
|
||||
entries: Dict[str, str] = {}
|
||||
for line in file_path.read_text().splitlines():
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
entries[line] = "*"
|
||||
logger.debug("Loaded %d whitelist entries from %s", len(entries), file_path)
|
||||
return entries
|
||||
|
||||
|
||||
def _normalize_sealed_worker_ro_import_paths(raw_paths: object) -> list[str]:
|
||||
if not isinstance(raw_paths, list):
|
||||
raise ValueError(
|
||||
"tool.comfy.host.sealed_worker_ro_import_paths must be a list of absolute paths."
|
||||
)
|
||||
|
||||
normalized_paths: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for raw_path in raw_paths:
|
||||
if not isinstance(raw_path, str) or not raw_path.strip():
|
||||
raise ValueError(
|
||||
"tool.comfy.host.sealed_worker_ro_import_paths entries must be non-empty strings."
|
||||
)
|
||||
normalized_path = str(PurePosixPath(raw_path.replace("\\", "/")))
|
||||
# Accept both POSIX absolute paths (/home/...) and Windows drive-letter paths (D:/...)
|
||||
is_absolute = normalized_path.startswith("/") or (
|
||||
len(normalized_path) >= 3 and normalized_path[1] == ":" and normalized_path[2] == "/"
|
||||
)
|
||||
if not is_absolute:
|
||||
raise ValueError(
|
||||
"tool.comfy.host.sealed_worker_ro_import_paths entries must be absolute paths."
|
||||
)
|
||||
if normalized_path not in seen:
|
||||
seen.add(normalized_path)
|
||||
normalized_paths.append(normalized_path)
|
||||
|
||||
return normalized_paths
|
||||
|
||||
|
||||
def load_host_policy(comfy_root: Path) -> HostSecurityPolicy:
|
||||
config_override = os.environ.get(HOST_POLICY_PATH_ENV)
|
||||
config_path = Path(config_override) if config_override else comfy_root / "pyproject.toml"
|
||||
policy = _default_policy()
|
||||
|
||||
if not config_path.exists():
|
||||
logger.debug("Host policy file missing at %s, using defaults.", config_path)
|
||||
return policy
|
||||
|
||||
try:
|
||||
with config_path.open("rb") as f:
|
||||
data = tomllib.load(f)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to parse host policy from %s, using defaults.",
|
||||
config_path,
|
||||
exc_info=True,
|
||||
)
|
||||
return policy
|
||||
|
||||
tool_config = data.get("tool", {}).get("comfy", {}).get("host", {})
|
||||
if not isinstance(tool_config, dict):
|
||||
logger.debug("No [tool.comfy.host] section found, using defaults.")
|
||||
return policy
|
||||
|
||||
sandbox_mode = tool_config.get("sandbox_mode")
|
||||
if isinstance(sandbox_mode, str):
|
||||
normalized_sandbox_mode = sandbox_mode.strip().lower()
|
||||
if normalized_sandbox_mode in VALID_SANDBOX_MODES:
|
||||
policy["sandbox_mode"] = normalized_sandbox_mode
|
||||
else:
|
||||
logger.warning(
|
||||
"Invalid host sandbox_mode %r in %s, using default %r.",
|
||||
sandbox_mode,
|
||||
config_path,
|
||||
DEFAULT_POLICY["sandbox_mode"],
|
||||
)
|
||||
|
||||
if "allow_network" in tool_config:
|
||||
policy["allow_network"] = bool(tool_config["allow_network"])
|
||||
|
||||
if "writable_paths" in tool_config:
|
||||
policy["writable_paths"] = _normalize_writable_paths(tool_config["writable_paths"])
|
||||
|
||||
if "readonly_paths" in tool_config:
|
||||
policy["readonly_paths"] = [str(p) for p in tool_config["readonly_paths"]]
|
||||
|
||||
if "sealed_worker_ro_import_paths" in tool_config:
|
||||
policy["sealed_worker_ro_import_paths"] = _normalize_sealed_worker_ro_import_paths(
|
||||
tool_config["sealed_worker_ro_import_paths"]
|
||||
)
|
||||
|
||||
whitelist_file = tool_config.get("whitelist_file")
|
||||
if isinstance(whitelist_file, str):
|
||||
policy["whitelist"].update(_load_whitelist_file(Path(whitelist_file), config_path))
|
||||
|
||||
whitelist_raw = tool_config.get("whitelist")
|
||||
if isinstance(whitelist_raw, dict):
|
||||
policy["whitelist"].update({str(k): str(v) for k, v in whitelist_raw.items()})
|
||||
|
||||
os.environ["PYISOLATE_SANDBOX_MODE"] = policy["sandbox_mode"]
|
||||
|
||||
logger.debug(
|
||||
"Loaded Host Policy: %d whitelisted nodes, Sandbox=%s, Network=%s",
|
||||
len(policy["whitelist"]),
|
||||
policy["sandbox_mode"],
|
||||
policy["allow_network"],
|
||||
)
|
||||
return policy
|
||||
|
||||
|
||||
__all__ = ["HostSecurityPolicy", "load_host_policy", "DEFAULT_POLICY"]
|
||||
@ -1,4 +1,5 @@
|
||||
import math
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
from scipy import integrate
|
||||
@ -12,8 +13,8 @@ from . import deis
|
||||
from . import sa_solver
|
||||
import comfy.model_patcher
|
||||
import comfy.model_sampling
|
||||
|
||||
import comfy.memory_management
|
||||
from comfy.cli_args import args
|
||||
from comfy.utils import model_trange as trange
|
||||
|
||||
def append_zero(x):
|
||||
@ -191,6 +192,13 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
if isolation_active:
|
||||
target_device = sigmas.device
|
||||
if x.device != target_device:
|
||||
x = x.to(target_device)
|
||||
s_in = s_in.to(target_device)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
if s_churn > 0:
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
|
||||
@ -20,6 +20,7 @@ import comfy.ldm.hunyuan3dv2_1
|
||||
import comfy.ldm.hunyuan3dv2_1.hunyuandit
|
||||
import torch
|
||||
import logging
|
||||
import os
|
||||
import comfy.ldm.lightricks.av_model
|
||||
import comfy.context_windows
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||
@ -120,8 +121,20 @@ def model_sampling(model_config, model_type):
|
||||
elif model_type == ModelType.V_PREDICTION_DDPM:
|
||||
c = comfy.model_sampling.V_PREDICTION_DDPM
|
||||
|
||||
from comfy.cli_args import args
|
||||
isolation_runtime_enabled = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
|
||||
class ModelSampling(s, c):
|
||||
pass
|
||||
if isolation_runtime_enabled:
|
||||
def __reduce__(self):
|
||||
"""Ensure pickling yields a proxy instead of failing on local class."""
|
||||
try:
|
||||
from comfy.isolation.model_sampling_proxy import ModelSamplingRegistry, ModelSamplingProxy
|
||||
registry = ModelSamplingRegistry()
|
||||
ms_id = registry.register(self)
|
||||
return (ModelSamplingProxy, (ms_id,))
|
||||
except Exception as exc:
|
||||
raise RuntimeError("Failed to serialize ModelSampling for isolation.") from exc
|
||||
|
||||
return ModelSampling(model_config)
|
||||
|
||||
|
||||
@ -498,6 +498,9 @@ except:
|
||||
|
||||
current_loaded_models = []
|
||||
|
||||
def _isolation_mode_enabled():
|
||||
return args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
|
||||
def module_size(module):
|
||||
module_mem = 0
|
||||
sd = module.state_dict()
|
||||
@ -604,8 +607,9 @@ class LoadedModel:
|
||||
if freed >= memory_to_free:
|
||||
return False
|
||||
self.model.detach(unpatch_weights)
|
||||
self.model_finalizer.detach()
|
||||
self.model_finalizer = None
|
||||
if self.model_finalizer is not None:
|
||||
self.model_finalizer.detach()
|
||||
self.model_finalizer = None
|
||||
self.real_model = None
|
||||
return True
|
||||
|
||||
@ -619,8 +623,15 @@ class LoadedModel:
|
||||
if self._patcher_finalizer is not None:
|
||||
self._patcher_finalizer.detach()
|
||||
|
||||
def dead_state(self):
|
||||
model_ref_gone = self.model is None
|
||||
real_model_ref = self.real_model
|
||||
real_model_ref_gone = callable(real_model_ref) and real_model_ref() is None
|
||||
return model_ref_gone, real_model_ref_gone
|
||||
|
||||
def is_dead(self):
|
||||
return self.real_model() is not None and self.model is None
|
||||
model_ref_gone, real_model_ref_gone = self.dead_state()
|
||||
return model_ref_gone or real_model_ref_gone
|
||||
|
||||
|
||||
def use_more_memory(extra_memory, loaded_models, device):
|
||||
@ -667,6 +678,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
||||
unloaded_model = []
|
||||
can_unload = []
|
||||
unloaded_models = []
|
||||
isolation_active = _isolation_mode_enabled()
|
||||
|
||||
for i in range(len(current_loaded_models) -1, -1, -1):
|
||||
shift_model = current_loaded_models[i]
|
||||
@ -675,6 +687,17 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
||||
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||
shift_model.currently_used = False
|
||||
|
||||
if can_unload and isolation_active:
|
||||
try:
|
||||
from pyisolate import flush_tensor_keeper # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
flush_tensor_keeper = None
|
||||
if callable(flush_tensor_keeper):
|
||||
flushed = flush_tensor_keeper()
|
||||
if flushed > 0:
|
||||
logging.debug("][ MM:tensor_keeper_flush | released=%d", flushed)
|
||||
gc.collect()
|
||||
|
||||
can_unload_sorted = sorted(can_unload)
|
||||
for x in can_unload_sorted:
|
||||
i = x[-1]
|
||||
@ -705,7 +728,13 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
||||
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
|
||||
for i in sorted(unloaded_model, reverse=True):
|
||||
unloaded_models.append(current_loaded_models.pop(i))
|
||||
unloaded = current_loaded_models.pop(i)
|
||||
model_obj = unloaded.model
|
||||
if model_obj is not None:
|
||||
cleanup = getattr(model_obj, "cleanup", None)
|
||||
if callable(cleanup):
|
||||
cleanup()
|
||||
unloaded_models.append(unloaded)
|
||||
|
||||
if len(unloaded_model) > 0:
|
||||
soft_empty_cache()
|
||||
@ -764,7 +793,9 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
for i in to_unload:
|
||||
model_to_unload = current_loaded_models.pop(i)
|
||||
model_to_unload.model.detach(unpatch_all=False)
|
||||
model_to_unload.model_finalizer.detach()
|
||||
if model_to_unload.model_finalizer is not None:
|
||||
model_to_unload.model_finalizer.detach()
|
||||
model_to_unload.model_finalizer = None
|
||||
|
||||
|
||||
total_memory_required = {}
|
||||
@ -837,25 +868,62 @@ def loaded_models(only_currently_used=False):
|
||||
|
||||
|
||||
def cleanup_models_gc():
|
||||
do_gc = False
|
||||
|
||||
reset_cast_buffers()
|
||||
if not _isolation_mode_enabled():
|
||||
dead_found = False
|
||||
for i in range(len(current_loaded_models)):
|
||||
if current_loaded_models[i].is_dead():
|
||||
dead_found = True
|
||||
break
|
||||
|
||||
if dead_found:
|
||||
logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.")
|
||||
gc.collect()
|
||||
soft_empty_cache()
|
||||
|
||||
for i in range(len(current_loaded_models) - 1, -1, -1):
|
||||
cur = current_loaded_models[i]
|
||||
if cur.is_dead():
|
||||
logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.")
|
||||
leaked = current_loaded_models.pop(i)
|
||||
model_obj = getattr(leaked, "model", None)
|
||||
if model_obj is not None:
|
||||
cleanup = getattr(model_obj, "cleanup", None)
|
||||
if callable(cleanup):
|
||||
cleanup()
|
||||
return
|
||||
|
||||
dead_found = False
|
||||
has_real_model_leak = False
|
||||
for i in range(len(current_loaded_models)):
|
||||
cur = current_loaded_models[i]
|
||||
if cur.is_dead():
|
||||
logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__))
|
||||
do_gc = True
|
||||
break
|
||||
model_ref_gone, real_model_ref_gone = current_loaded_models[i].dead_state()
|
||||
if model_ref_gone or real_model_ref_gone:
|
||||
dead_found = True
|
||||
if real_model_ref_gone and not model_ref_gone:
|
||||
has_real_model_leak = True
|
||||
|
||||
if do_gc:
|
||||
if dead_found:
|
||||
if has_real_model_leak:
|
||||
logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.")
|
||||
else:
|
||||
logging.debug("Cleaning stale loaded-model entries with released patcher references.")
|
||||
gc.collect()
|
||||
soft_empty_cache()
|
||||
|
||||
for i in range(len(current_loaded_models)):
|
||||
for i in range(len(current_loaded_models) - 1, -1, -1):
|
||||
cur = current_loaded_models[i]
|
||||
if cur.is_dead():
|
||||
logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
|
||||
model_ref_gone, real_model_ref_gone = cur.dead_state()
|
||||
if model_ref_gone or real_model_ref_gone:
|
||||
if real_model_ref_gone and not model_ref_gone:
|
||||
logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.")
|
||||
else:
|
||||
logging.debug("Cleaning stale loaded-model entry with released patcher reference.")
|
||||
leaked = current_loaded_models.pop(i)
|
||||
model_obj = getattr(leaked, "model", None)
|
||||
if model_obj is not None:
|
||||
cleanup = getattr(model_obj, "cleanup", None)
|
||||
if callable(cleanup):
|
||||
cleanup()
|
||||
|
||||
|
||||
def archive_model_dtypes(model):
|
||||
@ -869,11 +937,20 @@ def archive_model_dtypes(model):
|
||||
def cleanup_models():
|
||||
to_delete = []
|
||||
for i in range(len(current_loaded_models)):
|
||||
if current_loaded_models[i].real_model() is None:
|
||||
real_model_ref = current_loaded_models[i].real_model
|
||||
if real_model_ref is None:
|
||||
to_delete = [i] + to_delete
|
||||
continue
|
||||
if callable(real_model_ref) and real_model_ref() is None:
|
||||
to_delete = [i] + to_delete
|
||||
|
||||
for i in to_delete:
|
||||
x = current_loaded_models.pop(i)
|
||||
model_obj = getattr(x, "model", None)
|
||||
if model_obj is not None:
|
||||
cleanup = getattr(model_obj, "cleanup", None)
|
||||
if callable(cleanup):
|
||||
cleanup()
|
||||
del x
|
||||
|
||||
def dtype_size(dtype):
|
||||
|
||||
@ -11,12 +11,14 @@ from functools import partial
|
||||
import collections
|
||||
import math
|
||||
import logging
|
||||
import os
|
||||
import comfy.sampler_helpers
|
||||
import comfy.model_patcher
|
||||
import comfy.patcher_extension
|
||||
import comfy.hooks
|
||||
import comfy.context_windows
|
||||
import comfy.utils
|
||||
from comfy.cli_args import args
|
||||
import scipy.stats
|
||||
import numpy
|
||||
|
||||
@ -210,9 +212,11 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc
|
||||
_calc_cond_batch,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True)
|
||||
)
|
||||
return executor.execute(model, conds, x_in, timestep, model_options)
|
||||
result = executor.execute(model, conds, x_in, timestep, model_options)
|
||||
return result
|
||||
|
||||
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
out_conds = []
|
||||
out_counts = []
|
||||
# separate conds by matching hooks
|
||||
@ -269,7 +273,8 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
for k, v in to_run[tt][0].conditioning.items():
|
||||
cond_shapes[k].append(v.size())
|
||||
|
||||
if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory:
|
||||
memory_required = model.memory_required(input_shape, cond_shapes=cond_shapes)
|
||||
if memory_required * 1.5 < free_memory:
|
||||
to_batch = batch_amount
|
||||
break
|
||||
|
||||
@ -294,9 +299,17 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
patches = p.patches
|
||||
|
||||
batch_chunks = len(cond_or_uncond)
|
||||
input_x = torch.cat(input_x)
|
||||
if isolation_active:
|
||||
target_device = model.load_device if hasattr(model, "load_device") else input_x[0].device
|
||||
input_x = torch.cat(input_x).to(target_device)
|
||||
else:
|
||||
input_x = torch.cat(input_x)
|
||||
c = cond_cat(c)
|
||||
timestep_ = torch.cat([timestep] * batch_chunks)
|
||||
if isolation_active:
|
||||
timestep_ = torch.cat([timestep] * batch_chunks).to(target_device)
|
||||
mult = [m.to(target_device) if hasattr(m, "to") else m for m in mult]
|
||||
else:
|
||||
timestep_ = torch.cat([timestep] * batch_chunks)
|
||||
|
||||
transformer_options = model.current_patcher.apply_hooks(hooks=hooks)
|
||||
if 'transformer_options' in model_options:
|
||||
@ -327,9 +340,17 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
for o in range(batch_chunks):
|
||||
cond_index = cond_or_uncond[o]
|
||||
a = area[o]
|
||||
out_t = output[o]
|
||||
mult_t = mult[o]
|
||||
if isolation_active:
|
||||
target_dev = out_conds[cond_index].device
|
||||
if hasattr(out_t, "device") and out_t.device != target_dev:
|
||||
out_t = out_t.to(target_dev)
|
||||
if hasattr(mult_t, "device") and mult_t.device != target_dev:
|
||||
mult_t = mult_t.to(target_dev)
|
||||
if a is None:
|
||||
out_conds[cond_index] += output[o] * mult[o]
|
||||
out_counts[cond_index] += mult[o]
|
||||
out_conds[cond_index] += out_t * mult_t
|
||||
out_counts[cond_index] += mult_t
|
||||
else:
|
||||
out_c = out_conds[cond_index]
|
||||
out_cts = out_counts[cond_index]
|
||||
@ -337,8 +358,8 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
for i in range(dims):
|
||||
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
|
||||
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
|
||||
out_c += output[o] * mult[o]
|
||||
out_cts += mult[o]
|
||||
out_c += out_t * mult_t
|
||||
out_cts += mult_t
|
||||
|
||||
for i in range(len(out_conds)):
|
||||
out_conds[i] /= out_counts[i]
|
||||
@ -392,14 +413,31 @@ class KSamplerX0Inpaint:
|
||||
self.inner_model = model
|
||||
self.sigmas = sigmas
|
||||
def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None):
|
||||
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
if denoise_mask is not None:
|
||||
if isolation_active and denoise_mask.device != x.device:
|
||||
denoise_mask = denoise_mask.to(x.device)
|
||||
if "denoise_mask_function" in model_options:
|
||||
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
|
||||
latent_mask = 1. - denoise_mask
|
||||
x = x * denoise_mask + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image) * latent_mask
|
||||
if isolation_active:
|
||||
latent_image = self.latent_image
|
||||
if hasattr(latent_image, "device") and latent_image.device != x.device:
|
||||
latent_image = latent_image.to(x.device)
|
||||
scaled = self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=latent_image)
|
||||
if hasattr(scaled, "device") and scaled.device != x.device:
|
||||
scaled = scaled.to(x.device)
|
||||
else:
|
||||
scaled = self.inner_model.inner_model.scale_latent_inpaint(
|
||||
x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image
|
||||
)
|
||||
x = x * denoise_mask + scaled * latent_mask
|
||||
out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
|
||||
if denoise_mask is not None:
|
||||
out = out * denoise_mask + self.latent_image * latent_mask
|
||||
latent_image = self.latent_image
|
||||
if isolation_active and hasattr(latent_image, "device") and latent_image.device != out.device:
|
||||
latent_image = latent_image.to(out.device)
|
||||
out = out * denoise_mask + latent_image * latent_mask
|
||||
return out
|
||||
|
||||
def simple_scheduler(model_sampling, steps):
|
||||
@ -741,7 +779,11 @@ class KSAMPLER(Sampler):
|
||||
else:
|
||||
model_k.noise = noise
|
||||
|
||||
noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas[0], noise, latent_image, self.max_denoise(model_wrap, sigmas))
|
||||
max_denoise = self.max_denoise(model_wrap, sigmas)
|
||||
model_sampling = model_wrap.inner_model.model_sampling
|
||||
noise = model_sampling.noise_scaling(
|
||||
sigmas[0], noise, latent_image, max_denoise
|
||||
)
|
||||
|
||||
k_callback = None
|
||||
total_steps = len(sigmas) - 1
|
||||
|
||||
@ -92,7 +92,7 @@ if args.cuda_malloc:
|
||||
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
|
||||
if env_var is None:
|
||||
env_var = "backend:cudaMallocAsync"
|
||||
else:
|
||||
elif not args.use_process_isolation:
|
||||
env_var += ",backend:cudaMallocAsync"
|
||||
|
||||
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user