From 7578d1989f72b7c7d5fdc92d76966e6f2e077474 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Fri, 1 May 2026 00:35:43 +0000 Subject: [PATCH] Add workflow_id to all websocket messages This commit addresses BE-672 by ensuring all execution-related websocket messages include the workflow_id field when available. Changes: - Added extract_workflow_id() helper function in comfy_execution/jobs.py to extract workflow_id from extra_data - Updated execution.py to include workflow_id in all websocket messages: - execution_start - execution_cached - execution_success - execution_error - execution_interrupted - executing - executed (including cached UI) - Updated main.py to include workflow_id in: - progress messages (via hijack_progress hook) - final executing message (node=None) - Updated comfy_execution/progress.py to include workflow_id in: - progress_state messages - preview image metadata The workflow_id is extracted from extra_data['extra_pnginfo']['workflow']['id'] and is conditionally included in messages only when present, maintaining backward compatibility with workflows that don't have this field. Fixes: BE-672 Co-authored-by: Luke Mino-Altherr --- comfy_execution/jobs.py | 15 ++++++++++ comfy_execution/progress.py | 14 ++++++--- execution.py | 60 ++++++++++++++++++++++++++----------- main.py | 13 +++++++- 4 files changed, 80 insertions(+), 22 deletions(-) diff --git a/comfy_execution/jobs.py b/comfy_execution/jobs.py index fcd7ef735..dbf0fb17d 100644 --- a/comfy_execution/jobs.py +++ b/comfy_execution/jobs.py @@ -105,6 +105,21 @@ def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str return create_time, workflow_id +def extract_workflow_id(extra_data: dict) -> Optional[str]: + """Extract workflow_id from extra_data. + + Args: + extra_data: The extra_data dict containing workflow information + + Returns: + The workflow_id if present, otherwise None + """ + if not extra_data: + return None + extra_pnginfo = extra_data.get('extra_pnginfo', {}) + return extra_pnginfo.get('workflow', {}).get('id') + + def is_previewable(media_type: str, item: dict) -> bool: """ Check if an output item is previewable. diff --git a/comfy_execution/progress.py b/comfy_execution/progress.py index f951a3350..c721f8722 100644 --- a/comfy_execution/progress.py +++ b/comfy_execution/progress.py @@ -182,8 +182,11 @@ class WebUIProgressHandler(ProgressHandler): # Send a combined progress_state message with all node states # Include client_id to ensure message is only sent to the initiating client + message = {"prompt_id": prompt_id, "nodes": active_nodes} + if self.registry.workflow_id is not None: + message["workflow_id"] = self.registry.workflow_id self.server_instance.send_sync( - "progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}, self.server_instance.client_id + "progress_state", message, self.server_instance.client_id ) @override @@ -223,6 +226,8 @@ class WebUIProgressHandler(ProgressHandler): ), "real_node_id": self.registry.dynprompt.get_real_node_id(node_id), } + if self.registry.workflow_id is not None: + metadata["workflow_id"] = self.registry.workflow_id self.server_instance.send_sync( BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA, (image, metadata), @@ -240,9 +245,10 @@ class ProgressRegistry: Registry that maintains node progress state and notifies registered handlers. """ - def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt"): + def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt", workflow_id: Optional[str] = None): self.prompt_id = prompt_id self.dynprompt = dynprompt + self.workflow_id = workflow_id self.nodes: Dict[str, NodeProgressState] = {} self.handlers: Dict[str, ProgressHandler] = {} @@ -322,7 +328,7 @@ class ProgressRegistry: # Global registry instance global_progress_registry: ProgressRegistry | None = None -def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None: +def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt", workflow_id: Optional[str] = None) -> None: global global_progress_registry # Reset existing handlers if registry exists @@ -330,7 +336,7 @@ def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None: global_progress_registry.reset_handlers() # Create new registry - global_progress_registry = ProgressRegistry(prompt_id, dynprompt) + global_progress_registry = ProgressRegistry(prompt_id, dynprompt, workflow_id) def add_progress_handler(handler: ProgressHandler) -> None: diff --git a/execution.py b/execution.py index 5a6d3404c..4f4c7913e 100644 --- a/execution.py +++ b/execution.py @@ -41,6 +41,7 @@ from comfy_execution.utils import CurrentNodeContext from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func from comfy_api.latest import io, _io from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger +from comfy_execution.jobs import extract_workflow_id class ExecutionResult(Enum): @@ -416,15 +417,18 @@ def _is_intermediate_output(dynprompt, node_id): class_def = nodes.NODE_CLASS_MAPPINGS[class_type] return getattr(class_def, 'HAS_INTERMEDIATE_OUTPUT', False) -def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, ui_outputs): +def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, ui_outputs, workflow_id=None): if server.client_id is None: return cached_ui = cached.ui or {} - server.send_sync("executed", { "node": node_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id }, server.client_id) + message = { "node": node_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id } + if workflow_id is not None: + message["workflow_id"] = workflow_id + server.send_sync("executed", message, server.client_id) if cached.ui is not None: ui_outputs[node_id] = cached.ui -async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs): +async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs, workflow_id=None): unique_id = current_item real_node_id = dynprompt.get_real_node_id(unique_id) display_node_id = dynprompt.get_display_node_id(unique_id) @@ -434,7 +438,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, class_def = nodes.NODE_CLASS_MAPPINGS[class_type] cached = await caches.outputs.get(unique_id) if cached is not None: - _send_cached_ui(server, unique_id, display_node_id, cached, prompt_id, ui_outputs) + _send_cached_ui(server, unique_id, display_node_id, cached, prompt_id, ui_outputs, workflow_id) get_progress_state().finish_progress(unique_id) execution_list.cache_update(unique_id, cached) return (ExecutionResult.SUCCESS, None, None) @@ -482,7 +486,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data) if server.client_id is not None: server.last_node_id = display_node_id - server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) + message = { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id } + if workflow_id is not None: + message["workflow_id"] = workflow_id + server.send_sync("executing", message, server.client_id) obj = await caches.objects.get(unique_id) if obj is None: @@ -522,6 +529,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, "current_inputs": [], "current_outputs": [], } + if workflow_id is not None: + mes["workflow_id"] = workflow_id server.send_sync("execution_error", mes, server.client_id) return ExecutionBlocker(None) else: @@ -559,7 +568,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, "output": output_ui } if server.client_id is not None: - server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) + message = { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id } + if workflow_id is not None: + message["workflow_id"] = workflow_id + server.send_sync("executed", message, server.client_id) if has_subgraph: cached_outputs = [] new_node_ids = [] @@ -666,7 +678,7 @@ class PromptExecutor: if self.server.client_id is not None or broadcast: self.server.send_sync(event, data, self.server.client_id) - def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex): + def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex, workflow_id=None): node_id = error["node_id"] class_type = prompt[node_id]["class_type"] @@ -679,6 +691,8 @@ class PromptExecutor: "node_type": class_type, "executed": list(executed), } + if workflow_id is not None: + mes["workflow_id"] = workflow_id self.add_message("execution_interrupted", mes, broadcast=True) else: mes = { @@ -692,6 +706,8 @@ class PromptExecutor: "current_inputs": error["current_inputs"], "current_outputs": list(current_outputs), } + if workflow_id is not None: + mes["workflow_id"] = workflow_id self.add_message("execution_error", mes, broadcast=False) def _notify_prompt_lifecycle(self, event: str, prompt_id: str): @@ -720,8 +736,14 @@ class PromptExecutor: else: self.server.client_id = None + # Extract workflow_id from extra_data + workflow_id = extract_workflow_id(extra_data) + self.status_messages = [] - self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) + execution_start_msg = { "prompt_id": prompt_id } + if workflow_id is not None: + execution_start_msg["workflow_id"] = workflow_id + self.add_message("execution_start", execution_start_msg, broadcast=False) self._notify_prompt_lifecycle("start", prompt_id) ram_headroom = int(self.cache_args["ram"] * (1024 ** 3)) @@ -731,7 +753,7 @@ class PromptExecutor: try: with torch.inference_mode(): dynamic_prompt = DynamicPrompt(prompt) - reset_progress_state(prompt_id, dynamic_prompt) + reset_progress_state(prompt_id, dynamic_prompt, workflow_id) add_progress_handler(WebUIProgressHandler(self.server)) is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs) for cache in self.caches.all: @@ -748,9 +770,10 @@ class PromptExecutor: ] comfy.model_management.cleanup_models_gc() - self.add_message("execution_cached", - { "nodes": cached_nodes, "prompt_id": prompt_id}, - broadcast=False) + execution_cached_msg = { "nodes": cached_nodes, "prompt_id": prompt_id } + if workflow_id is not None: + execution_cached_msg["workflow_id"] = workflow_id + self.add_message("execution_cached", execution_cached_msg, broadcast=False) pending_subgraph_results = {} pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results ui_node_outputs = {} @@ -763,14 +786,14 @@ class PromptExecutor: while not execution_list.is_empty(): node_id, error, ex = await execution_list.stage_node_execution() if error is not None: - self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) + self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex, workflow_id) break assert node_id is not None, "Node ID should not be None at this point" - result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs) + result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs, workflow_id) self.success = result != ExecutionResult.FAILURE if result == ExecutionResult.FAILURE: - self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) + self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex, workflow_id) break elif result == ExecutionResult.PENDING: execution_list.unstage_node_execution() @@ -791,8 +814,11 @@ class PromptExecutor: cached = await self.caches.outputs.get(node_id) if cached is not None: display_node_id = dynamic_prompt.get_display_node_id(node_id) - _send_cached_ui(self.server, node_id, display_node_id, cached, prompt_id, ui_node_outputs) - self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) + _send_cached_ui(self.server, node_id, display_node_id, cached, prompt_id, ui_node_outputs, workflow_id) + execution_success_msg = { "prompt_id": prompt_id } + if workflow_id is not None: + execution_success_msg["workflow_id"] = workflow_id + self.add_message("execution_success", execution_success_msg, broadcast=False) ui_outputs = {} meta_outputs = {} diff --git a/main.py b/main.py index dbaf2745c..53cef5e5e 100644 --- a/main.py +++ b/main.py @@ -21,6 +21,7 @@ import logging import sys from comfy_execution.progress import get_progress_state from comfy_execution.utils import get_executing_context +from comfy_execution.jobs import extract_workflow_id from comfy_api import feature_flags from app.database.db import init_db, dependencies_available @@ -322,7 +323,11 @@ def prompt_worker(q, server_instance): completed=e.success, messages=e.status_messages), process_item=remove_sensitive) if server_instance.client_id is not None: - server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id) + workflow_id = extract_workflow_id(extra_data) + executing_msg = {"node": None, "prompt_id": prompt_id} + if workflow_id is not None: + executing_msg["workflow_id"] = workflow_id + server_instance.send_sync("executing", executing_msg, server_instance.client_id) current_time = time.perf_counter() execution_time = current_time - execution_start_time @@ -386,6 +391,12 @@ def hijack_progress(server_instance): if node_id is None: node_id = server_instance.last_node_id progress = {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id} + + # Add workflow_id if available from progress state + progress_state = get_progress_state() + if hasattr(progress_state, 'workflow_id') and progress_state.workflow_id is not None: + progress["workflow_id"] = progress_state.workflow_id + get_progress_state().update_progress(node_id, value, total, preview_image) server_instance.send_sync("progress", progress, server_instance.client_id)