Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions python/packages/core/agent_framework/_workflows/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,13 @@ async def _run_core(
yield event

elif checkpoint_id is not None:
# Restore the prior workflow state from the checkpoint. Shared
# state (e.g. accumulated conversation history maintained by the
# workflow's executors) survives across turns because Workflow.run
# no longer wipes state per call. Callers who want to deliver a
# new user message after restore should make a second
# `workflow.run(message=...)` call - they are NOT mutually
# exclusive on the same instance, but each must be its own call.
if streaming:
async for event in self.workflow.run(
stream=True,
Expand Down
7 changes: 6 additions & 1 deletion python/packages/core/agent_framework/_workflows/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,12 @@ async def restore_from_checkpoint(
"Please rebuild the original workflow before resuming."
)

# Restore state
# Restore state. Clear first so import_state (which merges) does
# not leak stale keys from a prior run on this Workflow instance.
# This matters more now that Workflow.run() no longer wipes state
# per call - the only reset point for shared state on a reused
# instance is at restore time.
self._state.clear()
self._state.import_state(checkpoint.state)
# Restore executor states using the restored state
await self._restore_executor_states()
Expand Down
85 changes: 58 additions & 27 deletions python/packages/core/agent_framework/_workflows/_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def get_executors_list(self) -> list[Executor]:
async def _run_workflow_with_tracing(
self,
initial_executor_fn: Callable[[], Awaitable[None]] | None = None,
reset_context: bool = True,
is_continuation: bool = False,
streaming: bool = False,
function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
Expand All @@ -310,13 +310,19 @@ async def _run_workflow_with_tracing(
of external callers to maintain context across different workflow runs.

Args:
initial_executor_fn: Optional function to execute initial executor
reset_context: Whether to reset the context for a new run
streaming: Whether to enable streaming mode for agents
initial_executor_fn: Optional function to execute initial executor.
is_continuation: True when this run is a continuation of prior
work (a checkpoint restore or a responses-only replay) rather
than a fresh new turn delivered via the start executor with
``message=...``. Continuations preserve per-run accounting
(iteration counter and run kwargs) from the prior turn;
fresh-message runs reset them. Shared workflow state is
preserved in both cases.
streaming: Whether to enable streaming mode for agents.
function_invocation_kwargs: Optional kwargs to store in State for function
invocations in subagents
invocations in subagents.
client_kwargs: Optional kwargs to store in State for chat client
invocations in subagents
invocations in subagents.

Yields:
WorkflowEvent: The events generated during the workflow execution.
Expand Down Expand Up @@ -345,16 +351,26 @@ async def _run_workflow_with_tracing(
in_progress = WorkflowEvent.status(WorkflowRunState.IN_PROGRESS)
yield in_progress # noqa: RUF070

# Reset context for a new run if supported
if reset_context:
# Per-run reset for fresh-message runs only. We deliberately
# do NOT clear shared workflow state (`_state.clear()`) or the
# runner context's in-flight messages (`reset_for_new_run()`)
# here - state and pending work persist across `run()` calls
# so that a `WorkflowAgent` can deliver multi-turn input on
# the same instance and have prior turns' context survive.
# Iteration counting and per-run kwargs ARE per-run though,
# so they're reset here.
if not is_continuation:
self._runner.reset_iteration_count()
self._runner.context.reset_for_new_run()
self._state.clear()

# Store run kwargs in State so executors can access them.
# Only overwrite when new kwargs are explicitly provided or state was
# just cleared (fresh run). On continuation (reset_context=False) with
# no new kwargs, preserve the kwargs from the original run.
# Per-run kwargs semantics:
# - On a fresh message run, prior kwargs go away (set to {}
# by default, or to the new kwargs if provided). This
# prevents stale kwargs from a prior turn leaking into the
# current turn.
# - On a continuation (checkpoint restore or responses), the
# prior run's kwargs are preserved unless the caller
# explicitly provides new kwargs.
if function_invocation_kwargs is not None or client_kwargs is not None:
combined_kwargs: dict[str, Any] = {}
if function_invocation_kwargs is not None:
Expand All @@ -366,11 +382,12 @@ async def _run_workflow_with_tracing(
client_kwargs, "client_kwargs"
)
self._state.set(WORKFLOW_RUN_KWARGS_KEY, combined_kwargs)
elif reset_context:
elif not is_continuation:
self._state.set(WORKFLOW_RUN_KWARGS_KEY, {})
self._state.commit() # Commit immediately so kwargs are available

# Set streaming mode after reset
# Set streaming mode (always set explicitly per run since
# reset_for_new_run() no longer runs to clear it).
self._runner_context.set_streaming(streaming)

# Execute initial setup if provided
Expand Down Expand Up @@ -585,13 +602,33 @@ async def _run_core(
if checkpoint_storage is not None:
self._runner.context.set_runtime_checkpoint_storage(checkpoint_storage)

initial_executor_fn, reset_context = self._resolve_execution_mode(
# Async validation: a fresh-message run is only allowed when the
# runner context has fully drained from any prior run. If it still
# has in-flight executor messages, the prior run didn't complete -
# the caller must either resume from a checkpoint or wait for the
# prior run to drain. (Pending request_info events are intentionally
# NOT blocked here: a follow-up run with message=... is the normal
# way to deliver a response to those pending requests, e.g. via
# WorkflowAgent._process_pending_requests.)
# NOTE: _validate_run_params already enforces that ``message`` is
# mutually exclusive with both ``checkpoint_id`` and ``responses``,
# so we don't need to re-check those here.
if message is not None and await self._runner.context.has_messages():
raise RuntimeError(
"Cannot start a new run with 'message' while in-flight executor "
"messages remain from a prior run. Resume from a checkpoint "
"(checkpoint_id=...) or wait for the prior run to complete. "
"Workflows that need to recover from a mid-run failure must use "
"checkpointing; there is no in-process recovery path."
)

initial_executor_fn = self._resolve_execution_mode(
message, responses, checkpoint_id, checkpoint_storage
)

async for event in self._run_workflow_with_tracing(
initial_executor_fn=initial_executor_fn,
reset_context=reset_context,
is_continuation=(message is None),
streaming=streaming,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
Expand Down Expand Up @@ -674,12 +711,8 @@ def _resolve_execution_mode(
responses: Mapping[str, Any] | None,
checkpoint_id: str | None,
checkpoint_storage: CheckpointStorage | None,
) -> tuple[Callable[[], Awaitable[None]], bool]:
"""Determine the initial executor function and reset_context flag based on parameters.

Returns:
A tuple of (initial_executor_fn, reset_context).
"""
) -> Callable[[], Awaitable[None]]:
"""Determine the initial executor function based on parameters."""
if responses is not None:
if checkpoint_id is not None:
# Combined: restore checkpoint then send responses
Expand All @@ -689,13 +722,11 @@ def _resolve_execution_mode(
else:
# Send responses only (requires pending requests in workflow state)
initial_executor_fn = functools.partial(self._send_responses_internal, responses)
return initial_executor_fn, False
return initial_executor_fn
# Regular run or checkpoint restoration
initial_executor_fn = functools.partial(
return functools.partial(
self._execute_with_message_or_checkpoint, message, checkpoint_id, checkpoint_storage
)
reset_context = message is not None and checkpoint_id is None
return initial_executor_fn, reset_context

async def _restore_and_send_responses(
self,
Expand Down
80 changes: 64 additions & 16 deletions python/packages/core/tests/workflow/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,8 +488,13 @@ async def handle_message(
await ctx.yield_output(existing_messages.copy()) # type: ignore


async def test_workflow_multiple_runs_no_state_collision():
"""Test that running the same workflow instance multiple times doesn't have state collision."""
async def test_workflow_multiple_runs_preserve_state():
"""Test that running the same workflow instance multiple times preserves shared state.

State preservation is the new default - calling ``Workflow.run`` repeatedly
on the same instance behaves like a chat agent maintaining memory across
turns. Callers that want fresh state should rebuild the Workflow.
"""
with tempfile.TemporaryDirectory() as temp_dir:
storage = FileCheckpointStorage(temp_dir)

Expand All @@ -503,29 +508,45 @@ async def test_workflow_multiple_runs_no_state_collision():
.build()
)

# Run 1: Should only see messages from run 1
# Run 1: Single record from run 1
result1 = await workflow.run(StateTrackingMessage(data="message1", run_id="run1"))
assert result1.get_final_state() == WorkflowRunState.IDLE
outputs1 = result1.get_outputs()
assert outputs1[0] == ["run1:message1"]

# Run 2: Should only see messages from run 2, not run 1
# Run 2: State from run 1 persists; run 2's record appends.
result2 = await workflow.run(StateTrackingMessage(data="message2", run_id="run2"))
assert result2.get_final_state() == WorkflowRunState.IDLE
outputs2 = result2.get_outputs()
assert outputs2[0] == ["run2:message2"] # Should NOT contain run1 data
assert outputs2[0] == ["run1:message1", "run2:message2"]

# Run 3: Should only see messages from run 3
# Run 3: Same - all three accumulate.
result3 = await workflow.run(StateTrackingMessage(data="message3", run_id="run3"))
assert result3.get_final_state() == WorkflowRunState.IDLE
outputs3 = result3.get_outputs()
assert outputs3[0] == ["run3:message3"] # Should NOT contain run1 or run2 data
assert outputs3[0] == ["run1:message1", "run2:message2", "run3:message3"]


async def test_workflow_multiple_runs_no_state_collision_after_rebuild():
"""Rebuilding the Workflow gives a fresh shared-state slate."""
with tempfile.TemporaryDirectory() as temp_dir:
storage = FileCheckpointStorage(temp_dir)

def _build():
executor = StateTrackingExecutor(id="state_executor")
return (
WorkflowBuilder(start_executor=executor, checkpoint_storage=storage)
.add_edge(executor, executor)
.build()
)

# Verify that each run only processed its own message
# This confirms that the checkpointable context properly resets between runs
assert outputs1[0] != outputs2[0]
assert outputs2[0] != outputs3[0]
assert outputs1[0] != outputs3[0]
wf1 = _build()
result1 = await wf1.run(StateTrackingMessage(data="message1", run_id="run1"))
assert result1.get_outputs()[0] == ["run1:message1"]

wf2 = _build()
result2 = await wf2.run(StateTrackingMessage(data="message2", run_id="run2"))
assert result2.get_outputs()[0] == ["run2:message2"]


async def test_workflow_checkpoint_runtime_only_configuration(
Expand Down Expand Up @@ -932,6 +953,31 @@ async def test_agent_streaming_vs_non_streaming() -> None:
assert accumulated_text == "Hello World", f"Expected 'Hello World', got '{accumulated_text}'"


async def test_workflow_run_inflight_messages_guard(simple_executor: Executor) -> None:
"""``run(message=...)`` must reject in-flight executor messages from a prior run.

Workflows preserve state and pending messages across :meth:`Workflow.run`
calls. If a prior run aborted before the runner drained those pending
messages (e.g. it raised :class:`WorkflowConvergenceException`), the next
fresh-message call should fail loudly instead of silently mixing the
leftover messages with the new turn. The supported recovery path is to
resume from a checkpoint; there is no in-process recovery hatch.
"""
workflow = WorkflowBuilder(start_executor=simple_executor).add_edge(simple_executor, simple_executor).build()
test_message = WorkflowMessage(data="test", source_id="test", target_id=None)

# Simulate an aborted prior run by leaving a message in the runner context.
workflow._runner.context._messages["test"] = [test_message]
assert await workflow._runner.context.has_messages()

with pytest.raises(RuntimeError, match="in-flight executor messages"):
await workflow.run(test_message)

with pytest.raises(RuntimeError, match="in-flight executor messages"):
async for _ in workflow.run(test_message, stream=True):
pass


async def test_workflow_run_parameter_validation(simple_executor: Executor) -> None:
"""Test that stream properly validate parameter combinations."""
workflow = WorkflowBuilder(start_executor=simple_executor).add_edge(simple_executor, simple_executor).build()
Expand All @@ -942,13 +988,15 @@ async def test_workflow_run_parameter_validation(simple_executor: Executor) -> N
result = await workflow.run(test_message)
assert result.get_final_state() == WorkflowRunState.IDLE

# Invalid: both message and checkpoint_id
# Invalid: message + checkpoint_id (mutually exclusive). Multi-turn
# state preservation is handled by Workflow.run preserving state across
# calls, so the host pattern is two separate calls (restore-then-run),
# not a single combined call.
with pytest.raises(ValueError, match="Cannot provide both 'message' and 'checkpoint_id'"):
await workflow.run(test_message, checkpoint_id="fake_id")
await workflow.run(test_message, checkpoint_id="some-checkpoint")

# Invalid: both message and checkpoint_id (streaming)
with pytest.raises(ValueError, match="Cannot provide both 'message' and 'checkpoint_id'"):
async for _ in workflow.run(test_message, checkpoint_id="fake_id", stream=True):
async for _ in workflow.run(test_message, checkpoint_id="some-checkpoint", stream=True):
pass

# Invalid: none of message or checkpoint_id
Expand Down
Loading
Loading