From 1f0b13705dc2cf875559907292461331ee3a8133 Mon Sep 17 00:00:00 2001 From: Glary-Bot Date: Sat, 2 May 2026 02:39:30 +0000 Subject: [PATCH] Address review: clear stale workflow_id, expand reconnect payload, harden tests - Clear self.server.last_workflow_id and self.workflow_id in the PromptExecutor finally block so a progress callback racing with teardown can no longer attach the previous run's workflow_id to a later 'progress' event. - Include prompt_id and last_workflow_id in the reconnect 'executing' message in server.py so reconnecting clients can recover both workflow- and prompt-scoped execution state, matching the regular 'executing' payload. - Add an AST-based static guard that walks execution.py, main.py, and comfy_execution/progress.py and asserts every dict literal carrying prompt_id also carries workflow_id. Also add a unit test covering PREVIEW_IMAGE_WITH_METADATA metadata. Together these regression-test every emitter (execution_start/success/error/interrupted/cached, executing, executed, progress, progress_state, preview metadata) without requiring a GPU-backed import of execution.py. --- execution.py | 2 + server.py | 6 +- .../test_workflow_id_in_ws_messages.py | 110 ++++++++++++++++++ 3 files changed, 117 insertions(+), 1 deletion(-) diff --git a/execution.py b/execution.py index 4f6b5f2c5..5014da312 100644 --- a/execution.py +++ b/execution.py @@ -816,6 +816,8 @@ class PromptExecutor: finally: comfy.memory_management.set_ram_cache_release_state(None, 0) self._notify_prompt_lifecycle("end", prompt_id) + self.server.last_workflow_id = None + self.workflow_id = None async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): diff --git a/server.py b/server.py index 881da8e66..5029fbb27 100644 --- a/server.py +++ b/server.py @@ -274,7 +274,11 @@ class PromptServer(): await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid) # On reconnect if we are the currently executing client send the current node if self.client_id == sid and self.last_node_id is not None: - await self.send("executing", { "node": self.last_node_id }, sid) + await self.send("executing", { + "node": self.last_node_id, + "prompt_id": getattr(self, "last_prompt_id", None), + "workflow_id": getattr(self, "last_workflow_id", None), + }, sid) # Flag to track if we've received the first message first_message = True diff --git a/tests-unit/execution_test/test_workflow_id_in_ws_messages.py b/tests-unit/execution_test/test_workflow_id_in_ws_messages.py index 0645c02c0..ff675a9a7 100644 --- a/tests-unit/execution_test/test_workflow_id_in_ws_messages.py +++ b/tests-unit/execution_test/test_workflow_id_in_ws_messages.py @@ -109,3 +109,113 @@ class TestResetProgressState: def test_reset_default_workflow_id_none(self): reset_progress_state("prompt-2", _DummyDynPrompt()) assert get_progress_state().workflow_id is None + + +class TestExecutionMessagePayloadsContainWorkflowId: + """Static-analysis guard ensuring every WebSocket message payload that + carries `prompt_id` also carries `workflow_id`. This is a regression net + for future refactors of execution.py / main.py / progress.py and avoids + the GPU/torch dependency of importing `execution.py` directly. + """ + + @staticmethod + def _emitting_dicts(source: str): + """Yield every dict literal in `source` that contains a 'prompt_id' key.""" + import ast + + tree = ast.parse(source) + for node in ast.walk(tree): + if not isinstance(node, ast.Dict): + continue + keys = [ + k.value + for k in node.keys + if isinstance(k, ast.Constant) and isinstance(k.value, str) + ] + if "prompt_id" in keys: + yield node, keys + + def _assert_workflow_id_in_every_prompt_id_dict(self, file_path: str): + from pathlib import Path + + source = Path(file_path).read_text() + offenders = [] + for node, keys in self._emitting_dicts(source): + if "workflow_id" not in keys: + offenders.append((node.lineno, keys)) + assert not offenders, ( + f"{file_path}: dict literals with 'prompt_id' but no 'workflow_id': {offenders}" + ) + + def test_execution_py_payloads_include_workflow_id(self): + self._assert_workflow_id_in_every_prompt_id_dict("execution.py") + + def test_main_py_payloads_include_workflow_id(self): + self._assert_workflow_id_in_every_prompt_id_dict("main.py") + + def test_progress_py_payloads_include_workflow_id(self): + self._assert_workflow_id_in_every_prompt_id_dict("comfy_execution/progress.py") + + +class TestPreviewImageMetadataPayload: + """Verify PREVIEW_IMAGE_WITH_METADATA metadata carries workflow_id.""" + + def test_preview_metadata_includes_workflow_id(self): + from unittest.mock import MagicMock, patch + from PIL import Image + + from comfy_execution.progress import ( + NodeState, + ProgressRegistry, + WebUIProgressHandler, + ) + + class _DynPrompt: + def get_display_node_id(self, n): + return n + + def get_parent_node_id(self, n): + return None + + def get_real_node_id(self, n): + return n + + server = MagicMock() + server.client_id = "cid" + server.sockets_metadata = {} + + registry = ProgressRegistry( + prompt_id="p1", dynprompt=_DynPrompt(), workflow_id="wf-1" + ) + handler = WebUIProgressHandler(server) + handler.set_registry(registry) + + image = ("PNG", Image.new("RGB", (1, 1)), None) + + with patch( + "comfy_execution.progress.feature_flags.supports_feature", + return_value=True, + ): + handler.update_handler( + node_id="n1", + value=1.0, + max_value=1.0, + state={ + "state": NodeState.Running, + "value": 1.0, + "max": 1.0, + }, + prompt_id="p1", + image=image, + ) + + preview_calls = [ + c + for c in server.send_sync.call_args_list + if c.args[0] != "progress_state" + ] + assert len(preview_calls) == 1 + _, payload, _ = preview_calls[0].args + _, metadata = payload + assert metadata["prompt_id"] == "p1" + assert metadata["workflow_id"] == "wf-1"