diff --git a/.ci/windows_intel_base_files/run_intel_gpu.bat b/.ci/windows_intel_base_files/run_intel_gpu.bat index 274d7c948..3c845624d 100755 --- a/.ci/windows_intel_base_files/run_intel_gpu.bat +++ b/.ci/windows_intel_base_files/run_intel_gpu.bat @@ -1,2 +1,2 @@ -.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build -pause +.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build +pause diff --git a/comfy/hooks.py b/comfy/hooks.py index 1a76c7ba4..7a5f69ca7 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -14,6 +14,9 @@ if TYPE_CHECKING: import comfy.lora import comfy.model_management import comfy.patcher_extension +from comfy.cli_args import args +import uuid +import os from node_helpers import conditioning_set_values # ####################################################################################################### @@ -61,8 +64,37 @@ class EnumHookScope(enum.Enum): HookedOnly = "hooked_only" +_ISOLATION_HOOKREF_MODE = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1" + + class _HookRef: - pass + def __init__(self): + if _ISOLATION_HOOKREF_MODE: + self._pyisolate_id = str(uuid.uuid4()) + + def _ensure_pyisolate_id(self): + pyisolate_id = getattr(self, "_pyisolate_id", None) + if pyisolate_id is None: + pyisolate_id = str(uuid.uuid4()) + self._pyisolate_id = pyisolate_id + return pyisolate_id + + def __eq__(self, other): + if not _ISOLATION_HOOKREF_MODE: + return self is other + if not isinstance(other, _HookRef): + return False + return self._ensure_pyisolate_id() == other._ensure_pyisolate_id() + + def __hash__(self): + if not _ISOLATION_HOOKREF_MODE: + return id(self) + return hash(self._ensure_pyisolate_id()) + + def __str__(self): + if not _ISOLATION_HOOKREF_MODE: + return super().__str__() + return f"PYISOLATE_HOOKREF:{self._ensure_pyisolate_id()}" def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): @@ -168,6 +200,8 @@ class WeightHook(Hook): key_map = comfy.lora.model_lora_keys_clip(model.model, key_map) else: key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) + if self.weights is None: + self.weights = {} weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False) else: if target == EnumWeightTarget.Clip: diff --git a/comfy/isolation/host_policy.py b/comfy/isolation/host_policy.py new file mode 100644 index 000000000..91b8e5b96 --- /dev/null +++ b/comfy/isolation/host_policy.py @@ -0,0 +1,180 @@ +# pylint: disable=logging-fstring-interpolation +from __future__ import annotations + +import logging +import os +from pathlib import Path +from pathlib import PurePosixPath +from typing import Dict, List, TypedDict + +try: + import tomllib +except ImportError: + import tomli as tomllib # type: ignore[no-redef] + +logger = logging.getLogger(__name__) + +HOST_POLICY_PATH_ENV = "COMFY_HOST_POLICY_PATH" +VALID_SANDBOX_MODES = frozenset({"required", "disabled"}) +FORBIDDEN_WRITABLE_PATHS = frozenset({"/tmp"}) + + +class HostSecurityPolicy(TypedDict): + sandbox_mode: str + allow_network: bool + writable_paths: List[str] + readonly_paths: List[str] + sealed_worker_ro_import_paths: List[str] + whitelist: Dict[str, str] + + +DEFAULT_POLICY: HostSecurityPolicy = { + "sandbox_mode": "required", + "allow_network": False, + "writable_paths": ["/dev/shm"], + "readonly_paths": [], + "sealed_worker_ro_import_paths": [], + "whitelist": {}, +} + + +def _default_policy() -> HostSecurityPolicy: + return { + "sandbox_mode": DEFAULT_POLICY["sandbox_mode"], + "allow_network": DEFAULT_POLICY["allow_network"], + "writable_paths": list(DEFAULT_POLICY["writable_paths"]), + "readonly_paths": list(DEFAULT_POLICY["readonly_paths"]), + "sealed_worker_ro_import_paths": list(DEFAULT_POLICY["sealed_worker_ro_import_paths"]), + "whitelist": dict(DEFAULT_POLICY["whitelist"]), + } + + +def _normalize_writable_paths(paths: list[object]) -> list[str]: + normalized_paths: list[str] = [] + for raw_path in paths: + # Host-policy paths are contract-style POSIX paths; keep representation + # stable across Windows/Linux so tests and config behavior stay consistent. + normalized_path = str(PurePosixPath(str(raw_path).replace("\\", "/"))) + if normalized_path in FORBIDDEN_WRITABLE_PATHS: + continue + normalized_paths.append(normalized_path) + return normalized_paths + + +def _load_whitelist_file(file_path: Path, config_path: Path) -> Dict[str, str]: + if not file_path.is_absolute(): + file_path = config_path.parent / file_path + if not file_path.exists(): + logger.warning("whitelist_file %s not found, skipping.", file_path) + return {} + entries: Dict[str, str] = {} + for line in file_path.read_text().splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + entries[line] = "*" + logger.debug("Loaded %d whitelist entries from %s", len(entries), file_path) + return entries + + +def _normalize_sealed_worker_ro_import_paths(raw_paths: object) -> list[str]: + if not isinstance(raw_paths, list): + raise ValueError( + "tool.comfy.host.sealed_worker_ro_import_paths must be a list of absolute paths." + ) + + normalized_paths: list[str] = [] + seen: set[str] = set() + for raw_path in raw_paths: + if not isinstance(raw_path, str) or not raw_path.strip(): + raise ValueError( + "tool.comfy.host.sealed_worker_ro_import_paths entries must be non-empty strings." + ) + normalized_path = str(PurePosixPath(raw_path.replace("\\", "/"))) + # Accept both POSIX absolute paths (/home/...) and Windows drive-letter paths (D:/...) + is_absolute = normalized_path.startswith("/") or ( + len(normalized_path) >= 3 and normalized_path[1] == ":" and normalized_path[2] == "/" + ) + if not is_absolute: + raise ValueError( + "tool.comfy.host.sealed_worker_ro_import_paths entries must be absolute paths." + ) + if normalized_path not in seen: + seen.add(normalized_path) + normalized_paths.append(normalized_path) + + return normalized_paths + + +def load_host_policy(comfy_root: Path) -> HostSecurityPolicy: + config_override = os.environ.get(HOST_POLICY_PATH_ENV) + config_path = Path(config_override) if config_override else comfy_root / "pyproject.toml" + policy = _default_policy() + + if not config_path.exists(): + logger.debug("Host policy file missing at %s, using defaults.", config_path) + return policy + + try: + with config_path.open("rb") as f: + data = tomllib.load(f) + except Exception: + logger.warning( + "Failed to parse host policy from %s, using defaults.", + config_path, + exc_info=True, + ) + return policy + + tool_config = data.get("tool", {}).get("comfy", {}).get("host", {}) + if not isinstance(tool_config, dict): + logger.debug("No [tool.comfy.host] section found, using defaults.") + return policy + + sandbox_mode = tool_config.get("sandbox_mode") + if isinstance(sandbox_mode, str): + normalized_sandbox_mode = sandbox_mode.strip().lower() + if normalized_sandbox_mode in VALID_SANDBOX_MODES: + policy["sandbox_mode"] = normalized_sandbox_mode + else: + logger.warning( + "Invalid host sandbox_mode %r in %s, using default %r.", + sandbox_mode, + config_path, + DEFAULT_POLICY["sandbox_mode"], + ) + + if "allow_network" in tool_config: + policy["allow_network"] = bool(tool_config["allow_network"]) + + if "writable_paths" in tool_config: + policy["writable_paths"] = _normalize_writable_paths(tool_config["writable_paths"]) + + if "readonly_paths" in tool_config: + policy["readonly_paths"] = [str(p) for p in tool_config["readonly_paths"]] + + if "sealed_worker_ro_import_paths" in tool_config: + policy["sealed_worker_ro_import_paths"] = _normalize_sealed_worker_ro_import_paths( + tool_config["sealed_worker_ro_import_paths"] + ) + + whitelist_file = tool_config.get("whitelist_file") + if isinstance(whitelist_file, str): + policy["whitelist"].update(_load_whitelist_file(Path(whitelist_file), config_path)) + + whitelist_raw = tool_config.get("whitelist") + if isinstance(whitelist_raw, dict): + policy["whitelist"].update({str(k): str(v) for k, v in whitelist_raw.items()}) + + os.environ["PYISOLATE_SANDBOX_MODE"] = policy["sandbox_mode"] + + logger.debug( + "Loaded Host Policy: %d whitelisted nodes, Sandbox=%s, Network=%s", + len(policy["whitelist"]), + policy["sandbox_mode"], + policy["allow_network"], + ) + return policy + + +__all__ = ["HostSecurityPolicy", "load_host_policy", "DEFAULT_POLICY"] diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 6978eb717..4ed4a9250 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1,4 +1,5 @@ import math +import os from functools import partial from scipy import integrate @@ -12,8 +13,8 @@ from . import deis from . import sa_solver import comfy.model_patcher import comfy.model_sampling - import comfy.memory_management +from comfy.cli_args import args from comfy.utils import model_trange as trange def append_zero(x): @@ -191,6 +192,13 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) + isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1" + if isolation_active: + target_device = sigmas.device + if x.device != target_device: + x = x.to(target_device) + s_in = s_in.to(target_device) + for i in trange(len(sigmas) - 1, disable=disable): if s_churn > 0: gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. diff --git a/comfy/model_base.py b/comfy/model_base.py index 50dab5782..534aa33bd 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -20,6 +20,7 @@ import comfy.ldm.hunyuan3dv2_1 import comfy.ldm.hunyuan3dv2_1.hunyuandit import torch import logging +import os import comfy.ldm.lightricks.av_model import comfy.context_windows from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep @@ -120,8 +121,20 @@ def model_sampling(model_config, model_type): elif model_type == ModelType.V_PREDICTION_DDPM: c = comfy.model_sampling.V_PREDICTION_DDPM + from comfy.cli_args import args + isolation_runtime_enabled = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1" + class ModelSampling(s, c): - pass + if isolation_runtime_enabled: + def __reduce__(self): + """Ensure pickling yields a proxy instead of failing on local class.""" + try: + from comfy.isolation.model_sampling_proxy import ModelSamplingRegistry, ModelSamplingProxy + registry = ModelSamplingRegistry() + ms_id = registry.register(self) + return (ModelSamplingProxy, (ms_id,)) + except Exception as exc: + raise RuntimeError("Failed to serialize ModelSampling for isolation.") from exc return ModelSampling(model_config) diff --git a/comfy/model_management.py b/comfy/model_management.py index 95af40012..297eaf6a4 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -498,6 +498,9 @@ except: current_loaded_models = [] +def _isolation_mode_enabled(): + return args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1" + def module_size(module): module_mem = 0 sd = module.state_dict() @@ -604,8 +607,9 @@ class LoadedModel: if freed >= memory_to_free: return False self.model.detach(unpatch_weights) - self.model_finalizer.detach() - self.model_finalizer = None + if self.model_finalizer is not None: + self.model_finalizer.detach() + self.model_finalizer = None self.real_model = None return True @@ -619,8 +623,15 @@ class LoadedModel: if self._patcher_finalizer is not None: self._patcher_finalizer.detach() + def dead_state(self): + model_ref_gone = self.model is None + real_model_ref = self.real_model + real_model_ref_gone = callable(real_model_ref) and real_model_ref() is None + return model_ref_gone, real_model_ref_gone + def is_dead(self): - return self.real_model() is not None and self.model is None + model_ref_gone, real_model_ref_gone = self.dead_state() + return model_ref_gone or real_model_ref_gone def use_more_memory(extra_memory, loaded_models, device): @@ -667,6 +678,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins unloaded_model = [] can_unload = [] unloaded_models = [] + isolation_active = _isolation_mode_enabled() for i in range(len(current_loaded_models) -1, -1, -1): shift_model = current_loaded_models[i] @@ -675,6 +687,17 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i)) shift_model.currently_used = False + if can_unload and isolation_active: + try: + from pyisolate import flush_tensor_keeper # type: ignore[attr-defined] + except Exception: + flush_tensor_keeper = None + if callable(flush_tensor_keeper): + flushed = flush_tensor_keeper() + if flushed > 0: + logging.debug("][ MM:tensor_keeper_flush | released=%d", flushed) + gc.collect() + can_unload_sorted = sorted(can_unload) for x in can_unload_sorted: i = x[-1] @@ -705,7 +728,13 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}") for i in sorted(unloaded_model, reverse=True): - unloaded_models.append(current_loaded_models.pop(i)) + unloaded = current_loaded_models.pop(i) + model_obj = unloaded.model + if model_obj is not None: + cleanup = getattr(model_obj, "cleanup", None) + if callable(cleanup): + cleanup() + unloaded_models.append(unloaded) if len(unloaded_model) > 0: soft_empty_cache() @@ -764,7 +793,9 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu for i in to_unload: model_to_unload = current_loaded_models.pop(i) model_to_unload.model.detach(unpatch_all=False) - model_to_unload.model_finalizer.detach() + if model_to_unload.model_finalizer is not None: + model_to_unload.model_finalizer.detach() + model_to_unload.model_finalizer = None total_memory_required = {} @@ -837,25 +868,62 @@ def loaded_models(only_currently_used=False): def cleanup_models_gc(): - do_gc = False - reset_cast_buffers() + if not _isolation_mode_enabled(): + dead_found = False + for i in range(len(current_loaded_models)): + if current_loaded_models[i].is_dead(): + dead_found = True + break + if dead_found: + logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.") + gc.collect() + soft_empty_cache() + + for i in range(len(current_loaded_models) - 1, -1, -1): + cur = current_loaded_models[i] + if cur.is_dead(): + logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.") + leaked = current_loaded_models.pop(i) + model_obj = getattr(leaked, "model", None) + if model_obj is not None: + cleanup = getattr(model_obj, "cleanup", None) + if callable(cleanup): + cleanup() + return + + dead_found = False + has_real_model_leak = False for i in range(len(current_loaded_models)): - cur = current_loaded_models[i] - if cur.is_dead(): - logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__)) - do_gc = True - break + model_ref_gone, real_model_ref_gone = current_loaded_models[i].dead_state() + if model_ref_gone or real_model_ref_gone: + dead_found = True + if real_model_ref_gone and not model_ref_gone: + has_real_model_leak = True - if do_gc: + if dead_found: + if has_real_model_leak: + logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.") + else: + logging.debug("Cleaning stale loaded-model entries with released patcher references.") gc.collect() soft_empty_cache() - for i in range(len(current_loaded_models)): + for i in range(len(current_loaded_models) - 1, -1, -1): cur = current_loaded_models[i] - if cur.is_dead(): - logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__)) + model_ref_gone, real_model_ref_gone = cur.dead_state() + if model_ref_gone or real_model_ref_gone: + if real_model_ref_gone and not model_ref_gone: + logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.") + else: + logging.debug("Cleaning stale loaded-model entry with released patcher reference.") + leaked = current_loaded_models.pop(i) + model_obj = getattr(leaked, "model", None) + if model_obj is not None: + cleanup = getattr(model_obj, "cleanup", None) + if callable(cleanup): + cleanup() def archive_model_dtypes(model): @@ -869,11 +937,20 @@ def archive_model_dtypes(model): def cleanup_models(): to_delete = [] for i in range(len(current_loaded_models)): - if current_loaded_models[i].real_model() is None: + real_model_ref = current_loaded_models[i].real_model + if real_model_ref is None: + to_delete = [i] + to_delete + continue + if callable(real_model_ref) and real_model_ref() is None: to_delete = [i] + to_delete for i in to_delete: x = current_loaded_models.pop(i) + model_obj = getattr(x, "model", None) + if model_obj is not None: + cleanup = getattr(model_obj, "cleanup", None) + if callable(cleanup): + cleanup() del x def dtype_size(dtype): diff --git a/comfy/samplers.py b/comfy/samplers.py index 0a4d062db..35bea21d8 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -11,12 +11,14 @@ from functools import partial import collections import math import logging +import os import comfy.sampler_helpers import comfy.model_patcher import comfy.patcher_extension import comfy.hooks import comfy.context_windows import comfy.utils +from comfy.cli_args import args import scipy.stats import numpy @@ -210,9 +212,11 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc _calc_cond_batch, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True) ) - return executor.execute(model, conds, x_in, timestep, model_options) + result = executor.execute(model, conds, x_in, timestep, model_options) + return result def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): + isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1" out_conds = [] out_counts = [] # separate conds by matching hooks @@ -269,7 +273,8 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens for k, v in to_run[tt][0].conditioning.items(): cond_shapes[k].append(v.size()) - if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory: + memory_required = model.memory_required(input_shape, cond_shapes=cond_shapes) + if memory_required * 1.5 < free_memory: to_batch = batch_amount break @@ -294,9 +299,17 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens patches = p.patches batch_chunks = len(cond_or_uncond) - input_x = torch.cat(input_x) + if isolation_active: + target_device = model.load_device if hasattr(model, "load_device") else input_x[0].device + input_x = torch.cat(input_x).to(target_device) + else: + input_x = torch.cat(input_x) c = cond_cat(c) - timestep_ = torch.cat([timestep] * batch_chunks) + if isolation_active: + timestep_ = torch.cat([timestep] * batch_chunks).to(target_device) + mult = [m.to(target_device) if hasattr(m, "to") else m for m in mult] + else: + timestep_ = torch.cat([timestep] * batch_chunks) transformer_options = model.current_patcher.apply_hooks(hooks=hooks) if 'transformer_options' in model_options: @@ -327,9 +340,17 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens for o in range(batch_chunks): cond_index = cond_or_uncond[o] a = area[o] + out_t = output[o] + mult_t = mult[o] + if isolation_active: + target_dev = out_conds[cond_index].device + if hasattr(out_t, "device") and out_t.device != target_dev: + out_t = out_t.to(target_dev) + if hasattr(mult_t, "device") and mult_t.device != target_dev: + mult_t = mult_t.to(target_dev) if a is None: - out_conds[cond_index] += output[o] * mult[o] - out_counts[cond_index] += mult[o] + out_conds[cond_index] += out_t * mult_t + out_counts[cond_index] += mult_t else: out_c = out_conds[cond_index] out_cts = out_counts[cond_index] @@ -337,8 +358,8 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens for i in range(dims): out_c = out_c.narrow(i + 2, a[i + dims], a[i]) out_cts = out_cts.narrow(i + 2, a[i + dims], a[i]) - out_c += output[o] * mult[o] - out_cts += mult[o] + out_c += out_t * mult_t + out_cts += mult_t for i in range(len(out_conds)): out_conds[i] /= out_counts[i] @@ -392,14 +413,31 @@ class KSamplerX0Inpaint: self.inner_model = model self.sigmas = sigmas def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None): + isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1" if denoise_mask is not None: + if isolation_active and denoise_mask.device != x.device: + denoise_mask = denoise_mask.to(x.device) if "denoise_mask_function" in model_options: denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas}) latent_mask = 1. - denoise_mask - x = x * denoise_mask + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image) * latent_mask + if isolation_active: + latent_image = self.latent_image + if hasattr(latent_image, "device") and latent_image.device != x.device: + latent_image = latent_image.to(x.device) + scaled = self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=latent_image) + if hasattr(scaled, "device") and scaled.device != x.device: + scaled = scaled.to(x.device) + else: + scaled = self.inner_model.inner_model.scale_latent_inpaint( + x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image + ) + x = x * denoise_mask + scaled * latent_mask out = self.inner_model(x, sigma, model_options=model_options, seed=seed) if denoise_mask is not None: - out = out * denoise_mask + self.latent_image * latent_mask + latent_image = self.latent_image + if isolation_active and hasattr(latent_image, "device") and latent_image.device != out.device: + latent_image = latent_image.to(out.device) + out = out * denoise_mask + latent_image * latent_mask return out def simple_scheduler(model_sampling, steps): @@ -741,7 +779,11 @@ class KSAMPLER(Sampler): else: model_k.noise = noise - noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas[0], noise, latent_image, self.max_denoise(model_wrap, sigmas)) + max_denoise = self.max_denoise(model_wrap, sigmas) + model_sampling = model_wrap.inner_model.model_sampling + noise = model_sampling.noise_scaling( + sigmas[0], noise, latent_image, max_denoise + ) k_callback = None total_steps = len(sigmas) - 1 diff --git a/cuda_malloc.py b/cuda_malloc.py index f7651981c..f6d2063e9 100644 --- a/cuda_malloc.py +++ b/cuda_malloc.py @@ -92,7 +92,7 @@ if args.cuda_malloc: env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None) if env_var is None: env_var = "backend:cudaMallocAsync" - else: + elif not args.use_process_isolation: env_var += ",backend:cudaMallocAsync" os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var