Include workflow_id in all execution WebSocket messages

The frontend already stores extra_data['extra_pnginfo']['workflow']['id']
when queueing a prompt and exposes it via the /api/jobs REST endpoint, but
none of the WebSocket events emitted during execution carry it. That makes
it impossible to scope progress state by workflow on the client without
maintaining a job_id -> workflow_id mapping that races with execution_start.

This adds workflow_id alongside prompt_id on every execution event:

- execution_start, execution_success, execution_error, execution_interrupted,
  execution_cached, executing, executed
- progress and progress_state
- the metadata block on PREVIEW_IMAGE_WITH_METADATA

A new public extract_workflow_id helper in comfy_execution/jobs.py is the
single source of truth for the lookup; the existing _extract_job_metadata
delegates to it. The id is plumbed through PromptExecutor (stored as
self.workflow_id and on server.last_workflow_id), the module-level
execute() coroutine, the _send_cached_ui helper, and ProgressRegistry /
reset_progress_state so WebUIProgressHandler can include it in
progress_state and preview-image metadata. The progress hook in main.py
reads server.last_workflow_id to populate the legacy 'progress' event.

Tests cover the helper's edge cases (missing/non-string ids, non-dict
levels) and that the WebUIProgressHandler emits workflow_id on every
progress_state payload via mocked PromptServer.
This commit is contained in:
Glary-Bot 2026-05-02 02:29:55 +00:00
parent e9c311b245
commit 2205341279
6 changed files with 198 additions and 20 deletions

View File

@ -93,6 +93,27 @@ def _create_text_preview(value: str) -> dict:
}
def extract_workflow_id(extra_data: Optional[dict]) -> Optional[str]:
"""Extract the workflow id from a prompt's ``extra_data``.
The frontend stores the id at ``extra_data["extra_pnginfo"]["workflow"]["id"]``
when a prompt is queued. Any value that is not a non-empty string is treated as
missing so callers can rely on the return being either ``None`` or a string.
"""
if not isinstance(extra_data, dict):
return None
extra_pnginfo = extra_data.get('extra_pnginfo')
if not isinstance(extra_pnginfo, dict):
return None
workflow = extra_pnginfo.get('workflow')
if not isinstance(workflow, dict):
return None
workflow_id = workflow.get('id')
if isinstance(workflow_id, str) and workflow_id:
return workflow_id
return None
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
"""Extract create_time and workflow_id from extra_data.
@ -100,8 +121,7 @@ def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str
tuple: (create_time, workflow_id)
"""
create_time = extra_data.get('create_time')
extra_pnginfo = extra_data.get('extra_pnginfo', {})
workflow_id = extra_pnginfo.get('workflow', {}).get('id')
workflow_id = extract_workflow_id(extra_data)
return create_time, workflow_id

View File

