mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-04 21:36:14 +02:00
Include workflow_id in all execution WebSocket messages
The frontend already stores extra_data['extra_pnginfo']['workflow']['id'] when queueing a prompt and exposes it via the /api/jobs REST endpoint, but none of the WebSocket events emitted during execution carry it. That makes it impossible to scope progress state by workflow on the client without maintaining a job_id -> workflow_id mapping that races with execution_start. This adds workflow_id alongside prompt_id on every execution event: - execution_start, execution_success, execution_error, execution_interrupted, execution_cached, executing, executed - progress and progress_state - the metadata block on PREVIEW_IMAGE_WITH_METADATA A new public extract_workflow_id helper in comfy_execution/jobs.py is the single source of truth for the lookup; the existing _extract_job_metadata delegates to it. The id is plumbed through PromptExecutor (stored as self.workflow_id and on server.last_workflow_id), the module-level execute() coroutine, the _send_cached_ui helper, and ProgressRegistry / reset_progress_state so WebUIProgressHandler can include it in progress_state and preview-image metadata. The progress hook in main.py reads server.last_workflow_id to populate the legacy 'progress' event. Tests cover the helper's edge cases (missing/non-string ids, non-dict levels) and that the WebUIProgressHandler emits workflow_id on every progress_state payload via mocked PromptServer.
This commit is contained in:
parent
e9c311b245
commit
2205341279
@ -93,6 +93,27 @@ def _create_text_preview(value: str) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def extract_workflow_id(extra_data: Optional[dict]) -> Optional[str]:
|
||||
"""Extract the workflow id from a prompt's ``extra_data``.
|
||||
|
||||
The frontend stores the id at ``extra_data["extra_pnginfo"]["workflow"]["id"]``
|
||||
when a prompt is queued. Any value that is not a non-empty string is treated as
|
||||
missing so callers can rely on the return being either ``None`` or a string.
|
||||
"""
|
||||
if not isinstance(extra_data, dict):
|
||||
return None
|
||||
extra_pnginfo = extra_data.get('extra_pnginfo')
|
||||
if not isinstance(extra_pnginfo, dict):
|
||||
return None
|
||||
workflow = extra_pnginfo.get('workflow')
|
||||
if not isinstance(workflow, dict):
|
||||
return None
|
||||
workflow_id = workflow.get('id')
|
||||
if isinstance(workflow_id, str) and workflow_id:
|
||||
return workflow_id
|
||||
return None
|
||||
|
||||
|
||||
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
|
||||
"""Extract create_time and workflow_id from extra_data.
|
||||
|
||||
@ -100,8 +121,7 @@ def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str
|
||||
tuple: (create_time, workflow_id)
|
||||
"""
|
||||
create_time = extra_data.get('create_time')
|
||||
extra_pnginfo = extra_data.get('extra_pnginfo', {})
|
||||
workflow_id = extra_pnginfo.get('workflow', {}).get('id')
|
||||
workflow_id = extract_workflow_id(extra_data)
|
||||
return create_time, workflow_id
|
||||
|
||||
|
||||
|
||||
@ -164,6 +164,8 @@ class WebUIProgressHandler(ProgressHandler):
|
||||
if self.server_instance is None:
|
||||
return
|
||||
|
||||
workflow_id = self.registry.workflow_id if self.registry else None
|
||||
|
||||
# Only send info for non-pending nodes
|
||||
active_nodes = {
|
||||
node_id: {
|
||||
@ -172,6 +174,7 @@ class WebUIProgressHandler(ProgressHandler):
|
||||
"state": state["state"].value,
|
||||
"node_id": node_id,
|
||||
"prompt_id": prompt_id,
|
||||
"workflow_id": workflow_id,
|
||||
"display_node_id": self.registry.dynprompt.get_display_node_id(node_id),
|
||||
"parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id),
|
||||
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id),
|
||||
@ -183,7 +186,7 @@ class WebUIProgressHandler(ProgressHandler):
|
||||
# Send a combined progress_state message with all node states
|
||||
# Include client_id to ensure message is only sent to the initiating client
|
||||
self.server_instance.send_sync(
|
||||
"progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}, self.server_instance.client_id
|
||||
"progress_state", {"prompt_id": prompt_id, "workflow_id": workflow_id, "nodes": active_nodes}, self.server_instance.client_id
|
||||
)
|
||||
|
||||
@override
|
||||
@ -215,6 +218,7 @@ class WebUIProgressHandler(ProgressHandler):
|
||||
metadata = {
|
||||
"node_id": node_id,
|
||||
"prompt_id": prompt_id,
|
||||
"workflow_id": self.registry.workflow_id if self.registry else None,
|
||||
"display_node_id": self.registry.dynprompt.get_display_node_id(
|
||||
node_id
|
||||
),
|
||||
@ -240,9 +244,10 @@ class ProgressRegistry:
|
||||
Registry that maintains node progress state and notifies registered handlers.
|
||||
"""
|
||||
|
||||
def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt"):
|
||||
def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt", workflow_id: Optional[str] = None):
|
||||
self.prompt_id = prompt_id
|
||||
self.dynprompt = dynprompt
|
||||
self.workflow_id = workflow_id
|
||||
self.nodes: Dict[str, NodeProgressState] = {}
|
||||
self.handlers: Dict[str, ProgressHandler] = {}
|
||||
|
||||
@ -322,7 +327,7 @@ class ProgressRegistry:
|
||||
# Global registry instance
|
||||
global_progress_registry: ProgressRegistry | None = None
|
||||
|
||||
def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None:
|
||||
def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt", workflow_id: Optional[str] = None) -> None:
|
||||
global global_progress_registry
|
||||
|
||||
# Reset existing handlers if registry exists
|
||||
@ -330,7 +335,7 @@ def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None:
|
||||
global_progress_registry.reset_handlers()
|
||||
|
||||
# Create new registry
|
||||
global_progress_registry = ProgressRegistry(prompt_id, dynprompt)
|
||||
global_progress_registry = ProgressRegistry(prompt_id, dynprompt, workflow_id)
|
||||
|
||||
|
||||
def add_progress_handler(handler: ProgressHandler) -> None:
|
||||
|
||||
31
execution.py
31
execution.py
@ -37,6 +37,7 @@ from comfy_execution.graph import (
|
||||
from comfy_execution.graph_utils import GraphBuilder, is_link
|
||||
from comfy_execution.validation import validate_node_input
|
||||
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
|
||||
from comfy_execution.jobs import extract_workflow_id
|
||||
from comfy_execution.utils import CurrentNodeContext
|
||||
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
||||
from comfy_api.latest import io, _io
|
||||
@ -416,15 +417,15 @@ def _is_intermediate_output(dynprompt, node_id):
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
return getattr(class_def, 'HAS_INTERMEDIATE_OUTPUT', False)
|
||||
|
||||
def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, ui_outputs):
|
||||
def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, workflow_id, ui_outputs):
|
||||
if server.client_id is None:
|
||||
return
|
||||
cached_ui = cached.ui or {}
|
||||
server.send_sync("executed", { "node": node_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id }, server.client_id)
|
||||
server.send_sync("executed", { "node": node_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id, "workflow_id": workflow_id }, server.client_id)
|
||||
if cached.ui is not None:
|
||||
ui_outputs[node_id] = cached.ui
|
||||
|
||||
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
|
||||
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, workflow_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
|
||||
unique_id = current_item
|
||||
real_node_id = dynprompt.get_real_node_id(unique_id)
|
||||
display_node_id = dynprompt.get_display_node_id(unique_id)
|
||||
@ -434,7 +435,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
cached = await caches.outputs.get(unique_id)
|
||||
if cached is not None:
|
||||
_send_cached_ui(server, unique_id, display_node_id, cached, prompt_id, ui_outputs)
|
||||
_send_cached_ui(server, unique_id, display_node_id, cached, prompt_id, workflow_id, ui_outputs)
|
||||
get_progress_state().finish_progress(unique_id)
|
||||
execution_list.cache_update(unique_id, cached)
|
||||
return (ExecutionResult.SUCCESS, None, None)
|
||||
@ -482,7 +483,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.last_node_id = display_node_id
|
||||
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
||||
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id, "workflow_id": workflow_id }, server.client_id)
|
||||
|
||||
obj = await caches.objects.get(unique_id)
|
||||
if obj is None:
|
||||
@ -512,6 +513,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
if block.message is not None:
|
||||
mes = {
|
||||
"prompt_id": prompt_id,
|
||||
"workflow_id": workflow_id,
|
||||
"node_id": unique_id,
|
||||
"node_type": class_type,
|
||||
"executed": list(executed),
|
||||
@ -559,7 +561,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
"output": output_ui
|
||||
}
|
||||
if server.client_id is not None:
|
||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id, "workflow_id": workflow_id }, server.client_id)
|
||||
if has_subgraph:
|
||||
cached_outputs = []
|
||||
new_node_ids = []
|
||||
@ -656,6 +658,7 @@ class PromptExecutor:
|
||||
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
|
||||
self.status_messages = []
|
||||
self.success = True
|
||||
self.workflow_id = None
|
||||
|
||||
def add_message(self, event, data: dict, broadcast: bool):
|
||||
data = {
|
||||
@ -675,6 +678,7 @@ class PromptExecutor:
|
||||
if isinstance(ex, comfy.model_management.InterruptProcessingException):
|
||||
mes = {
|
||||
"prompt_id": prompt_id,
|
||||
"workflow_id": self.workflow_id,
|
||||
"node_id": node_id,
|
||||
"node_type": class_type,
|
||||
"executed": list(executed),
|
||||
@ -683,6 +687,7 @@ class PromptExecutor:
|
||||
else:
|
||||
mes = {
|
||||
"prompt_id": prompt_id,
|
||||
"workflow_id": self.workflow_id,
|
||||
"node_id": node_id,
|
||||
"node_type": class_type,
|
||||
"executed": list(executed),
|
||||
@ -721,7 +726,9 @@ class PromptExecutor:
|
||||
self.server.client_id = None
|
||||
|
||||
self.status_messages = []
|
||||
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
|
||||
self.workflow_id = extract_workflow_id(extra_data)
|
||||
self.server.last_workflow_id = self.workflow_id
|
||||
self.add_message("execution_start", { "prompt_id": prompt_id, "workflow_id": self.workflow_id }, broadcast=False)
|
||||
|
||||
self._notify_prompt_lifecycle("start", prompt_id)
|
||||
ram_headroom = int(self.cache_args["ram"] * (1024 ** 3))
|
||||
@ -731,7 +738,7 @@ class PromptExecutor:
|
||||
try:
|
||||
with torch.inference_mode():
|
||||
dynamic_prompt = DynamicPrompt(prompt)
|
||||
reset_progress_state(prompt_id, dynamic_prompt)
|
||||
reset_progress_state(prompt_id, dynamic_prompt, self.workflow_id)
|
||||
add_progress_handler(WebUIProgressHandler(self.server))
|
||||
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
|
||||
for cache in self.caches.all:
|
||||
@ -749,7 +756,7 @@ class PromptExecutor:
|
||||
|
||||
comfy.model_management.cleanup_models_gc()
|
||||
self.add_message("execution_cached",
|
||||
{ "nodes": cached_nodes, "prompt_id": prompt_id},
|
||||
{ "nodes": cached_nodes, "prompt_id": prompt_id, "workflow_id": self.workflow_id },
|
||||
broadcast=False)
|
||||
pending_subgraph_results = {}
|
||||
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
||||
@ -767,7 +774,7 @@ class PromptExecutor:
|
||||
break
|
||||
|
||||
assert node_id is not None, "Node ID should not be None at this point"
|
||||
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
|
||||
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, self.workflow_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
|
||||
self.success = result != ExecutionResult.FAILURE
|
||||
if result == ExecutionResult.FAILURE:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
@ -791,8 +798,8 @@ class PromptExecutor:
|
||||
cached = await self.caches.outputs.get(node_id)
|
||||
if cached is not None:
|
||||
display_node_id = dynamic_prompt.get_display_node_id(node_id)
|
||||
_send_cached_ui(self.server, node_id, display_node_id, cached, prompt_id, ui_node_outputs)
|
||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
||||
_send_cached_ui(self.server, node_id, display_node_id, cached, prompt_id, self.workflow_id, ui_node_outputs)
|
||||
self.add_message("execution_success", { "prompt_id": prompt_id, "workflow_id": self.workflow_id }, broadcast=False)
|
||||
|
||||
ui_outputs = {}
|
||||
meta_outputs = {}
|
||||
|
||||
4
main.py
4
main.py
@ -322,7 +322,7 @@ def prompt_worker(q, server_instance):
|
||||
completed=e.success,
|
||||
messages=e.status_messages), process_item=remove_sensitive)
|
||||
if server_instance.client_id is not None:
|
||||
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
|
||||
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id, "workflow_id": getattr(server_instance, 'last_workflow_id', None)}, server_instance.client_id)
|
||||
|
||||
current_time = time.perf_counter()
|
||||
execution_time = current_time - execution_start_time
|
||||
@ -385,7 +385,7 @@ def hijack_progress(server_instance):
|
||||
prompt_id = server_instance.last_prompt_id
|
||||
if node_id is None:
|
||||
node_id = server_instance.last_node_id
|
||||
progress = {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id}
|
||||
progress = {"value": value, "max": total, "prompt_id": prompt_id, "workflow_id": getattr(server_instance, 'last_workflow_id', None), "node": node_id}
|
||||
get_progress_state().update_progress(node_id, value, total, preview_image)
|
||||
|
||||
server_instance.send_sync("progress", progress, server_instance.client_id)
|
||||
|
||||
111
tests-unit/execution_test/test_workflow_id_in_ws_messages.py
Normal file
111
tests-unit/execution_test/test_workflow_id_in_ws_messages.py
Normal file
@ -0,0 +1,111 @@
|
||||
"""Tests that workflow_id is included alongside prompt_id in WebSocket payloads
|
||||
emitted by the progress handler and the prompt executor.
|
||||
|
||||
Frontend stores extra_data["extra_pnginfo"]["workflow"]["id"] when queueing a
|
||||
prompt; we propagate that as `workflow_id` on every execution event so a
|
||||
multi-tab UI can scope progress state by workflow even when terminal
|
||||
WebSocket frames are dropped.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from comfy_execution.progress import (
|
||||
NodeState,
|
||||
ProgressRegistry,
|
||||
WebUIProgressHandler,
|
||||
reset_progress_state,
|
||||
get_progress_state,
|
||||
)
|
||||
|
||||
|
||||
class _DummyDynPrompt:
|
||||
def get_display_node_id(self, node_id):
|
||||
return node_id
|
||||
|
||||
def get_parent_node_id(self, node_id):
|
||||
return None
|
||||
|
||||
def get_real_node_id(self, node_id):
|
||||
return node_id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server():
|
||||
s = MagicMock()
|
||||
s.client_id = "client-1"
|
||||
return s
|
||||
|
||||
|
||||
def _registry(workflow_id):
|
||||
return ProgressRegistry(
|
||||
prompt_id="prompt-1",
|
||||
dynprompt=_DummyDynPrompt(),
|
||||
workflow_id=workflow_id,
|
||||
)
|
||||
|
||||
|
||||
class TestProgressStatePayload:
|
||||
def test_progress_state_includes_workflow_id(self, server):
|
||||
registry = _registry("wf-abc")
|
||||
registry.nodes["n1"] = {
|
||||
"state": NodeState.Running,
|
||||
"value": 1.0,
|
||||
"max": 5.0,
|
||||
}
|
||||
|
||||
handler = WebUIProgressHandler(server)
|
||||
handler.set_registry(registry)
|
||||
handler._send_progress_state("prompt-1", registry.nodes)
|
||||
|
||||
server.send_sync.assert_called_once()
|
||||
event, payload, sid = server.send_sync.call_args.args
|
||||
assert event == "progress_state"
|
||||
assert payload["prompt_id"] == "prompt-1"
|
||||
assert payload["workflow_id"] == "wf-abc"
|
||||
assert payload["nodes"]["n1"]["workflow_id"] == "wf-abc"
|
||||
assert payload["nodes"]["n1"]["prompt_id"] == "prompt-1"
|
||||
assert sid == "client-1"
|
||||
|
||||
def test_progress_state_workflow_id_none_when_missing(self, server):
|
||||
registry = _registry(None)
|
||||
registry.nodes["n1"] = {
|
||||
"state": NodeState.Running,
|
||||
"value": 0.5,
|
||||
"max": 1.0,
|
||||
}
|
||||
|
||||
handler = WebUIProgressHandler(server)
|
||||
handler.set_registry(registry)
|
||||
handler._send_progress_state("prompt-1", registry.nodes)
|
||||
|
||||
_, payload, _ = server.send_sync.call_args.args
|
||||
assert payload["workflow_id"] is None
|
||||
assert payload["nodes"]["n1"]["workflow_id"] is None
|
||||
|
||||
|
||||
class TestProgressRegistryConstruction:
|
||||
def test_workflow_id_default_is_none(self):
|
||||
registry = ProgressRegistry(
|
||||
prompt_id="prompt-1", dynprompt=_DummyDynPrompt()
|
||||
)
|
||||
assert registry.workflow_id is None
|
||||
|
||||
def test_workflow_id_stored_on_registry(self):
|
||||
registry = ProgressRegistry(
|
||||
prompt_id="prompt-1",
|
||||
dynprompt=_DummyDynPrompt(),
|
||||
workflow_id="wf-xyz",
|
||||
)
|
||||
assert registry.workflow_id == "wf-xyz"
|
||||
|
||||
|
||||
class TestResetProgressState:
|
||||
def test_reset_threads_workflow_id(self):
|
||||
reset_progress_state("prompt-1", _DummyDynPrompt(), "wf-456")
|
||||
assert get_progress_state().workflow_id == "wf-456"
|
||||
|
||||
def test_reset_default_workflow_id_none(self):
|
||||
reset_progress_state("prompt-2", _DummyDynPrompt())
|
||||
assert get_progress_state().workflow_id is None
|
||||
@ -10,9 +10,44 @@ from comfy_execution.jobs import (
|
||||
get_outputs_summary,
|
||||
apply_sorting,
|
||||
has_3d_extension,
|
||||
extract_workflow_id,
|
||||
)
|
||||
|
||||
|
||||
class TestExtractWorkflowId:
|
||||
"""Unit tests for extract_workflow_id()."""
|
||||
|
||||
def test_returns_id_from_extra_pnginfo(self):
|
||||
assert extract_workflow_id({'extra_pnginfo': {'workflow': {'id': 'wf-123'}}}) == 'wf-123'
|
||||
|
||||
def test_missing_extra_data_returns_none(self):
|
||||
assert extract_workflow_id(None) is None
|
||||
|
||||
def test_non_dict_extra_data_returns_none(self):
|
||||
assert extract_workflow_id('not-a-dict') is None
|
||||
|
||||
def test_missing_extra_pnginfo_returns_none(self):
|
||||
assert extract_workflow_id({}) is None
|
||||
|
||||
def test_missing_workflow_returns_none(self):
|
||||
assert extract_workflow_id({'extra_pnginfo': {}}) is None
|
||||
|
||||
def test_missing_id_returns_none(self):
|
||||
assert extract_workflow_id({'extra_pnginfo': {'workflow': {}}}) is None
|
||||
|
||||
def test_empty_string_id_returns_none(self):
|
||||
assert extract_workflow_id({'extra_pnginfo': {'workflow': {'id': ''}}}) is None
|
||||
|
||||
def test_non_string_id_returns_none(self):
|
||||
assert extract_workflow_id({'extra_pnginfo': {'workflow': {'id': 42}}}) is None
|
||||
|
||||
def test_non_dict_workflow_returns_none(self):
|
||||
assert extract_workflow_id({'extra_pnginfo': {'workflow': 'not-a-dict'}}) is None
|
||||
|
||||
def test_non_dict_extra_pnginfo_returns_none(self):
|
||||
assert extract_workflow_id({'extra_pnginfo': 'not-a-dict'}) is None
|
||||
|
||||
|
||||
class TestJobStatus:
|
||||
"""Test JobStatus constants."""
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user