Address review: clear stale workflow_id, expand reconnect payload, harden tests

- Clear self.server.last_workflow_id and self.workflow_id in the
  PromptExecutor finally block so a progress callback racing with
  teardown can no longer attach the previous run's workflow_id to a
  later 'progress' event.
- Include prompt_id and last_workflow_id in the reconnect 'executing'
  message in server.py so reconnecting clients can recover both
  workflow- and prompt-scoped execution state, matching the regular
  'executing' payload.
- Add an AST-based static guard that walks execution.py, main.py, and
  comfy_execution/progress.py and asserts every dict literal carrying
  prompt_id also carries workflow_id. Also add a unit test covering
  PREVIEW_IMAGE_WITH_METADATA metadata. Together these regression-test
  every emitter (execution_start/success/error/interrupted/cached,
  executing, executed, progress, progress_state, preview metadata)
  without requiring a GPU-backed import of execution.py.
This commit is contained in:
Glary-Bot 2026-05-02 02:39:30 +00:00
parent 2205341279
commit 1f0b13705d
3 changed files with 117 additions and 1 deletions

View File

@ -816,6 +816,8 @@ class PromptExecutor:
finally:
comfy.memory_management.set_ram_cache_release_state(None, 0)
self._notify_prompt_lifecycle("end", prompt_id)
self.server.last_workflow_id = None
self.workflow_id = None
async def validate_inputs(prompt_id, prompt, item, validated, visiting=None):

View File

@ -274,7 +274,11 @@ class PromptServer():
await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid)
# On reconnect if we are the currently executing client send the current node
if self.client_id == sid and self.last_node_id is not None:
await self.send("executing", { "node": self.last_node_id }, sid)
await self.send("executing", {
"node": self.last_node_id,
"prompt_id": getattr(self, "last_prompt_id", None),
"workflow_id": getattr(self, "last_workflow_id", None),
}, sid)
# Flag to track if we've received the first message
first_message = True

View File

@ -109,3 +109,113 @@ class TestResetProgressState:
def test_reset_default_workflow_id_none(self):
reset_progress_state("prompt-2", _DummyDynPrompt())
assert get_progress_state().workflow_id is None
class TestExecutionMessagePayloadsContainWorkflowId:
"""Static-analysis guard ensuring every WebSocket message payload that
carries `prompt_id` also carries `workflow_id`. This is a regression net
for future refactors of execution.py / main.py / progress.py and avoids
the GPU/torch dependency of importing `execution.py` directly.
"""
@staticmethod
def _emitting_dicts(source: str):
"""Yield every dict literal in `source` that contains a 'prompt_id' key."""
import ast
tree = ast.parse(source)
for node in ast.walk(tree):
if not isinstance(node, ast.Dict):
continue
keys = [
k.value
for k in node.keys
if isinstance(k, ast.Constant) and isinstance(k.value, str)
]
if "prompt_id" in keys:
yield node, keys
def _assert_workflow_id_in_every_prompt_id_dict(self, file_path: str):
from pathlib import Path
source = Path(file_path).read_text()
offenders = []
for node, keys in self._emitting_dicts(source):
if "workflow_id" not in keys:
offenders.append((node.lineno, keys))
assert not offenders, (
f"{file_path}: dict literals with 'prompt_id' but no 'workflow_id': {offenders}"
)
def test_execution_py_payloads_include_workflow_id(self):
self._assert_workflow_id_in_every_prompt_id_dict("execution.py")
def test_main_py_payloads_include_workflow_id(self):
self._assert_workflow_id_in_every_prompt_id_dict("main.py")
def test_progress_py_payloads_include_workflow_id(self):
self._assert_workflow_id_in_every_prompt_id_dict("comfy_execution/progress.py")
class TestPreviewImageMetadataPayload:
"""Verify PREVIEW_IMAGE_WITH_METADATA metadata carries workflow_id."""
def test_preview_metadata_includes_workflow_id(self):
from unittest.mock import MagicMock, patch
from PIL import Image
from comfy_execution.progress import (
NodeState,
ProgressRegistry,
WebUIProgressHandler,
)
class _DynPrompt:
def get_display_node_id(self, n):
return n
def get_parent_node_id(self, n):
return None
def get_real_node_id(self, n):
return n
server = MagicMock()
server.client_id = "cid"
server.sockets_metadata = {}
registry = ProgressRegistry(
prompt_id="p1", dynprompt=_DynPrompt(), workflow_id="wf-1"
)
handler = WebUIProgressHandler(server)
handler.set_registry(registry)
image = ("PNG", Image.new("RGB", (1, 1)), None)
with patch(
"comfy_execution.progress.feature_flags.supports_feature",
return_value=True,
):
handler.update_handler(
node_id="n1",
value=1.0,
max_value=1.0,
state={
"state": NodeState.Running,
"value": 1.0,
"max": 1.0,
},
prompt_id="p1",
image=image,
)
preview_calls = [
c
for c in server.send_sync.call_args_list
if c.args[0] != "progress_state"
]
assert len(preview_calls) == 1
_, payload, _ = preview_calls[0].args
_, metadata = payload
assert metadata["prompt_id"] == "p1"
assert metadata["workflow_id"] == "wf-1"