mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-04 21:36:14 +02:00
Address review: clear stale workflow_id, expand reconnect payload, harden tests
- Clear self.server.last_workflow_id and self.workflow_id in the PromptExecutor finally block so a progress callback racing with teardown can no longer attach the previous run's workflow_id to a later 'progress' event. - Include prompt_id and last_workflow_id in the reconnect 'executing' message in server.py so reconnecting clients can recover both workflow- and prompt-scoped execution state, matching the regular 'executing' payload. - Add an AST-based static guard that walks execution.py, main.py, and comfy_execution/progress.py and asserts every dict literal carrying prompt_id also carries workflow_id. Also add a unit test covering PREVIEW_IMAGE_WITH_METADATA metadata. Together these regression-test every emitter (execution_start/success/error/interrupted/cached, executing, executed, progress, progress_state, preview metadata) without requiring a GPU-backed import of execution.py.
This commit is contained in:
parent
2205341279
commit
1f0b13705d
@ -816,6 +816,8 @@ class PromptExecutor:
|
||||
finally:
|
||||
comfy.memory_management.set_ram_cache_release_state(None, 0)
|
||||
self._notify_prompt_lifecycle("end", prompt_id)
|
||||
self.server.last_workflow_id = None
|
||||
self.workflow_id = None
|
||||
|
||||
|
||||
async def validate_inputs(prompt_id, prompt, item, validated, visiting=None):
|
||||
|
||||
@ -274,7 +274,11 @@ class PromptServer():
|
||||
await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid)
|
||||
# On reconnect if we are the currently executing client send the current node
|
||||
if self.client_id == sid and self.last_node_id is not None:
|
||||
await self.send("executing", { "node": self.last_node_id }, sid)
|
||||
await self.send("executing", {
|
||||
"node": self.last_node_id,
|
||||
"prompt_id": getattr(self, "last_prompt_id", None),
|
||||
"workflow_id": getattr(self, "last_workflow_id", None),
|
||||
}, sid)
|
||||
|
||||
# Flag to track if we've received the first message
|
||||
first_message = True
|
||||
|
||||
@ -109,3 +109,113 @@ class TestResetProgressState:
|
||||
def test_reset_default_workflow_id_none(self):
|
||||
reset_progress_state("prompt-2", _DummyDynPrompt())
|
||||
assert get_progress_state().workflow_id is None
|
||||
|
||||
|
||||
class TestExecutionMessagePayloadsContainWorkflowId:
|
||||
"""Static-analysis guard ensuring every WebSocket message payload that
|
||||
carries `prompt_id` also carries `workflow_id`. This is a regression net
|
||||
for future refactors of execution.py / main.py / progress.py and avoids
|
||||
the GPU/torch dependency of importing `execution.py` directly.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _emitting_dicts(source: str):
|
||||
"""Yield every dict literal in `source` that contains a 'prompt_id' key."""
|
||||
import ast
|
||||
|
||||
tree = ast.parse(source)
|
||||
for node in ast.walk(tree):
|
||||
if not isinstance(node, ast.Dict):
|
||||
continue
|
||||
keys = [
|
||||
k.value
|
||||
for k in node.keys
|
||||
if isinstance(k, ast.Constant) and isinstance(k.value, str)
|
||||
]
|
||||
if "prompt_id" in keys:
|
||||
yield node, keys
|
||||
|
||||
def _assert_workflow_id_in_every_prompt_id_dict(self, file_path: str):
|
||||
from pathlib import Path
|
||||
|
||||
source = Path(file_path).read_text()
|
||||
offenders = []
|
||||
for node, keys in self._emitting_dicts(source):
|
||||
if "workflow_id" not in keys:
|
||||
offenders.append((node.lineno, keys))
|
||||
assert not offenders, (
|
||||
f"{file_path}: dict literals with 'prompt_id' but no 'workflow_id': {offenders}"
|
||||
)
|
||||
|
||||
def test_execution_py_payloads_include_workflow_id(self):
|
||||
self._assert_workflow_id_in_every_prompt_id_dict("execution.py")
|
||||
|
||||
def test_main_py_payloads_include_workflow_id(self):
|
||||
self._assert_workflow_id_in_every_prompt_id_dict("main.py")
|
||||
|
||||
def test_progress_py_payloads_include_workflow_id(self):
|
||||
self._assert_workflow_id_in_every_prompt_id_dict("comfy_execution/progress.py")
|
||||
|
||||
|
||||
class TestPreviewImageMetadataPayload:
|
||||
"""Verify PREVIEW_IMAGE_WITH_METADATA metadata carries workflow_id."""
|
||||
|
||||
def test_preview_metadata_includes_workflow_id(self):
|
||||
from unittest.mock import MagicMock, patch
|
||||
from PIL import Image
|
||||
|
||||
from comfy_execution.progress import (
|
||||
NodeState,
|
||||
ProgressRegistry,
|
||||
WebUIProgressHandler,
|
||||
)
|
||||
|
||||
class _DynPrompt:
|
||||
def get_display_node_id(self, n):
|
||||
return n
|
||||
|
||||
def get_parent_node_id(self, n):
|
||||
return None
|
||||
|
||||
def get_real_node_id(self, n):
|
||||
return n
|
||||
|
||||
server = MagicMock()
|
||||
server.client_id = "cid"
|
||||
server.sockets_metadata = {}
|
||||
|
||||
registry = ProgressRegistry(
|
||||
prompt_id="p1", dynprompt=_DynPrompt(), workflow_id="wf-1"
|
||||
)
|
||||
handler = WebUIProgressHandler(server)
|
||||
handler.set_registry(registry)
|
||||
|
||||
image = ("PNG", Image.new("RGB", (1, 1)), None)
|
||||
|
||||
with patch(
|
||||
"comfy_execution.progress.feature_flags.supports_feature",
|
||||
return_value=True,
|
||||
):
|
||||
handler.update_handler(
|
||||
node_id="n1",
|
||||
value=1.0,
|
||||
max_value=1.0,
|
||||
state={
|
||||
"state": NodeState.Running,
|
||||
"value": 1.0,
|
||||
"max": 1.0,
|
||||
},
|
||||
prompt_id="p1",
|
||||
image=image,
|
||||
)
|
||||
|
||||
preview_calls = [
|
||||
c
|
||||
for c in server.send_sync.call_args_list
|
||||
if c.args[0] != "progress_state"
|
||||
]
|
||||
assert len(preview_calls) == 1
|
||||
_, payload, _ = preview_calls[0].args
|
||||
_, metadata = payload
|
||||
assert metadata["prompt_id"] == "p1"
|
||||
assert metadata["workflow_id"] == "wf-1"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user