# pylint: disable=import-outside-toplevel from __future__ import annotations import asyncio import logging from typing import Any from comfy.isolation.proxies.base import ( BaseProxy, BaseRegistry, detach_if_grad, get_thread_loop, run_coro_in_new_loop, ) logger = logging.getLogger(__name__) def _prefer_device(*tensors: Any) -> Any: try: import torch except Exception: return None for t in tensors: if isinstance(t, torch.Tensor) and t.is_cuda: return t.device for t in tensors: if isinstance(t, torch.Tensor): return t.device return None def _to_device(obj: Any, device: Any) -> Any: try: import torch except Exception: return obj if device is None: return obj if isinstance(obj, torch.Tensor): if obj.device != device: return obj.to(device) return obj if isinstance(obj, (list, tuple)): converted = [_to_device(x, device) for x in obj] return type(obj)(converted) if isinstance(obj, tuple) else converted if isinstance(obj, dict): return {k: _to_device(v, device) for k, v in obj.items()} return obj class ModelSamplingRegistry(BaseRegistry[Any]): _type_prefix = "modelsampling" async def calculate_input(self, instance_id: str, sigma: Any, noise: Any) -> Any: sampling = self._get_instance(instance_id) return detach_if_grad(sampling.calculate_input(sigma, noise)) async def calculate_denoised( self, instance_id: str, sigma: Any, model_output: Any, model_input: Any ) -> Any: sampling = self._get_instance(instance_id) return detach_if_grad( sampling.calculate_denoised(sigma, model_output, model_input) ) async def noise_scaling( self, instance_id: str, sigma: Any, noise: Any, latent_image: Any, max_denoise: bool = False, ) -> Any: sampling = self._get_instance(instance_id) return detach_if_grad( sampling.noise_scaling(sigma, noise, latent_image, max_denoise=max_denoise) ) async def inverse_noise_scaling( self, instance_id: str, sigma: Any, latent: Any ) -> Any: sampling = self._get_instance(instance_id) return detach_if_grad(sampling.inverse_noise_scaling(sigma, latent)) async def timestep(self, instance_id: str, sigma: Any) -> Any: sampling = self._get_instance(instance_id) return sampling.timestep(sigma) async def sigma(self, instance_id: str, timestep: Any) -> Any: sampling = self._get_instance(instance_id) return sampling.sigma(timestep) async def percent_to_sigma(self, instance_id: str, percent: float) -> Any: sampling = self._get_instance(instance_id) return sampling.percent_to_sigma(percent) async def get_sigma_min(self, instance_id: str) -> Any: sampling = self._get_instance(instance_id) return detach_if_grad(sampling.sigma_min) async def get_sigma_max(self, instance_id: str) -> Any: sampling = self._get_instance(instance_id) return detach_if_grad(sampling.sigma_max) async def get_sigma_data(self, instance_id: str) -> Any: sampling = self._get_instance(instance_id) return detach_if_grad(sampling.sigma_data) async def get_sigmas(self, instance_id: str) -> Any: sampling = self._get_instance(instance_id) return detach_if_grad(sampling.sigmas) async def set_sigmas(self, instance_id: str, sigmas: Any) -> None: sampling = self._get_instance(instance_id) sampling.set_sigmas(sigmas) class ModelSamplingProxy(BaseProxy[ModelSamplingRegistry]): _registry_class = ModelSamplingRegistry __module__ = "comfy.isolation.model_sampling_proxy" def _get_rpc(self) -> Any: if self._rpc_caller is None: from pyisolate._internal.rpc_protocol import get_child_rpc_instance rpc = get_child_rpc_instance() if rpc is not None: self._rpc_caller = rpc.create_caller( ModelSamplingRegistry, ModelSamplingRegistry.get_remote_id() ) else: registry = ModelSamplingRegistry() class _LocalCaller: def calculate_input( self, instance_id: str, sigma: Any, noise: Any ) -> Any: return registry.calculate_input(instance_id, sigma, noise) def calculate_denoised( self, instance_id: str, sigma: Any, model_output: Any, model_input: Any, ) -> Any: return registry.calculate_denoised( instance_id, sigma, model_output, model_input ) def noise_scaling( self, instance_id: str, sigma: Any, noise: Any, latent_image: Any, max_denoise: bool = False, ) -> Any: return registry.noise_scaling( instance_id, sigma, noise, latent_image, max_denoise ) def inverse_noise_scaling( self, instance_id: str, sigma: Any, latent: Any ) -> Any: return registry.inverse_noise_scaling( instance_id, sigma, latent ) def timestep(self, instance_id: str, sigma: Any) -> Any: return registry.timestep(instance_id, sigma) def sigma(self, instance_id: str, timestep: Any) -> Any: return registry.sigma(instance_id, timestep) def percent_to_sigma(self, instance_id: str, percent: float) -> Any: return registry.percent_to_sigma(instance_id, percent) def get_sigma_min(self, instance_id: str) -> Any: return registry.get_sigma_min(instance_id) def get_sigma_max(self, instance_id: str) -> Any: return registry.get_sigma_max(instance_id) def get_sigma_data(self, instance_id: str) -> Any: return registry.get_sigma_data(instance_id) def get_sigmas(self, instance_id: str) -> Any: return registry.get_sigmas(instance_id) def set_sigmas(self, instance_id: str, sigmas: Any) -> None: return registry.set_sigmas(instance_id, sigmas) self._rpc_caller = _LocalCaller() return self._rpc_caller def _call(self, method_name: str, *args: Any) -> Any: rpc = self._get_rpc() method = getattr(rpc, method_name) result = method(self._instance_id, *args) if asyncio.iscoroutine(result): try: asyncio.get_running_loop() return run_coro_in_new_loop(result) except RuntimeError: loop = get_thread_loop() return loop.run_until_complete(result) return result @property def sigma_min(self) -> Any: return self._call("get_sigma_min") @property def sigma_max(self) -> Any: return self._call("get_sigma_max") @property def sigma_data(self) -> Any: return self._call("get_sigma_data") @property def sigmas(self) -> Any: return self._call("get_sigmas") def calculate_input(self, sigma: Any, noise: Any) -> Any: return self._call("calculate_input", sigma, noise) def calculate_denoised( self, sigma: Any, model_output: Any, model_input: Any ) -> Any: return self._call("calculate_denoised", sigma, model_output, model_input) def noise_scaling( self, sigma: Any, noise: Any, latent_image: Any, max_denoise: bool = False ) -> Any: return self._call("noise_scaling", sigma, noise, latent_image, max_denoise) def inverse_noise_scaling(self, sigma: Any, latent: Any) -> Any: return self._call("inverse_noise_scaling", sigma, latent) def timestep(self, sigma: Any) -> Any: return self._call("timestep", sigma) def sigma(self, timestep: Any) -> Any: return self._call("sigma", timestep) def percent_to_sigma(self, percent: float) -> Any: return self._call("percent_to_sigma", percent) def set_sigmas(self, sigmas: Any) -> None: return self._call("set_sigmas", sigmas)