@ -164,6 +164,8 @@ class WebUIProgressHandler(ProgressHandler):
if self.server_instance is None:
return
workflow_id = self.registry.workflow_id if self.registry else None
# Only send info for non-pending nodes
active_nodes = {
node_id: {
@ -172,6 +174,7 @@ class WebUIProgressHandler(ProgressHandler):
"state": state["state"].value,
"node_id": node_id,
"prompt_id": prompt_id,
"workflow_id": workflow_id,
"display_node_id": self.registry.dynprompt.get_display_node_id(node_id),
"parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id),
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id),
@ -183,7 +186,7 @@ class WebUIProgressHandler(ProgressHandler):
# Send a combined progress_state message with all node states
# Include client_id to ensure message is only sent to the initiating client
self.server_instance.send_sync(
"progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}, self.server_instance.client_id
"progress_state", {"prompt_id": prompt_id, "workflow_id": workflow_id, "nodes": active_nodes}, self.server_instance.client_id
)
@override
@ -215,6 +218,7 @@ class WebUIProgressHandler(ProgressHandler):
metadata = {
"node_id": node_id,
"prompt_id": prompt_id,
"workflow_id": self.registry.workflow_id if self.registry else None,
"display_node_id": self.registry.dynprompt.get_display_node_id(
node_id
),
@ -240,9 +244,10 @@ class ProgressRegistry:
Registry that maintains node progress state and notifies registered handlers.
"""
def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt"):
def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt", workflow_id: Optional[str] = None):
self.prompt_id = prompt_id
self.dynprompt = dynprompt
self.workflow_id = workflow_id
self.nodes: Dict[str, NodeProgressState] = {}
self.handlers: Dict[str, ProgressHandler] = {}
@ -322,7 +327,7 @@ class ProgressRegistry:
# Global registry instance
global_progress_registry: ProgressRegistry | None = None
def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None:
def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt", workflow_id: Optional[str] = None) -> None:
global global_progress_registry
# Reset existing handlers if registry exists
@ -330,7 +335,7 @@ def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None:
global_progress_registry.reset_handlers()
# Create new registry
global_progress_registry = ProgressRegistry(prompt_id, dynprompt)
global_progress_registry = ProgressRegistry(prompt_id, dynprompt, workflow_id)
def add_progress_handler(handler: ProgressHandler) -> None:

View File

@ -37,6 +37,7 @@ from comfy_execution.graph import (
from comfy_execution.graph_utils import GraphBuilder, is_link
from comfy_execution.validation import validate_node_input
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
from comfy_execution.jobs import extract_workflow_id
from comfy_execution.utils import CurrentNodeContext
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
from comfy_api.latest import io, _io
@ -416,15 +417,15 @@ def _is_intermediate_output(dynprompt, node_id):
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
return getattr(class_def, 'HAS_INTERMEDIATE_OUTPUT', False)
def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, ui_outputs):
def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, workflow_id, ui_outputs):
if server.client_id is None:
return
cached_ui = cached.ui or {}
server.send_sync("executed", { "node": node_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id }, server.client_id)
server.send_sync("executed", { "node": node_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id, "workflow_id": workflow_id }, server.client_id)
if cached.ui is not None:
ui_outputs[node_id] = cached.ui
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, workflow_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
unique_id = current_item
real_node_id = dynprompt.get_real_node_id(unique_id)
display_node_id = dynprompt.get_display_node_id(unique_id)
@ -434,7 +435,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
cached = await caches.outputs.get(unique_id)
if cached is not None:
_send_cached_ui(server, unique_id, display_node_id, cached, prompt_id, ui_outputs)
_send_cached_ui(server, unique_id, display_node_id, cached, prompt_id, workflow_id, ui_outputs)
get_progress_state().finish_progress(unique_id)
execution_list.cache_update(unique_id, cached)
return (ExecutionResult.SUCCESS, None, None)
@ -482,7 +483,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
if server.client_id is not None:
server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id, "workflow_id": workflow_id }, server.client_id)
obj = await caches.objects.get(unique_id)
if obj is None:
@ -512,6 +513,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
if block.message is not None:
mes = {
"prompt_id": prompt_id,
"workflow_id": workflow_id,
"node_id": unique_id,
"node_type": class_type,
"executed": list(executed),
@ -559,7 +561,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
"output": output_ui
}
if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id, "workflow_id": workflow_id }, server.client_id)
if has_subgraph:
cached_outputs = []
new_node_ids = []
@ -656,6 +658,7 @@ class PromptExecutor:
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
self.status_messages = []
self.success = True
self.workflow_id = None
def add_message(self, event, data: dict, broadcast: bool):
data = {
@ -675,6 +678,7 @@ class PromptExecutor:
if isinstance(ex, comfy.model_management.InterruptProcessingException):
mes = {
"prompt_id": prompt_id,
"workflow_id": self.workflow_id,
"node_id": node_id,
"node_type": class_type,
"executed": list(executed),
@ -683,6 +687,7 @@ class PromptExecutor:
else:
mes = {
"prompt_id": prompt_id,
"workflow_id": self.workflow_id,
"node_id": node_id,
"node_type": class_type,
"executed": list(executed),
@ -721,7 +726,9 @@ class PromptExecutor:
self.server.client_id = None
self.status_messages = []
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
self.workflow_id = extract_workflow_id(extra_data)
self.server.last_workflow_id = self.workflow_id
self.add_message("execution_start", { "prompt_id": prompt_id, "workflow_id": self.workflow_id }, broadcast=False)
self._notify_prompt_lifecycle("start", prompt_id)
ram_headroom = int(self.cache_args["ram"] * (1024 ** 3))
@ -731,7 +738,7 @@ class PromptExecutor:
try:
with torch.inference_mode():
dynamic_prompt = DynamicPrompt(prompt)
reset_progress_state(prompt_id, dynamic_prompt)
reset_progress_state(prompt_id, dynamic_prompt, self.workflow_id)
add_progress_handler(WebUIProgressHandler(self.server))
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
for cache in self.caches.all:
@ -749,7 +756,7 @@ class PromptExecutor:
comfy.model_management.cleanup_models_gc()
self.add_message("execution_cached",
{ "nodes": cached_nodes, "prompt_id": prompt_id},
{ "nodes": cached_nodes, "prompt_id": prompt_id, "workflow_id": self.workflow_id },
broadcast=False)
pending_subgraph_results = {}
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
@ -767,7 +774,7 @@ class PromptExecutor:
break
assert node_id is not None, "Node ID should not be None at this point"
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, self.workflow_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
self.success = result != ExecutionResult.FAILURE
if result == ExecutionResult.FAILURE:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
@ -791,8 +798,8 @@ class PromptExecutor:
cached = await self.caches.outputs.get(node_id)
if cached is not None:
display_node_id = dynamic_prompt.get_display_node_id(node_id)
_send_cached_ui(self.server, node_id, display_node_id, cached, prompt_id, ui_node_outputs)
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
_send_cached_ui(self.server, node_id, display_node_id, cached, prompt_id, self.workflow_id, ui_node_outputs)
self.add_message("execution_success", { "prompt_id": prompt_id, "workflow_id": self.workflow_id }, broadcast=False)
ui_outputs = {}
meta_outputs = {}

View File

@ -322,7 +322,7 @@ def prompt_worker(q, server_instance):
completed=e.success,
messages=e.status_messages), process_item=remove_sensitive)
if server_instance.client_id is not None:
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id, "workflow_id": getattr(server_instance, 'last_workflow_id', None)}, server_instance.client_id)
current_time = time.perf_counter()
execution_time = current_time - execution_start_time
@ -385,7 +385,7 @@ def hijack_progress(server_instance):
prompt_id = server_instance.last_prompt_id
if node_id is None:
node_id = server_instance.last_node_id
progress = {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id}
progress = {"value": value, "max": total, "prompt_id": prompt_id, "workflow_id": getattr(server_instance, 'last_workflow_id', None), "node": node_id}
get_progress_state().update_progress(node_id, value, total, preview_image)
server_instance.send_sync("progress", progress, server_instance.client_id)

View File

@ -0,0 +1,111 @@
"""Tests that workflow_id is included alongside prompt_id in WebSocket payloads
emitted by the progress handler and the prompt executor.
Frontend stores extra_data["extra_pnginfo"]["workflow"]["id"] when queueing a
prompt; we propagate that as `workflow_id` on every execution event so a
multi-tab UI can scope progress state by workflow even when terminal
WebSocket frames are dropped.
"""
from unittest.mock import MagicMock
import pytest
from comfy_execution.progress import (
NodeState,
ProgressRegistry,
WebUIProgressHandler,
reset_progress_state,
get_progress_state,
)
class _DummyDynPrompt:
def get_display_node_id(self, node_id):
return node_id
def get_parent_node_id(self, node_id):
return None
def get_real_node_id(self, node_id):
return node_id
@pytest.fixture
def server():
s = MagicMock()
s.client_id = "client-1"
return s
def _registry(workflow_id):
return ProgressRegistry(
prompt_id="prompt-1",
dynprompt=_DummyDynPrompt(),
workflow_id=workflow_id,
)
class TestProgressStatePayload:
def test_progress_state_includes_workflow_id(self, server):
registry = _registry("wf-abc")
registry.nodes["n1"] = {
"state": NodeState.Running,
"value": 1.0,
"max": 5.0,
}
handler = WebUIProgressHandler(server)
handler.set_registry(registry)
handler._send_progress_state("prompt-1", registry.nodes)
server.send_sync.assert_called_once()
event, payload, sid = server.send_sync.call_args.args
assert event == "progress_state"
assert payload["prompt_id"] == "prompt-1"
assert payload["workflow_id"] == "wf-abc"
assert payload["nodes"]["n1"]["workflow_id"] == "wf-abc"
assert payload["nodes"]["n1"]["prompt_id"] == "prompt-1"
assert sid == "client-1"
def test_progress_state_workflow_id_none_when_missing(self, server):
registry = _registry(None)
registry.nodes["n1"] = {
"state": NodeState.Running,
"value": 0.5,
"max": 1.0,
}
handler = WebUIProgressHandler(server)
handler.set_registry(registry)
handler._send_progress_state("prompt-1", registry.nodes)
_, payload, _ = server.send_sync.call_args.args
assert payload["workflow_id"] is None
assert payload["nodes"]["n1"]["workflow_id"] is None
class TestProgressRegistryConstruction:
def test_workflow_id_default_is_none(self):
registry = ProgressRegistry(
prompt_id="prompt-1", dynprompt=_DummyDynPrompt()
)
assert registry.workflow_id is None
def test_workflow_id_stored_on_registry(self):
registry = ProgressRegistry(
prompt_id="prompt-1",
dynprompt=_DummyDynPrompt(),
workflow_id="wf-xyz",
)
assert registry.workflow_id == "wf-xyz"
class TestResetProgressState:
def test_reset_threads_workflow_id(self):
reset_progress_state("prompt-1", _DummyDynPrompt(), "wf-456")
assert get_progress_state().workflow_id == "wf-456"
def test_reset_default_workflow_id_none(self):
reset_progress_state("prompt-2", _DummyDynPrompt())
assert get_progress_state().workflow_id is None

View File

@ -10,9 +10,44 @@ from comfy_execution.jobs import (
get_outputs_summary,
apply_sorting,
has_3d_extension,
extract_workflow_id,
)
class TestExtractWorkflowId:
"""Unit tests for extract_workflow_id()."""
def test_returns_id_from_extra_pnginfo(self):
assert extract_workflow_id({'extra_pnginfo': {'workflow': {'id': 'wf-123'}}}) == 'wf-123'
def test_missing_extra_data_returns_none(self):
assert extract_workflow_id(None) is None
def test_non_dict_extra_data_returns_none(self):
assert extract_workflow_id('not-a-dict') is None
def test_missing_extra_pnginfo_returns_none(self):
assert extract_workflow_id({}) is None
def test_missing_workflow_returns_none(self):
assert extract_workflow_id({'extra_pnginfo': {}}) is None
def test_missing_id_returns_none(self):
assert extract_workflow_id({'extra_pnginfo': {'workflow': {}}}) is None
def test_empty_string_id_returns_none(self):
assert extract_workflow_id({'extra_pnginfo': {'workflow': {'id': ''}}}) is None
def test_non_string_id_returns_none(self):
assert extract_workflow_id({'extra_pnginfo': {'workflow': {'id': 42}}}) is None
def test_non_dict_workflow_returns_none(self):
assert extract_workflow_id({'extra_pnginfo': {'workflow': 'not-a-dict'}}) is None
def test_non_dict_extra_pnginfo_returns_none(self):
assert extract_workflow_id({'extra_pnginfo': 'not-a-dict'}) is None
class TestJobStatus:
"""Test JobStatus constants."""