diff --git a/comfy_execution/jobs.py b/comfy_execution/jobs.py index fcd7ef735..24dd1ffd0 100644 --- a/comfy_execution/jobs.py +++ b/comfy_execution/jobs.py @@ -93,6 +93,27 @@ def _create_text_preview(value: str) -> dict: } +def extract_workflow_id(extra_data: Optional[dict]) -> Optional[str]: + """Extract the workflow id from a prompt's ``extra_data``. + + The frontend stores the id at ``extra_data["extra_pnginfo"]["workflow"]["id"]`` + when a prompt is queued. Any value that is not a non-empty string is treated as + missing so callers can rely on the return being either ``None`` or a string. + """ + if not isinstance(extra_data, dict): + return None + extra_pnginfo = extra_data.get('extra_pnginfo') + if not isinstance(extra_pnginfo, dict): + return None + workflow = extra_pnginfo.get('workflow') + if not isinstance(workflow, dict): + return None + workflow_id = workflow.get('id') + if isinstance(workflow_id, str) and workflow_id: + return workflow_id + return None + + def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]: """Extract create_time and workflow_id from extra_data. @@ -100,8 +121,7 @@ def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str tuple: (create_time, workflow_id) """ create_time = extra_data.get('create_time') - extra_pnginfo = extra_data.get('extra_pnginfo', {}) - workflow_id = extra_pnginfo.get('workflow', {}).get('id') + workflow_id = extract_workflow_id(extra_data) return create_time, workflow_id diff --git a/comfy_execution/progress.py b/comfy_execution/progress.py index f951a3350..b6d3bd3e4 100644 --- a/comfy_execution/progress.py +++ b/comfy_execution/progress.py @@ -164,6 +164,8 @@ class WebUIProgressHandler(ProgressHandler): if self.server_instance is None: return + workflow_id = self.registry.workflow_id if self.registry else None + # Only send info for non-pending nodes active_nodes = { node_id: { @@ -172,6 +174,7 @@ class WebUIProgressHandler(ProgressHandler): "state": state["state"].value, "node_id": node_id, "prompt_id": prompt_id, + "workflow_id": workflow_id, "display_node_id": self.registry.dynprompt.get_display_node_id(node_id), "parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id), "real_node_id": self.registry.dynprompt.get_real_node_id(node_id), @@ -183,7 +186,7 @@ class WebUIProgressHandler(ProgressHandler): # Send a combined progress_state message with all node states # Include client_id to ensure message is only sent to the initiating client self.server_instance.send_sync( - "progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}, self.server_instance.client_id + "progress_state", {"prompt_id": prompt_id, "workflow_id": workflow_id, "nodes": active_nodes}, self.server_instance.client_id ) @override @@ -215,6 +218,7 @@ class WebUIProgressHandler(ProgressHandler): metadata = { "node_id": node_id, "prompt_id": prompt_id, + "workflow_id": self.registry.workflow_id if self.registry else None, "display_node_id": self.registry.dynprompt.get_display_node_id( node_id ), @@ -240,9 +244,10 @@ class ProgressRegistry: Registry that maintains node progress state and notifies registered handlers. """ - def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt"): + def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt", workflow_id: Optional[str] = None): self.prompt_id = prompt_id self.dynprompt = dynprompt + self.workflow_id = workflow_id self.nodes: Dict[str, NodeProgressState] = {} self.handlers: Dict[str, ProgressHandler] = {} @@ -322,7 +327,7 @@ class ProgressRegistry: # Global registry instance global_progress_registry: ProgressRegistry | None = None -def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None: +def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt", workflow_id: Optional[str] = None) -> None: global global_progress_registry # Reset existing handlers if registry exists @@ -330,7 +335,7 @@ def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None: global_progress_registry.reset_handlers() # Create new registry - global_progress_registry = ProgressRegistry(prompt_id, dynprompt) + global_progress_registry = ProgressRegistry(prompt_id, dynprompt, workflow_id) def add_progress_handler(handler: ProgressHandler) -> None: diff --git a/execution.py b/execution.py index 5a6d3404c..4f6b5f2c5 100644 --- a/execution.py +++ b/execution.py @@ -37,6 +37,7 @@ from comfy_execution.graph import ( from comfy_execution.graph_utils import GraphBuilder, is_link from comfy_execution.validation import validate_node_input from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler +from comfy_execution.jobs import extract_workflow_id from comfy_execution.utils import CurrentNodeContext from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func from comfy_api.latest import io, _io @@ -416,15 +417,15 @@ def _is_intermediate_output(dynprompt, node_id): class_def = nodes.NODE_CLASS_MAPPINGS[class_type] return getattr(class_def, 'HAS_INTERMEDIATE_OUTPUT', False) -def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, ui_outputs): +def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, workflow_id, ui_outputs): if server.client_id is None: return cached_ui = cached.ui or {} - server.send_sync("executed", { "node": node_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id }, server.client_id) + server.send_sync("executed", { "node": node_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id, "workflow_id": workflow_id }, server.client_id) if cached.ui is not None: ui_outputs[node_id] = cached.ui -async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs): +async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, workflow_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs): unique_id = current_item real_node_id = dynprompt.get_real_node_id(unique_id) display_node_id = dynprompt.get_display_node_id(unique_id) @@ -434,7 +435,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, class_def = nodes.NODE_CLASS_MAPPINGS[class_type] cached = await caches.outputs.get(unique_id) if cached is not None: - _send_cached_ui(server, unique_id, display_node_id, cached, prompt_id, ui_outputs) + _send_cached_ui(server, unique_id, display_node_id, cached, prompt_id, workflow_id, ui_outputs) get_progress_state().finish_progress(unique_id) execution_list.cache_update(unique_id, cached) return (ExecutionResult.SUCCESS, None, None) @@ -482,7 +483,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data) if server.client_id is not None: server.last_node_id = display_node_id - server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) + server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id, "workflow_id": workflow_id }, server.client_id) obj = await caches.objects.get(unique_id) if obj is None: @@ -512,6 +513,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, if block.message is not None: mes = { "prompt_id": prompt_id, + "workflow_id": workflow_id, "node_id": unique_id, "node_type": class_type, "executed": list(executed), @@ -559,7 +561,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, "output": output_ui } if server.client_id is not None: - server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) + server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id, "workflow_id": workflow_id }, server.client_id) if has_subgraph: cached_outputs = [] new_node_ids = [] @@ -656,6 +658,7 @@ class PromptExecutor: self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args) self.status_messages = [] self.success = True + self.workflow_id = None def add_message(self, event, data: dict, broadcast: bool): data = { @@ -675,6 +678,7 @@ class PromptExecutor: if isinstance(ex, comfy.model_management.InterruptProcessingException): mes = { "prompt_id": prompt_id, + "workflow_id": self.workflow_id, "node_id": node_id, "node_type": class_type, "executed": list(executed), @@ -683,6 +687,7 @@ class PromptExecutor: else: mes = { "prompt_id": prompt_id, + "workflow_id": self.workflow_id, "node_id": node_id, "node_type": class_type, "executed": list(executed), @@ -721,7 +726,9 @@ class PromptExecutor: self.server.client_id = None self.status_messages = [] - self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) + self.workflow_id = extract_workflow_id(extra_data) + self.server.last_workflow_id = self.workflow_id + self.add_message("execution_start", { "prompt_id": prompt_id, "workflow_id": self.workflow_id }, broadcast=False) self._notify_prompt_lifecycle("start", prompt_id) ram_headroom = int(self.cache_args["ram"] * (1024 ** 3)) @@ -731,7 +738,7 @@ class PromptExecutor: try: with torch.inference_mode(): dynamic_prompt = DynamicPrompt(prompt) - reset_progress_state(prompt_id, dynamic_prompt) + reset_progress_state(prompt_id, dynamic_prompt, self.workflow_id) add_progress_handler(WebUIProgressHandler(self.server)) is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs) for cache in self.caches.all: @@ -749,7 +756,7 @@ class PromptExecutor: comfy.model_management.cleanup_models_gc() self.add_message("execution_cached", - { "nodes": cached_nodes, "prompt_id": prompt_id}, + { "nodes": cached_nodes, "prompt_id": prompt_id, "workflow_id": self.workflow_id }, broadcast=False) pending_subgraph_results = {} pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results @@ -767,7 +774,7 @@ class PromptExecutor: break assert node_id is not None, "Node ID should not be None at this point" - result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs) + result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, self.workflow_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs) self.success = result != ExecutionResult.FAILURE if result == ExecutionResult.FAILURE: self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) @@ -791,8 +798,8 @@ class PromptExecutor: cached = await self.caches.outputs.get(node_id) if cached is not None: display_node_id = dynamic_prompt.get_display_node_id(node_id) - _send_cached_ui(self.server, node_id, display_node_id, cached, prompt_id, ui_node_outputs) - self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) + _send_cached_ui(self.server, node_id, display_node_id, cached, prompt_id, self.workflow_id, ui_node_outputs) + self.add_message("execution_success", { "prompt_id": prompt_id, "workflow_id": self.workflow_id }, broadcast=False) ui_outputs = {} meta_outputs = {} diff --git a/main.py b/main.py index dbaf2745c..1ccaf24c1 100644 --- a/main.py +++ b/main.py @@ -322,7 +322,7 @@ def prompt_worker(q, server_instance): completed=e.success, messages=e.status_messages), process_item=remove_sensitive) if server_instance.client_id is not None: - server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id) + server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id, "workflow_id": getattr(server_instance, 'last_workflow_id', None)}, server_instance.client_id) current_time = time.perf_counter() execution_time = current_time - execution_start_time @@ -385,7 +385,7 @@ def hijack_progress(server_instance): prompt_id = server_instance.last_prompt_id if node_id is None: node_id = server_instance.last_node_id - progress = {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id} + progress = {"value": value, "max": total, "prompt_id": prompt_id, "workflow_id": getattr(server_instance, 'last_workflow_id', None), "node": node_id} get_progress_state().update_progress(node_id, value, total, preview_image) server_instance.send_sync("progress", progress, server_instance.client_id) diff --git a/tests-unit/execution_test/test_workflow_id_in_ws_messages.py b/tests-unit/execution_test/test_workflow_id_in_ws_messages.py new file mode 100644 index 000000000..0645c02c0 --- /dev/null +++ b/tests-unit/execution_test/test_workflow_id_in_ws_messages.py @@ -0,0 +1,111 @@ +"""Tests that workflow_id is included alongside prompt_id in WebSocket payloads +emitted by the progress handler and the prompt executor. + +Frontend stores extra_data["extra_pnginfo"]["workflow"]["id"] when queueing a +prompt; we propagate that as `workflow_id` on every execution event so a +multi-tab UI can scope progress state by workflow even when terminal +WebSocket frames are dropped. +""" + +from unittest.mock import MagicMock + +import pytest + +from comfy_execution.progress import ( + NodeState, + ProgressRegistry, + WebUIProgressHandler, + reset_progress_state, + get_progress_state, +) + + +class _DummyDynPrompt: + def get_display_node_id(self, node_id): + return node_id + + def get_parent_node_id(self, node_id): + return None + + def get_real_node_id(self, node_id): + return node_id + + +@pytest.fixture +def server(): + s = MagicMock() + s.client_id = "client-1" + return s + + +def _registry(workflow_id): + return ProgressRegistry( + prompt_id="prompt-1", + dynprompt=_DummyDynPrompt(), + workflow_id=workflow_id, + ) + + +class TestProgressStatePayload: + def test_progress_state_includes_workflow_id(self, server): + registry = _registry("wf-abc") + registry.nodes["n1"] = { + "state": NodeState.Running, + "value": 1.0, + "max": 5.0, + } + + handler = WebUIProgressHandler(server) + handler.set_registry(registry) + handler._send_progress_state("prompt-1", registry.nodes) + + server.send_sync.assert_called_once() + event, payload, sid = server.send_sync.call_args.args + assert event == "progress_state" + assert payload["prompt_id"] == "prompt-1" + assert payload["workflow_id"] == "wf-abc" + assert payload["nodes"]["n1"]["workflow_id"] == "wf-abc" + assert payload["nodes"]["n1"]["prompt_id"] == "prompt-1" + assert sid == "client-1" + + def test_progress_state_workflow_id_none_when_missing(self, server): + registry = _registry(None) + registry.nodes["n1"] = { + "state": NodeState.Running, + "value": 0.5, + "max": 1.0, + } + + handler = WebUIProgressHandler(server) + handler.set_registry(registry) + handler._send_progress_state("prompt-1", registry.nodes) + + _, payload, _ = server.send_sync.call_args.args + assert payload["workflow_id"] is None + assert payload["nodes"]["n1"]["workflow_id"] is None + + +class TestProgressRegistryConstruction: + def test_workflow_id_default_is_none(self): + registry = ProgressRegistry( + prompt_id="prompt-1", dynprompt=_DummyDynPrompt() + ) + assert registry.workflow_id is None + + def test_workflow_id_stored_on_registry(self): + registry = ProgressRegistry( + prompt_id="prompt-1", + dynprompt=_DummyDynPrompt(), + workflow_id="wf-xyz", + ) + assert registry.workflow_id == "wf-xyz" + + +class TestResetProgressState: + def test_reset_threads_workflow_id(self): + reset_progress_state("prompt-1", _DummyDynPrompt(), "wf-456") + assert get_progress_state().workflow_id == "wf-456" + + def test_reset_default_workflow_id_none(self): + reset_progress_state("prompt-2", _DummyDynPrompt()) + assert get_progress_state().workflow_id is None diff --git a/tests/execution/test_jobs.py b/tests/execution/test_jobs.py index 814af5c13..6afa6cd9c 100644 --- a/tests/execution/test_jobs.py +++ b/tests/execution/test_jobs.py @@ -10,9 +10,44 @@ from comfy_execution.jobs import ( get_outputs_summary, apply_sorting, has_3d_extension, + extract_workflow_id, ) +class TestExtractWorkflowId: + """Unit tests for extract_workflow_id().""" + + def test_returns_id_from_extra_pnginfo(self): + assert extract_workflow_id({'extra_pnginfo': {'workflow': {'id': 'wf-123'}}}) == 'wf-123' + + def test_missing_extra_data_returns_none(self): + assert extract_workflow_id(None) is None + + def test_non_dict_extra_data_returns_none(self): + assert extract_workflow_id('not-a-dict') is None + + def test_missing_extra_pnginfo_returns_none(self): + assert extract_workflow_id({}) is None + + def test_missing_workflow_returns_none(self): + assert extract_workflow_id({'extra_pnginfo': {}}) is None + + def test_missing_id_returns_none(self): + assert extract_workflow_id({'extra_pnginfo': {'workflow': {}}}) is None + + def test_empty_string_id_returns_none(self): + assert extract_workflow_id({'extra_pnginfo': {'workflow': {'id': ''}}}) is None + + def test_non_string_id_returns_none(self): + assert extract_workflow_id({'extra_pnginfo': {'workflow': {'id': 42}}}) is None + + def test_non_dict_workflow_returns_none(self): + assert extract_workflow_id({'extra_pnginfo': {'workflow': 'not-a-dict'}}) is None + + def test_non_dict_extra_pnginfo_returns_none(self): + assert extract_workflow_id({'extra_pnginfo': 'not-a-dict'}) is None + + class TestJobStatus: """Test JobStatus constants."""