diff --git a/python/packages/core/agent_framework/__init__.py b/python/packages/core/agent_framework/__init__.py index 05f65873bc..364f62eae1 100644 --- a/python/packages/core/agent_framework/__init__.py +++ b/python/packages/core/agent_framework/__init__.py @@ -213,6 +213,15 @@ handler, ) from ._workflows._function_executor import FunctionExecutor, executor +from ._workflows._functional import ( + FunctionalWorkflow, + FunctionalWorkflowAgent, + RunContext, + StepWrapper, + get_run_context, + step, + workflow, +) from ._workflows._request_info_mixin import response_handler from ._workflows._runner import Runner from ._workflows._runner_context import ( @@ -332,6 +341,8 @@ "FunctionMiddleware", "FunctionMiddlewareTypes", "FunctionTool", + "FunctionalWorkflow", + "FunctionalWorkflowAgent", "GeneratedEmbeddings", "GraphConnectivityError", "HistoryProvider", @@ -354,6 +365,7 @@ "ResponseStream", "Role", "RoleLiteral", + "RunContext", "Runner", "RunnerContext", "SecretString", @@ -366,6 +378,7 @@ "SkillScriptRunner", "SkillsProvider", "SlidingWindowStrategy", + "StepWrapper", "SubWorkflowRequestMessage", "SubWorkflowResponseMessage", "SummarizationStrategy", @@ -424,6 +437,7 @@ "evaluator", "executor", "function_middleware", + "get_run_context", "handler", "included_messages", "included_token_count", @@ -439,6 +453,7 @@ "register_state_type", "resolve_agent_id", "response_handler", + "step", "tool", "tool_call_args_match", "tool_called_check", @@ -447,4 +462,5 @@ "validate_tool_mode", "validate_tools", "validate_workflow_graph", + "workflow", ] diff --git a/python/packages/core/agent_framework/_feature_stage.py b/python/packages/core/agent_framework/_feature_stage.py index 761b7860a4..ef7dfd3687 100644 --- a/python/packages/core/agent_framework/_feature_stage.py +++ b/python/packages/core/agent_framework/_feature_stage.py @@ -48,6 +48,7 @@ class ExperimentalFeature(str, Enum): EVALS = "EVALS" FILE_HISTORY = "FILE_HISTORY" + FUNCTIONAL_WORKFLOWS = "FUNCTIONAL_WORKFLOWS" SKILLS = "SKILLS" TOOLBOXES = "TOOLBOXES" diff --git a/python/packages/core/agent_framework/_workflows/_events.py b/python/packages/core/agent_framework/_workflows/_events.py index d26952d8e5..4b8238268c 100644 --- a/python/packages/core/agent_framework/_workflows/_events.py +++ b/python/packages/core/agent_framework/_workflows/_events.py @@ -120,6 +120,7 @@ def from_exception( "executor_invoked", # Executor handler was called (use .executor_id, .data) "executor_completed", # Executor handler completed (use .executor_id, .data) "executor_failed", # Executor handler raised error (use .executor_id, .details) + "executor_bypassed", # Executor skipped via cache hit during replay (use .executor_id, .data) # Orchestration event types (use .data for typed payload) "group_chat", # Group chat orchestrator events (use .data as GroupChatRequestSentEvent | GroupChatResponseReceivedEvent) # noqa: E501 "handoff_sent", # Handoff routing events (use .data as HandoffSentEvent) @@ -148,6 +149,7 @@ class WorkflowEvent(Generic[DataT]): - `WorkflowEvent.executor_invoked(executor_id)` - executor handler called - `WorkflowEvent.executor_completed(executor_id)` - executor handler completed - `WorkflowEvent.executor_failed(executor_id, details)` - executor handler failed + - `WorkflowEvent.executor_bypassed(executor_id)` - executor skipped via cache hit The generic parameter DataT represents the type of the event's data payload: - Lifecycle events: `WorkflowEvent[None]` (data is None) @@ -318,6 +320,11 @@ def executor_failed(cls, executor_id: str, details: WorkflowErrorDetails) -> Wor """Create an 'executor_failed' event when an executor handler raises an error.""" return WorkflowEvent("executor_failed", executor_id=executor_id, data=details, details=details) + @classmethod + def executor_bypassed(cls, executor_id: str, data: DataT | None = None) -> WorkflowEvent[DataT]: + """Create an 'executor_bypassed' event when a step is skipped via cache hit during replay.""" + return cls("executor_bypassed", executor_id=executor_id, data=data) + # ========================================================================== # Property for type-safe access # ========================================================================== diff --git a/python/packages/core/agent_framework/_workflows/_functional.py b/python/packages/core/agent_framework/_workflows/_functional.py new file mode 100644 index 0000000000..159d75e137 --- /dev/null +++ b/python/packages/core/agent_framework/_workflows/_functional.py @@ -0,0 +1,1551 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Functional workflow API for writing workflows as plain async functions. + +.. warning:: Experimental + + This API is experimental and subject to change or removal + in future versions without notice. + +This module provides the ``@workflow`` and ``@step`` decorators that let users +define workflows using native Python control flow (if/else, loops, +``asyncio.gather``) instead of a graph-based topology. + +A ``@workflow``-decorated async function receives its input as the first +positional argument. If the function needs HITL (``request_info``), custom +events, or key/value state, add a :class:`RunContext` parameter — otherwise it +can be omitted. Inside the workflow, plain ``async`` calls run normally. +Optionally, ``@step``-decorated functions gain caching, per-step checkpointing, +and event emission. ``@step`` functions may also declare a ``RunContext`` +parameter to access HITL and state APIs directly. + +Key public symbols: + +* :func:`workflow` / :class:`FunctionalWorkflow` — decorator and runtime. +* :func:`step` / :class:`StepWrapper` — optional step decorator. +* :class:`RunContext` — execution context injected into workflow and step + functions. +* :func:`get_run_context` — retrieve the active ``RunContext`` from anywhere + inside a running workflow. +* :class:`FunctionalWorkflowAgent` — agent adapter returned by + :meth:`FunctionalWorkflow.as_agent`. +""" + +from __future__ import annotations + +# pyright: reportPrivateUsage=false +# Classes in this module (RunContext, StepWrapper, FunctionalWorkflow) form a +# cohesive unit and intentionally access each other's underscore-prefixed members. +import functools +import hashlib +import inspect +import logging +import typing +from collections.abc import AsyncIterable, Awaitable, Callable, Sequence +from contextvars import ContextVar +from copy import deepcopy +from typing import Any, Generic, Literal, TypeVar, overload + +from .._feature_stage import ExperimentalFeature, experimental +from .._types import AgentResponse, AgentResponseUpdate, ResponseStream +from ..observability import OtelAttr, capture_exception, create_workflow_span +from ._checkpoint import CheckpointStorage, WorkflowCheckpoint +from ._events import ( + WorkflowErrorDetails, + WorkflowEvent, + WorkflowRunState, + _framework_event_origin, # type: ignore[reportPrivateUsage] +) +from ._workflow import WorkflowRunResult + +logger = logging.getLogger(__name__) + +R = TypeVar("R") + +# ContextVar holding the active RunContext during workflow execution. +# ContextVar is per-asyncio-Task, so concurrent workflows each get their own context. +_active_run_ctx: ContextVar[RunContext | None] = ContextVar("_active_run_ctx", default=None) + + +@experimental(feature_id=ExperimentalFeature.FUNCTIONAL_WORKFLOWS) +def get_run_context() -> RunContext | None: + """Return the active :class:`RunContext`, or ``None`` if not inside a ``@workflow``. + + This is useful inside ``@step`` functions (or any code called from a + workflow) that need access to HITL, state, or event APIs without + requiring a ``RunContext`` parameter. + """ + return _active_run_ctx.get() + + +# --------------------------------------------------------------------------- +# Internal exception for HITL interruption +# --------------------------------------------------------------------------- + + +class WorkflowInterrupted(BaseException): + """Internal: raised when request_info() is called during initial execution. + + Inherits from ``BaseException`` (not ``Exception``) so that user code + with ``except Exception:`` handlers inside a ``@workflow`` function does + not accidentally intercept the HITL interruption signal. + """ + + def __init__(self, request_id: str, request_data: Any, response_type: type) -> None: + self.request_id = request_id + self.request_data = request_data + self.response_type = response_type + super().__init__(f"Workflow interrupted by request_info (request_id={request_id})") + + +# --------------------------------------------------------------------------- +# RunContext +# --------------------------------------------------------------------------- + + +@experimental(feature_id=ExperimentalFeature.FUNCTIONAL_WORKFLOWS) +class RunContext: + """Opt-in handle for workflow-only features inside a ``@workflow`` function. + + Use ``RunContext`` when a workflow function needs one of the following, + otherwise omit it entirely for a cleaner signature: + + * Human-in-the-loop: :meth:`request_info` pauses the workflow until a + response is supplied, then resumes with that value. + * Custom events: :meth:`add_event` emits events into the run stream + (useful for progress reporting or tracing). + * Workflow-scoped key/value state: :meth:`get_state` / :meth:`set_state` + persist values across a run and survive checkpoints. + + The context is injected automatically. Declare it either by parameter + name (``ctx``) or by type annotation (``: RunContext``); both work. + + Args: + workflow_name: Identifier for the enclosing workflow, used when + generating events and checkpoint metadata. + streaming: Whether the current run was started with ``stream=True``. + run_kwargs: Extra keyword arguments forwarded from + :meth:`FunctionalWorkflow.run`. + + Examples: + + .. code-block:: python + + # Simple workflow: no context parameter needed. + @workflow + async def my_pipeline(data: str) -> str: + return await some_step(data) + + + # HITL workflow: request a response from a human reviewer. + @workflow + async def hitl_pipeline(data: str, ctx: RunContext) -> str: + feedback = await ctx.request_info({"draft": data}, response_type=str) + return feedback + + + # RunContext also works inside @step functions. + @step + async def review_step(doc: str, ctx: RunContext) -> str: + feedback = await ctx.request_info({"draft": doc}, response_type=str) + return feedback + """ + + def __init__( + self, + workflow_name: str, + *, + streaming: bool = False, + run_kwargs: dict[str, Any] | None = None, + ) -> None: + self._workflow_name = workflow_name + self._streaming = streaming + self._run_kwargs = run_kwargs or {} + + # Event accumulator + self._events: list[WorkflowEvent[Any]] = [] + + # Step result cache: (step_name, call_index) -> result + self._step_cache: dict[tuple[str, int], Any] = {} + # Cached step metadata used to keep auto-generated request_info IDs in sync on bypass. + self._step_cache_auto_request_info_counts: dict[tuple[str, int], int] = {} + # Per-step call counters for deterministic cache keys + self._step_call_counters: dict[str, int] = {} + # Deterministic call counter for auto-generated request_info IDs + self._auto_request_info_index: int = 0 + + # HITL responses (set via _set_responses before replay) + self._responses: dict[str, Any] = {} + # Pending request_info events (for checkpointing) + self._pending_requests: dict[str, WorkflowEvent[Any]] = {} + + # User state (simple dict) + self._state: dict[str, Any] = {} + + # Callback invoked after each step completes (set by FunctionalWorkflow) + self._on_step_completed: Callable[[], Awaitable[None]] | None = None + + # ------------------------------------------------------------------ + # Public API (for @workflow functions) + # ------------------------------------------------------------------ + + async def request_info( + self, + request_data: Any, + response_type: type, + *, + request_id: str | None = None, + ) -> Any: + """Request external information (human-in-the-loop). + + On first execution this suspends the workflow by raising an internal + ``WorkflowInterrupted`` signal (caught by the framework, never exposed + to user code). The caller receives a ``WorkflowRunResult`` (or a + ``ResponseStream`` when ``stream=True``) whose + :meth:`~WorkflowRunResult.get_request_info_events` contains the pending + request. When the workflow is resumed with + ``run(responses={request_id: value})``, the same function re-executes + and ``request_info`` returns the provided *value* directly. + + Args: + request_data: Arbitrary payload describing what information is + needed (e.g. a Pydantic model, dict, or string prompt). + response_type: The expected Python type of the response value. + request_id: Optional stable identifier for this request. If + omitted, a deterministic identifier is derived from the call + order (``auto::``) so that resume works without the + caller needing to echo back an explicit ID. + + Returns: + The response value supplied during replay. ``None`` is allowed + but triggers a warning — prefer a sentinel value when the + absence of data is meaningful. + + Raises: + WorkflowInterrupted: Raised internally on initial execution + (not visible to workflow authors). + """ + if request_id is None: + # Deterministic id; same determinism contract as @step caching. + rid = f"auto::{self._auto_request_info_index}" + self._auto_request_info_index += 1 + else: + rid = request_id + + found, value = self._get_response(rid) + if found: + self._pending_requests.pop(rid, None) + return value + + # No response — emit event and interrupt + event = WorkflowEvent.request_info( + request_id=rid, + source_executor_id=self._workflow_name, + request_data=request_data, + response_type=response_type, + ) + await self.add_event(event) + self._pending_requests[rid] = event + raise WorkflowInterrupted(rid, request_data, response_type) + + async def add_event(self, event: WorkflowEvent[Any]) -> None: + """Add a custom event to the workflow event stream. + + Use this to inject application-specific events alongside the + framework-generated lifecycle events. + + Args: + event: The workflow event to append. + """ + self._events.append(event) + + def get_state(self, key: str, default: Any = None) -> Any: + """Retrieve a value from the workflow's key/value state. + + State values are persisted across HITL interruptions and are included + in checkpoints when checkpoint storage is configured. + + Args: + key: The state key to look up. + default: Value returned when *key* is absent. + + Returns: + The stored value, or *default* if the key does not exist. + """ + return self._state.get(key, default) + + def set_state(self, key: str, value: Any) -> None: + """Store a value in the workflow's key/value state. + + Args: + key: The state key. Must not start with ``_`` — framework + bookkeeping (e.g. ``_step_cache``, ``_original_message``) uses + the underscore prefix and user keys in that namespace are + silently clobbered by checkpoint save and dropped on + checkpoint restore. Use names without a leading underscore + for user state. + value: The value to store. Must be JSON-serializable if + checkpoint storage is used. + + Raises: + ValueError: If *key* begins with ``_`` (reserved for framework + bookkeeping). + """ + if key.startswith("_"): + raise ValueError( + f"State key {key!r} starts with '_', which is reserved for " + f"framework bookkeeping (e.g. '_step_cache', '_original_message') " + f"and would be silently dropped on checkpoint restore. Use a " + f"non-underscore-prefixed key for user state." + ) + self._state[key] = value + + def is_streaming(self) -> bool: + """Return whether the current run was started with ``stream=True``. + + Returns: + ``True`` if the workflow is running in streaming mode. + """ + return self._streaming + + # ------------------------------------------------------------------ + # Internal API (for StepWrapper and FunctionalWorkflow) + # ------------------------------------------------------------------ + + def _get_events(self) -> list[WorkflowEvent[Any]]: + return list(self._events) + + def _get_step_cache_key(self, step_name: str) -> tuple[str, int]: + idx = self._step_call_counters.get(step_name, 0) + self._step_call_counters[step_name] = idx + 1 + return (step_name, idx) + + def _get_cached_result(self, key: tuple[str, int]) -> tuple[bool, Any]: + if key in self._step_cache: + return True, self._step_cache[key] + return False, None + + def _set_cached_result(self, key: tuple[str, int], value: Any) -> None: + self._step_cache[key] = value + + def _set_cached_step_auto_request_info_count(self, key: tuple[str, int], count: int) -> None: + self._step_cache_auto_request_info_counts[key] = count + + def _advance_auto_request_info_index_for_cached_step(self, key: tuple[str, int]) -> None: + self._auto_request_info_index += self._step_cache_auto_request_info_counts.get(key, 0) + + def _set_responses(self, responses: dict[str, Any]) -> None: + for rid, value in responses.items(): + if value is None: + logger.warning( + "Response for request_id=%r is None. If this is intentional, " + "consider using a sentinel value instead.", + rid, + ) + self._responses = dict(responses) + # Remove resolved requests from the pending set so downstream + # checkpoints don't re-serialize them as still-pending. + for rid in responses: + self._pending_requests.pop(rid, None) + + def _get_response(self, request_id: str) -> tuple[bool, Any]: + """Look up a HITL response by *request_id*. + + Returns: + A ``(found, value)`` tuple. When *found* is ``True``, *value* is + the caller-supplied response (which **may be** ``None`` — a warning + is logged by :meth:`_set_responses` in that case). When *found* is + ``False``, *value* is always ``None`` and simply means no response + has been provided yet. + """ + if request_id in self._responses: + return True, self._responses[request_id] + return False, None + + def _export_step_cache(self) -> dict[str, Any]: + """Serialize the step cache for checkpointing. + + Converts tuple keys to strings for JSON compatibility. + """ + return {f"{name}::{idx}": val for (name, idx), val in self._step_cache.items()} + + def _export_step_cache_auto_request_info_counts(self) -> dict[str, int]: + """Serialize per-step auto request_info counts for checkpointing.""" + return {f"{name}::{idx}": count for (name, idx), count in self._step_cache_auto_request_info_counts.items()} + + def _import_step_cache(self, data: dict[str, Any]) -> None: + """Restore step cache from checkpoint data.""" + self._step_cache = {} + for k, v in data.items(): + try: + name, idx_str = k.rsplit("::", 1) + self._step_cache[name, int(idx_str)] = v + except (ValueError, TypeError) as exc: + raise ValueError( + f"Corrupted step cache entry in checkpoint: key={k!r}. " + f"The checkpoint may be from an incompatible version or corrupted. " + f"Original error: {exc}" + ) from exc + + def _import_step_cache_auto_request_info_counts(self, data: dict[str, Any]) -> None: + """Restore per-step auto request_info counts from checkpoint data.""" + self._step_cache_auto_request_info_counts = {} + for k, v in data.items(): + try: + name, idx_str = k.rsplit("::", 1) + self._step_cache_auto_request_info_counts[name, int(idx_str)] = int(v) + except (ValueError, TypeError) as exc: + raise ValueError( + f"Corrupted step cache request_info metadata in checkpoint: key={k!r}, value={v!r}. " + f"The checkpoint may be from an incompatible version or corrupted. " + f"Original error: {exc}" + ) from exc + + +# --------------------------------------------------------------------------- +# StepWrapper +# --------------------------------------------------------------------------- + + +@experimental(feature_id=ExperimentalFeature.FUNCTIONAL_WORKFLOWS) +class StepWrapper(Generic[R]): + """Wrapper returned by the ``@step`` decorator. + + When called inside a running ``@workflow`` function, the wrapper + intercepts execution to provide: + + * **Caching** — results are cached by ``(step_name, call_index)`` so + that HITL replay and checkpoint restore skip already-completed work. + On cache hit a single ``executor_bypassed`` event is emitted instead + of the normal ``executor_invoked`` / ``executor_completed`` pair. + * **Event emission** — ``executor_invoked`` / ``executor_completed`` / + ``executor_failed`` events are emitted for observability. + * **RunContext injection** — if the step function declares a parameter + annotated as :class:`RunContext` (or named ``ctx``), the active + context is automatically injected, giving step functions access to + HITL, state, and event APIs. + * **Per-step checkpointing** — a checkpoint is saved after each live + execution when checkpoint storage is configured. + + Outside a workflow the wrapper is transparent: it delegates directly to + the original function, making decorated functions fully testable in + isolation. + + Args: + func: The async function to wrap. + name: Optional display name. Defaults to ``func.__name__``. + + Raises: + TypeError: If *func* is not an async (coroutine) function. + """ + + def __init__(self, func: Callable[..., Awaitable[R]], *, name: str | None = None) -> None: + if not inspect.iscoroutinefunction(func): + raise TypeError( + f"@step can only decorate async functions, but '{func.__name__}' is not a coroutine function." + ) + self._func = func + self.name: str = name or func.__name__ + self._signature = inspect.signature(func) + functools.update_wrapper(self, func) + + # Detect RunContext parameter for auto-injection inside workflows + self._ctx_param_name: str | None = None + try: + hints = typing.get_type_hints(func) + except Exception: + hints = {} + for param_name, param in self._signature.parameters.items(): + if param.kind not in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ): + continue + resolved = hints.get(param_name, param.annotation) + if resolved is RunContext or param_name == "ctx": + self._ctx_param_name = param_name + break + + def _build_call_args_with_ctx( + self, + ctx: RunContext, + args: tuple[Any, ...], + kwargs: dict[str, Any], + ) -> tuple[tuple[Any, ...], dict[str, Any]]: + """Inject RunContext without consuming a user positional argument.""" + if self._ctx_param_name is None or self._ctx_param_name in kwargs: + return args, dict(kwargs) + + call_args: list[Any] = [] + call_kwargs = dict(kwargs) + arg_index = 0 + + for param in self._signature.parameters.values(): + if param.name == self._ctx_param_name: + if param.kind == inspect.Parameter.KEYWORD_ONLY: + call_kwargs[param.name] = ctx + else: + call_args.append(ctx) + continue + + if param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD): + if arg_index < len(args): + call_args.append(args[arg_index]) + arg_index += 1 + elif param.kind == inspect.Parameter.VAR_POSITIONAL: + call_args.extend(args[arg_index:]) + arg_index = len(args) + + if arg_index < len(args): + call_args.extend(args[arg_index:]) + + return tuple(call_args), call_kwargs + + async def __call__(self, *args: Any, **kwargs: Any) -> R: + ctx = _active_run_ctx.get() + if ctx is None: + # Outside a workflow — pass through directly + return await self._func(*args, **kwargs) + + cache_key = ctx._get_step_cache_key(self.name) + found, cached = ctx._get_cached_result(cache_key) + if found: + ctx._advance_auto_request_info_index_for_cached_step(cache_key) + # Dedicated bypass event so consumers can tell cache-hit replays + # apart from fresh executions. + await ctx.add_event(WorkflowEvent.executor_bypassed(self.name, cached)) + return cached # type: ignore[return-value, no-any-return] + + # Inject RunContext if the step function declares it + call_args, call_kwargs = self._build_call_args_with_ctx(ctx, args, kwargs) + + # Defensive deepcopy for the event log only; fall back to the live + # reference so non-deepcopyable args (locks, sockets) don't fail. + if args or kwargs: + try: + invocation_data: Any = deepcopy({"args": args, "kwargs": kwargs}) + except Exception: + invocation_data = {"args": args, "kwargs": kwargs} + else: + invocation_data = None + await ctx.add_event(WorkflowEvent.executor_invoked(self.name, invocation_data)) + auto_request_info_index_before = ctx._auto_request_info_index + try: + result = await self._func(*call_args, **call_kwargs) + except Exception as exc: + # NOTE: WorkflowInterrupted (from request_info inside a step) inherits + # from BaseException, NOT Exception, so it propagates past this handler + # without emitting a spurious executor_failed event. This is intentional + # — request_info is fully supported inside @step functions. + await ctx.add_event(WorkflowEvent.executor_failed(self.name, WorkflowErrorDetails.from_exception(exc))) + raise + ctx._set_cached_step_auto_request_info_count( + cache_key, + ctx._auto_request_info_index - auto_request_info_index_before, + ) + ctx._set_cached_result(cache_key, result) + await ctx.add_event(WorkflowEvent.executor_completed(self.name, result)) + if ctx._on_step_completed is not None: + await ctx._on_step_completed() + return result + + +# --------------------------------------------------------------------------- +# @step decorator +# --------------------------------------------------------------------------- + + +@overload +def step(func: Callable[..., Awaitable[R]]) -> StepWrapper[R]: ... + + +@overload +def step(*, name: str | None = None) -> Callable[[Callable[..., Awaitable[R]]], StepWrapper[R]]: ... + + +@experimental(feature_id=ExperimentalFeature.FUNCTIONAL_WORKFLOWS) +def step( + func: Callable[..., Awaitable[Any]] | None = None, + *, + name: str | None = None, +) -> StepWrapper[Any] | Callable[[Callable[..., Awaitable[Any]]], StepWrapper[Any]]: + """Decorator that marks an async function as a tracked workflow step. + + Supports both bare ``@step`` and parameterized ``@step(name="custom")`` + forms. Inside a running ``@workflow`` function, calls to a step are + intercepted for result caching, event emission, and per-step + checkpointing. If the step function declares a :class:`RunContext` + parameter (by type annotation or the name ``ctx``), the active context + is automatically injected, giving the step access to + :meth:`~RunContext.request_info`, state, and event APIs. Outside a + workflow the decorated function behaves identically to the original, + making it fully testable in isolation. + + The ``@step`` decorator is **optional**. Plain async functions work + inside ``@workflow`` without it; use ``@step`` only when you need + caching, checkpointing, or observability for a particular call. + + Args: + func: The async function to decorate (when using the bare + ``@step`` form). + name: Optional display name for the step. Defaults to the + function's ``__name__``. + + Returns: + A :class:`StepWrapper` (bare form) or a decorator that produces + one (parameterized form). + + Raises: + TypeError: If the decorated function is not async. + + Examples: + + .. code-block:: python + + @step + async def fetch_data(url: str) -> dict: + return await http_get(url) + + + @step(name="transform") + async def transform_data(raw: dict) -> str: + return json.dumps(raw) + + + # Step with HITL — RunContext is auto-injected inside a workflow: + @step + async def review(doc: str, ctx: RunContext) -> str: + return await ctx.request_info({"draft": doc}, response_type=str) + """ + if func is not None: + return StepWrapper(func, name=name) + + def _decorator(fn: Callable[..., Awaitable[Any]]) -> StepWrapper[Any]: + return StepWrapper(fn, name=name) + + return _decorator + + +# --------------------------------------------------------------------------- +# FunctionalWorkflow +# --------------------------------------------------------------------------- + + +@experimental(feature_id=ExperimentalFeature.FUNCTIONAL_WORKFLOWS) +class FunctionalWorkflow: + """A workflow backed by a user-defined async function. + + Created by the :func:`workflow` decorator. Exposes the same ``run()`` + interface as graph-based :class:`Workflow` objects, returning a + :class:`WorkflowRunResult` (or a :class:`ResponseStream` in streaming + mode). + + The underlying function is executed directly — no graph compilation or + edge wiring is involved. Native Python control flow (``if``/``else``, + ``for``, ``asyncio.gather``) is used for branching and parallelism. + + Args: + func: The async function that implements the workflow logic. + name: Display name for the workflow. Defaults to ``func.__name__``. + description: Optional human-readable description. + checkpoint_storage: Default :class:`CheckpointStorage` used for + persisting step results and state between runs. Can be + overridden per-run via the *checkpoint_storage* parameter of + :meth:`run`. + + Examples: + + .. code-block:: python + + @workflow + async def my_pipeline(data: str) -> str: + return await to_upper(data) + + + result = await my_pipeline.run("hello") + print(result.get_outputs()) # ['HELLO'] + """ + + def __init__( + self, + func: Callable[..., Awaitable[Any]], + *, + name: str | None = None, + description: str | None = None, + checkpoint_storage: CheckpointStorage | None = None, + ) -> None: + self._func = func + self.name = name or func.__name__ + self.description = description + self._checkpoint_storage = checkpoint_storage + self._is_running = False + # Replay state: cleared on clean completion so later responses-only + # calls can't silently replay with stale data from a prior run. + self._last_message: Any = None + self._last_step_cache: dict[tuple[str, int], Any] = {} + self._last_step_cache_auto_request_info_counts: dict[tuple[str, int], int] = {} + self._last_pending_request_ids: set[str] = set() + + # Signature arity is validated once at decoration time. + self._non_ctx_param_names = self._classify_signature(func) + + # Discover step names referenced in the function for signature hash + self._step_names = self._discover_step_names(func) + + # Compute a stable signature hash + self.graph_signature_hash = self._compute_signature_hash() + + functools.update_wrapper(self, func) # type: ignore[arg-type] + + @staticmethod + def _classify_signature(func: Callable[..., Any]) -> list[str]: + """Return the names of non-ctx parameters, validating arity. + + A workflow function may declare at most one non-ctx parameter (which + receives the caller-supplied ``message``). Any extra non-ctx + parameters would be silently dropped by ``_execute``, so we reject + them at decoration time. + """ + try: + hints = typing.get_type_hints(func) + except Exception: + hints = {} + non_ctx: list[str] = [] + for param_name, param in inspect.signature(func).parameters.items(): + resolved = hints.get(param_name, param.annotation) + if resolved is RunContext or param_name == "ctx": + continue + if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): + continue + non_ctx.append(param_name) + if len(non_ctx) > 1: + raise ValueError( + f"@workflow function '{func.__name__}' declares multiple non-RunContext " + f"parameters ({non_ctx}); at most one is supported (it receives the " + f"'message' argument passed to .run()). Combine the inputs into a " + f"single object or dict." + ) + return non_ctx + + # ------------------------------------------------------------------ + # run() — same overloaded interface as graph Workflow + # ------------------------------------------------------------------ + + @overload + def run( + self, + message: Any | None = None, + *, + stream: Literal[True], + responses: dict[str, Any] | None = None, + checkpoint_id: str | None = None, + checkpoint_storage: CheckpointStorage | None = None, + **kwargs: Any, + ) -> ResponseStream[WorkflowEvent[Any], WorkflowRunResult]: ... + + @overload + def run( + self, + message: Any | None = None, + *, + stream: Literal[False] = ..., + responses: dict[str, Any] | None = None, + checkpoint_id: str | None = None, + checkpoint_storage: CheckpointStorage | None = None, + include_status_events: bool = False, + **kwargs: Any, + ) -> Awaitable[WorkflowRunResult]: ... + + def run( + self, + message: Any | None = None, + *, + stream: bool = False, + responses: dict[str, Any] | None = None, + checkpoint_id: str | None = None, + checkpoint_storage: CheckpointStorage | None = None, + include_status_events: bool = False, + **kwargs: Any, + ) -> ResponseStream[WorkflowEvent[Any], WorkflowRunResult] | Awaitable[WorkflowRunResult]: + """Run the functional workflow. + + At least one of *message*, *responses*, or *checkpoint_id* must be + provided. *message* starts a fresh run; *responses* resumes after a + HITL interruption; *checkpoint_id* restores from a previously saved + checkpoint. *responses* may be combined with *checkpoint_id* to + restore a checkpoint and inject HITL responses in a single call. + *message* is mutually exclusive with both *responses* and + *checkpoint_id*. + + Args: + message: Input data passed as the first positional argument to + the workflow function. + stream: If ``True``, return a :class:`ResponseStream` that + yields :class:`WorkflowEvent` instances as they are produced. + responses: HITL responses keyed by ``request_id``, used to + resume a workflow that was suspended by + :meth:`RunContext.request_info`. + checkpoint_id: Identifier of a checkpoint to restore from. + Requires *checkpoint_storage* to be set (here or on the + decorator). + checkpoint_storage: Override the default checkpoint storage + for this run. + include_status_events: When ``True`` (non-streaming only), + include status-change events in the result. + + Keyword Args: + **kwargs: Extra keyword arguments stored on + :attr:`RunContext._run_kwargs` and accessible to step + functions. + + Returns: + A :class:`WorkflowRunResult` (non-streaming) or a + :class:`ResponseStream` (streaming). + + Raises: + ValueError: If the combination of *message*, *responses*, and + *checkpoint_id* is invalid. + RuntimeError: If the workflow is already running (concurrent + execution is not allowed). + """ + self._validate_run_params(message, responses, checkpoint_id) + if responses and checkpoint_id is None: + # Require at least one response key to match a currently-pending + # request; prevents silent replay against stale state while still + # allowing callers to accumulate prior answers across multi-round + # HITL. + if not self._last_pending_request_ids: + raise ValueError( + f"responses={list(responses)!r} do not correspond to any pending request on " + f"workflow '{self.name}'. The workflow has no pending request_info events, " + f"so there is nothing to resume. Start a fresh run with 'message', or supply " + f"'checkpoint_id' to restore a specific checkpoint." + ) + if not (set(responses) & self._last_pending_request_ids): + raise ValueError( + f"responses={list(responses)!r} do not answer any of the currently-pending " + f"requests on workflow '{self.name}' ({sorted(self._last_pending_request_ids)!r}). " + f"Provide a response keyed by one of the pending request_ids." + ) + self._ensure_not_running() + + response_stream: ResponseStream[WorkflowEvent[Any], WorkflowRunResult] = ResponseStream( + self._run_core( + message=message, + responses=responses, + checkpoint_id=checkpoint_id, + checkpoint_storage=checkpoint_storage, + streaming=stream, + **kwargs, + ), + finalizer=functools.partial(self._finalize_events, include_status_events=include_status_events), + cleanup_hooks=[self._run_cleanup], + ) + + if stream: + return response_stream + return response_stream.get_final_response() + + # ------------------------------------------------------------------ + # As agent + # ------------------------------------------------------------------ + + def as_agent( + self, + name: str | None = None, + *, + description: str | None = None, + context_providers: Sequence[Any] | None = None, + **kwargs: Any, + ) -> FunctionalWorkflowAgent: + """Wrap this workflow as an agent-compatible object. + + The returned :class:`FunctionalWorkflowAgent` exposes a ``run()`` + method that delegates to the workflow, surfaces ``request_info`` + events as function approval requests, and converts outputs into an + :class:`AgentResponse`. + + Signature mirrors graph :meth:`Workflow.as_agent` so polymorphic + code works over either flavor. + + Args: + name: Display name for the agent. Defaults to the workflow name. + description: Optional description override. Defaults to the + workflow's ``description``. + context_providers: Optional context providers to associate with + the agent. Stored for caller introspection. + **kwargs: Reserved for future parity with + :meth:`Workflow.as_agent`. + + Returns: + A :class:`FunctionalWorkflowAgent` wrapping this workflow. + """ + return FunctionalWorkflowAgent( + workflow=self, + name=name, + description=description, + context_providers=context_providers, + **kwargs, + ) + + # ------------------------------------------------------------------ + # Internal execution + # ------------------------------------------------------------------ + + async def _run_core( + self, + message: Any | None = None, + *, + responses: dict[str, Any] | None = None, + checkpoint_id: str | None = None, + checkpoint_storage: CheckpointStorage | None = None, + streaming: bool = False, + **kwargs: Any, + ) -> AsyncIterable[WorkflowEvent[Any]]: + storage = checkpoint_storage or self._checkpoint_storage + + # Build context + ctx = RunContext(self.name, streaming=streaming, run_kwargs=kwargs if kwargs else None) + + # Restore from checkpoint if requested + prev_checkpoint_id: str | None = None + if checkpoint_id is not None: + if storage is None: + raise ValueError( + "Cannot restore from checkpoint without checkpoint_storage. " + "Provide checkpoint_storage parameter or set it on the @workflow decorator." + ) + checkpoint = await storage.load(checkpoint_id) + if checkpoint.graph_signature_hash != self.graph_signature_hash: + raise ValueError( + f"Checkpoint '{checkpoint_id}' was created by a different version of workflow " + f"'{checkpoint.workflow_name}' and is not compatible with the current version. " + f"The workflow's step structure may have changed since this checkpoint was saved." + ) + prev_checkpoint_id = checkpoint_id + # Restore step cache + step_cache_data = checkpoint.state.get("_step_cache", {}) + ctx._import_step_cache(step_cache_data) + step_cache_auto_request_info_counts = checkpoint.state.get("_step_cache_auto_request_info_counts", {}) + ctx._import_step_cache_auto_request_info_counts(step_cache_auto_request_info_counts) + # Restore user state + ctx._state = {k: v for k, v in checkpoint.state.items() if not k.startswith("_")} + # Restore pending request info events + ctx._pending_requests = dict(checkpoint.pending_request_info_events) + # Restore original message for replay + if message is None: + message = checkpoint.state.get("_original_message") + + # For response-only replay (no checkpoint), restore cached state + if checkpoint_id is None and responses: + if message is None: + message = self._last_message + ctx._step_cache = dict(self._last_step_cache) + ctx._step_cache_auto_request_info_counts = dict(self._last_step_cache_auto_request_info_counts) + + # Store message for future replays + if message is not None: + self._last_message = message + + # Set responses for replay + if responses: + ctx._set_responses(responses) + + # Wire up per-step checkpointing + # Use a mutable list so the closure can update prev_checkpoint_id + ckpt_chain: list[str | None] = [prev_checkpoint_id] + if storage is not None: + + async def _on_step_completed() -> None: + ckpt_chain[0] = await self._save_checkpoint(ctx, storage, ckpt_chain[0]) + + ctx._on_step_completed = _on_step_completed + + # Tracing + attributes: dict[str, Any] = {OtelAttr.WORKFLOW_NAME: self.name} + if self.description: + attributes[OtelAttr.WORKFLOW_DESCRIPTION] = self.description + + with create_workflow_span(OtelAttr.WORKFLOW_RUN_SPAN, attributes) as span: + saw_request = False + try: + span.add_event(OtelAttr.WORKFLOW_STARTED) + + with _framework_event_origin(): + yield WorkflowEvent.started() + with _framework_event_origin(): + yield WorkflowEvent.status(WorkflowRunState.IN_PROGRESS) + + # Execute the user function + return_value = await self._execute(ctx, message) + + # Emit the return value as the workflow output. + if return_value is not None: + await ctx.add_event(WorkflowEvent.output(self.name, return_value)) + + # Persist step cache for response-only replay + self._last_step_cache = dict(ctx._step_cache) + self._last_step_cache_auto_request_info_counts = dict(ctx._step_cache_auto_request_info_counts) + + # Yield collected events. + # NOTE: Events are buffered during _execute() and yielded after + # the user function completes. This is *not* true streaming — + # all events have already been produced by this point. True + # per-token streaming from inner agent calls is a future + # enhancement. + for event in ctx._get_events(): + if event.type == "request_info": + saw_request = True + yield event + if event.type == "request_info": + with _framework_event_origin(): + yield WorkflowEvent.status(WorkflowRunState.IN_PROGRESS_PENDING_REQUESTS) + + # Save final checkpoint if storage is available + if storage is not None: + await self._save_checkpoint(ctx, storage, ckpt_chain[0]) + + # Final status + if saw_request: + self._last_pending_request_ids = set(ctx._pending_requests) + with _framework_event_origin(): + yield WorkflowEvent.status(WorkflowRunState.IDLE_WITH_PENDING_REQUESTS) + else: + # Clean completion — drop cross-run replay state. + self._last_message = None + self._last_step_cache = {} + self._last_step_cache_auto_request_info_counts = {} + self._last_pending_request_ids = set() + with _framework_event_origin(): + yield WorkflowEvent.status(WorkflowRunState.IDLE) + + span.add_event(OtelAttr.WORKFLOW_COMPLETED) + + except WorkflowInterrupted: + # Persist step cache for response-only replay + self._last_step_cache = dict(ctx._step_cache) + self._last_step_cache_auto_request_info_counts = dict(ctx._step_cache_auto_request_info_counts) + self._last_pending_request_ids = set(ctx._pending_requests) + + # HITL interruption — yield events collected so far + for event in ctx._get_events(): + if event.type == "request_info": + saw_request = True + yield event + if event.type == "request_info": + with _framework_event_origin(): + yield WorkflowEvent.status(WorkflowRunState.IN_PROGRESS_PENDING_REQUESTS) + + # Save checkpoint + if storage is not None: + await self._save_checkpoint(ctx, storage, ckpt_chain[0]) + + with _framework_event_origin(): + yield WorkflowEvent.status(WorkflowRunState.IDLE_WITH_PENDING_REQUESTS) + + span.add_event(OtelAttr.WORKFLOW_COMPLETED) + + except Exception as exc: + # Yield any events collected before the failure + for event in ctx._get_events(): + yield event + + details = WorkflowErrorDetails.from_exception(exc) + with _framework_event_origin(): + yield WorkflowEvent.failed(details) + with _framework_event_origin(): + yield WorkflowEvent.status(WorkflowRunState.FAILED) + + span.add_event( + name=OtelAttr.WORKFLOW_ERROR, + attributes={ + "error.message": str(exc), + "error.type": type(exc).__name__, + }, + ) + capture_exception(span, exception=exc) + raise + + async def _execute(self, ctx: RunContext, message: Any) -> Any: + """Run the user's async function with the active context.""" + if message is not None and not self._non_ctx_param_names: + raise ValueError( + f"@workflow function '{self._func.__name__}' has no non-RunContext " + f"parameter to receive a message, but .run(message=...) was called " + f"with a non-None value. Either add a first parameter to the " + f"workflow function or omit 'message'." + ) + + token = _active_run_ctx.set(ctx) + try: + sig = inspect.signature(self._func) + params = list(sig.parameters.values()) + + # Resolve string annotations to actual types + try: + hints = typing.get_type_hints(self._func) + except Exception as exc: + logger.warning( + "Failed to resolve type hints for workflow function '%s': %s. " + "RunContext injection may not work if annotations are forward references.", + self._func.__name__, + exc, + ) + hints = {} + + # Build call arguments: inject RunContext and pass `message`. + # RunContext is detected by type annotation first, then by + # parameter name "ctx" — so both of these work: + # async def my_workflow(data: str, ctx: RunContext) -> str: + # async def my_workflow(data: str, ctx) -> str: + call_args: list[Any] = [] + message_injected = False + + for param in params: + resolved = hints.get(param.name, param.annotation) + if resolved is RunContext or param.name == "ctx": + call_args.append(ctx) + elif not message_injected: + # First non-ctx param gets the message + call_args.append(message) + message_injected = True + + return await self._func(*call_args) + finally: + _active_run_ctx.reset(token) + + # ------------------------------------------------------------------ + # Checkpoint helpers + # ------------------------------------------------------------------ + + async def _save_checkpoint( + self, + ctx: RunContext, + storage: CheckpointStorage, + previous_checkpoint_id: str | None = None, + ) -> str: + state = dict(ctx._state) + state["_step_cache"] = ctx._export_step_cache() + state["_step_cache_auto_request_info_counts"] = ctx._export_step_cache_auto_request_info_counts() + state["_original_message"] = self._last_message + + checkpoint = WorkflowCheckpoint( + workflow_name=self.name, + graph_signature_hash=self.graph_signature_hash, + previous_checkpoint_id=previous_checkpoint_id, + state=state, + pending_request_info_events=dict(ctx._pending_requests), + ) + return await storage.save(checkpoint) + + def _compute_signature_hash(self) -> str: + """Stable hash of the workflow's code shape. + + Mixes workflow name, statically-discovered step names, and a digest + of ``__code__.co_code`` + ``co_names``. The code digest catches + body changes that step-name discovery misses (e.g. attribute-access + step references). + """ + code = getattr(self._func, "__code__", None) + co_code_hex = hashlib.sha256(code.co_code).hexdigest() if code is not None else "" + co_names = tuple(sorted(code.co_names)) if code is not None else () + sig_data = { + "workflow": self.name, + "steps": sorted(self._step_names), + "co_code": co_code_hex, + "co_names": list(co_names), + } + import json + + canonical = json.dumps(sig_data, sort_keys=True, separators=(",", ":")) + return hashlib.sha256(canonical.encode("utf-8")).hexdigest() + + @staticmethod + def _discover_step_names(func: Callable[..., Any]) -> list[str]: + """Extract step names referenced by the workflow function. + + Inspects the function's ``__code__.co_names`` and global scope for + ``StepWrapper`` instances. Steps accessed via module or class + attributes (``my_steps.fetch``) are missed here, but + :meth:`_compute_signature_hash` still captures them through the + ``co_code`` digest. + """ + names: list[str] = [] + globs = getattr(func, "__globals__", {}) + code_names = getattr(getattr(func, "__code__", None), "co_names", ()) + for n in code_names: + obj = globs.get(n) + if isinstance(obj, StepWrapper): + names.append(obj.name) + return names + + # ------------------------------------------------------------------ + # Finalize / cleanup / validation (mirrors Workflow) + # ------------------------------------------------------------------ + + @staticmethod + def _finalize_events( + events: Sequence[WorkflowEvent[Any]], + *, + include_status_events: bool = False, + ) -> WorkflowRunResult: + filtered: list[WorkflowEvent[Any]] = [] + status_events: list[WorkflowEvent[Any]] = [] + + for ev in events: + if ev.type == "started": + continue + if ev.type == "status": + status_events.append(ev) + if include_status_events: + filtered.append(ev) + continue + filtered.append(ev) + + return WorkflowRunResult(filtered, status_events) + + @staticmethod + def _validate_run_params( + message: Any | None, + responses: dict[str, Any] | None, + checkpoint_id: str | None, + ) -> None: + if message is not None and responses is not None: + raise ValueError("Cannot provide both 'message' and 'responses'. Use one or the other.") + + if message is not None and checkpoint_id is not None: + raise ValueError("Cannot provide both 'message' and 'checkpoint_id'. Use one or the other.") + + if message is None and responses is None and checkpoint_id is None: + raise ValueError( + "Must provide at least one of: 'message' (new run), 'responses' (send responses), " + "or 'checkpoint_id' (resume from checkpoint)." + ) + + def _ensure_not_running(self) -> None: + if self._is_running: + raise RuntimeError("Workflow is already running. Concurrent executions are not allowed.") + self._is_running = True + + async def _run_cleanup(self) -> None: + self._is_running = False + + +# --------------------------------------------------------------------------- +# @workflow decorator +# --------------------------------------------------------------------------- + + +@overload +def workflow(func: Callable[..., Awaitable[Any]]) -> FunctionalWorkflow: ... + + +@overload +def workflow( + *, + name: str | None = None, + description: str | None = None, + checkpoint_storage: CheckpointStorage | None = None, +) -> Callable[[Callable[..., Awaitable[Any]]], FunctionalWorkflow]: ... + + +@experimental(feature_id=ExperimentalFeature.FUNCTIONAL_WORKFLOWS) +def workflow( + func: Callable[..., Awaitable[Any]] | None = None, + *, + name: str | None = None, + description: str | None = None, + checkpoint_storage: CheckpointStorage | None = None, +) -> FunctionalWorkflow | Callable[[Callable[..., Awaitable[Any]]], FunctionalWorkflow]: + """Decorator that converts an async function into a :class:`FunctionalWorkflow`. + + Supports both bare ``@workflow`` and parameterized + ``@workflow(name="my_wf")`` forms. + + The decorated function receives its input as the first positional argument + and a :class:`RunContext` instance wherever a parameter is annotated with + that type. The resulting :class:`FunctionalWorkflow` object exposes the + same ``run()`` interface as graph-based workflows. + + Args: + func: The async function to decorate (when using the bare + ``@workflow`` form). + name: Display name for the workflow. Defaults to ``func.__name__``. + description: Optional human-readable description. + checkpoint_storage: Default :class:`CheckpointStorage` for + persisting step results and workflow state. + + Returns: + A :class:`FunctionalWorkflow` (bare form) or a decorator that + produces one (parameterized form). + + Examples: + + .. code-block:: python + + # Bare form + @workflow + async def pipeline(data: str) -> str: + return await process(data) + + + # Parameterized form + @workflow(name="my_pipeline", checkpoint_storage=storage) + async def pipeline(data: str) -> str: ... + """ + if func is not None: + return FunctionalWorkflow(func, name=name, description=description, checkpoint_storage=checkpoint_storage) + + def _decorator(fn: Callable[..., Awaitable[Any]]) -> FunctionalWorkflow: + return FunctionalWorkflow(fn, name=name, description=description, checkpoint_storage=checkpoint_storage) + + return _decorator + + +# --------------------------------------------------------------------------- +# FunctionalWorkflowAgent +# --------------------------------------------------------------------------- + + +@experimental(feature_id=ExperimentalFeature.FUNCTIONAL_WORKFLOWS) +class FunctionalWorkflowAgent: + """Agent adapter for a :class:`FunctionalWorkflow`. + + Provides a ``run()`` method with the same overloaded signature as + :class:`BaseAgent` — returning an :class:`AgentResponse` (non-streaming) + or a :class:`ResponseStream[AgentResponseUpdate, AgentResponse]` + (streaming), making functional workflows usable anywhere an + agent-compatible object is expected. + + ``request_info`` events emitted by the underlying workflow are surfaced + as :class:`FunctionApprovalRequestContent` items (mirroring the graph + :class:`WorkflowAgent`), so HITL workflows are callable via this + adapter. Callers resume via ``responses=`` / ``checkpoint_id=``. + + Args: + workflow: The :class:`FunctionalWorkflow` to wrap. + name: Display name for the agent. Defaults to the workflow name. + description: Display description. Defaults to ``workflow.description``. + context_providers: Optional context providers stored for caller + introspection. + **kwargs: Reserved for future parity with :class:`WorkflowAgent`; + currently ignored. + """ + + REQUEST_INFO_FUNCTION_NAME: str = "request_info" + + def __init__( + self, + workflow: FunctionalWorkflow, + *, + name: str | None = None, + description: str | None = None, + context_providers: Sequence[Any] | None = None, + **kwargs: Any, + ) -> None: + # kwargs is accepted for signature parity with graph Workflow.as_agent + # but not otherwise consumed. + del kwargs + self._workflow = workflow + self.name = name or workflow.name + self.id = f"FunctionalWorkflowAgent_{self.name}" + self.description: str | None = description if description is not None else workflow.description + self.context_providers: Sequence[Any] | None = context_providers + self._pending_requests: dict[str, WorkflowEvent[Any]] = {} + + @property + def pending_requests(self) -> dict[str, WorkflowEvent[Any]]: + """Pending request_info events emitted during the last run.""" + return self._pending_requests + + @overload + def run( + self, + messages: Any | None = None, + *, + stream: Literal[True], + responses: dict[str, Any] | None = None, + checkpoint_id: str | None = None, + checkpoint_storage: CheckpointStorage | None = None, + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse]: ... + + @overload + def run( + self, + messages: Any | None = None, + *, + stream: Literal[False] = ..., + responses: dict[str, Any] | None = None, + checkpoint_id: str | None = None, + checkpoint_storage: CheckpointStorage | None = None, + **kwargs: Any, + ) -> Awaitable[AgentResponse]: ... + + def run( + self, + messages: Any | None = None, + *, + stream: bool = False, + responses: dict[str, Any] | None = None, + checkpoint_id: str | None = None, + checkpoint_storage: CheckpointStorage | None = None, + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse] | Awaitable[AgentResponse]: + """Run the underlying workflow and return the result as an agent response. + + Args: + messages: Input data forwarded to :meth:`FunctionalWorkflow.run`. + + Keyword Args: + stream: If ``True``, return a :class:`ResponseStream` of + :class:`AgentResponseUpdate` items. + responses: HITL responses keyed by ``request_id``, forwarded to + the underlying workflow so HITL resumes work via this agent. + checkpoint_id: Optional checkpoint to restore from. + checkpoint_storage: Override the workflow's default + :class:`CheckpointStorage` for this run. + **kwargs: Extra keyword arguments forwarded to the workflow run. + + Returns: + An :class:`AgentResponse` (non-streaming) or a + :class:`ResponseStream` (streaming). + """ + if stream: + return self._run_streaming( + messages, + responses=responses, + checkpoint_id=checkpoint_id, + checkpoint_storage=checkpoint_storage, + **kwargs, + ) + return self._run_non_streaming( + messages, + responses=responses, + checkpoint_id=checkpoint_id, + checkpoint_storage=checkpoint_storage, + **kwargs, + ) + + async def _run_non_streaming( + self, + messages: Any | None, + *, + responses: dict[str, Any] | None = None, + checkpoint_id: str | None = None, + checkpoint_storage: CheckpointStorage | None = None, + **kwargs: Any, + ) -> AgentResponse: + result = await self._workflow.run( + messages, + responses=responses, + checkpoint_id=checkpoint_id, + checkpoint_storage=checkpoint_storage, + **kwargs, + ) + return self._result_to_agent_response(result) + + def _run_streaming( + self, + messages: Any | None, + *, + responses: dict[str, Any] | None = None, + checkpoint_id: str | None = None, + checkpoint_storage: CheckpointStorage | None = None, + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse]: + from .._types import Content + + agent_name = self.name + # Clear per-run pending state up front + self._pending_requests = {} + workflow_stream = self._workflow.run( + messages, + stream=True, + responses=responses, + checkpoint_id=checkpoint_id, + checkpoint_storage=checkpoint_storage, + **kwargs, + ) + + async def _generate_updates() -> AsyncIterable[AgentResponseUpdate]: + async for event in workflow_stream: + if event.type == "output": + data = event.data + if isinstance(data, str): + contents: list[Content] = [Content.from_text(text=data)] + elif isinstance(data, Content): + contents = [data] + else: + contents = [Content.from_text(text=str(data))] + yield AgentResponseUpdate( + contents=contents, + role="assistant", + author_name=agent_name, + ) + elif event.type == "request_info": + approval = self._request_info_to_approval_request(event) + if approval is None: + continue + yield AgentResponseUpdate( + contents=[approval], + role="assistant", + author_name=agent_name, + ) + + return ResponseStream( + _generate_updates(), + finalizer=AgentResponse.from_updates, + ) + + def _request_info_to_approval_request(self, event: WorkflowEvent[Any]) -> Any: + """Convert a `request_info` event to `FunctionApprovalRequestContent`. + + Returns ``None`` if the event is missing a request_id (defensive; + `request_info` always sets one). + """ + from .._types import Content + + request_id = event.request_id + if not request_id: + return None + self._pending_requests[request_id] = event + function_call = Content.from_function_call( + call_id=request_id, + name=self.REQUEST_INFO_FUNCTION_NAME, + arguments={"request_id": request_id, "data": event.data}, + ) + return Content.from_function_approval_request( + id=request_id, + function_call=function_call, + additional_properties={"request_id": request_id}, + ) + + def _result_to_agent_response(self, result: WorkflowRunResult) -> AgentResponse: + from .._types import Content + from .._types import Message as Msg + + # Refresh pending_requests for this run. + self._pending_requests = {} + + messages: list[Msg] = [] + for output in result.get_outputs(): + if isinstance(output, str): + contents: list[Content] = [Content.from_text(text=output)] + elif isinstance(output, Content): + contents = [output] + else: + contents = [Content.from_text(text=str(output))] + messages.append(Msg("assistant", contents)) + + # Surface pending request_info events so HITL callers see them. + approval_contents: list[Content] = [] + for event in result.get_request_info_events(): + approval = self._request_info_to_approval_request(event) + if approval is not None: + approval_contents.append(approval) + if approval_contents: + messages.append(Msg("assistant", approval_contents)) + + return AgentResponse(messages=messages) diff --git a/python/packages/core/agent_framework/_workflows/_workflow.py b/python/packages/core/agent_framework/_workflows/_workflow.py index fc26db8953..c452f62bc2 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow.py +++ b/python/packages/core/agent_framework/_workflows/_workflow.py @@ -340,10 +340,10 @@ async def _run_workflow_with_tracing( # Emit explicit start/status events to the stream with _framework_event_origin(): started = WorkflowEvent.started() - yield started + yield started # noqa: RUF070 with _framework_event_origin(): in_progress = WorkflowEvent.status(WorkflowRunState.IN_PROGRESS) - yield in_progress + yield in_progress # noqa: RUF070 # Reset context for a new run if supported if reset_context: @@ -388,7 +388,7 @@ async def _run_workflow_with_tracing( emitted_in_progress_pending = True with _framework_event_origin(): pending_status = WorkflowEvent.status(WorkflowRunState.IN_PROGRESS_PENDING_REQUESTS) - yield pending_status + yield pending_status # noqa: RUF070 # Workflow runs until idle - emit final status based on whether requests are pending if saw_request: with _framework_event_origin(): @@ -409,10 +409,10 @@ async def _run_workflow_with_tracing( details = WorkflowErrorDetails.from_exception(exc) with _framework_event_origin(): failed_event = WorkflowEvent.failed(details) - yield failed_event + yield failed_event # noqa: RUF070 with _framework_event_origin(): failed_status = WorkflowEvent.status(WorkflowRunState.FAILED) - yield failed_status + yield failed_status # noqa: RUF070 span.add_event( name=OtelAttr.WORKFLOW_ERROR, attributes={ diff --git a/python/packages/core/tests/workflow/test_function_executor.py b/python/packages/core/tests/workflow/test_function_executor.py index 6d292cb6eb..b45a667722 100644 --- a/python/packages/core/tests/workflow/test_function_executor.py +++ b/python/packages/core/tests/workflow/test_function_executor.py @@ -529,11 +529,12 @@ async def bad_handler(cls, data: str) -> str: assert "@handler on instance methods" in str(exc_info.value) async def test_async_staticmethod_detection_behavior(self): - """Document the behavior of asyncio.iscoroutinefunction with staticmethod descriptors. + """Document the behavior of inspect.iscoroutinefunction with staticmethod descriptors. This test explains why the unwrapping is necessary when decorators are stacked. """ import asyncio + import inspect # When @staticmethod is applied, it creates a descriptor async def my_async_func(): @@ -544,19 +545,19 @@ async def my_async_func(): static_wrapped = staticmethod(my_async_func) # Direct check on descriptor object fails (this is the bug) - assert not asyncio.iscoroutinefunction(static_wrapped) # type: ignore[reportDeprecated] + assert not inspect.iscoroutinefunction(static_wrapped) assert isinstance(static_wrapped, staticmethod) # But unwrapping __func__ reveals the async function unwrapped = static_wrapped.__func__ - assert asyncio.iscoroutinefunction(unwrapped) # type: ignore[reportDeprecated] + assert inspect.iscoroutinefunction(unwrapped) # When accessed via class attribute, Python's descriptor protocol # automatically unwraps it, so it works: class C: async_static = static_wrapped - assert asyncio.iscoroutinefunction(C.async_static) # type: ignore[reportDeprecated] # Works via descriptor protocol + assert inspect.iscoroutinefunction(C.async_static) # Works via descriptor protocol class TestExecutorExplicitTypes: diff --git a/python/packages/core/tests/workflow/test_functional_workflow.py b/python/packages/core/tests/workflow/test_functional_workflow.py new file mode 100644 index 0000000000..ba465ffe0b --- /dev/null +++ b/python/packages/core/tests/workflow/test_functional_workflow.py @@ -0,0 +1,1693 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for the functional workflow API (@workflow, @step, RunContext).""" + +from __future__ import annotations + +import asyncio +import logging +from collections.abc import Iterator +from contextlib import contextmanager +from dataclasses import dataclass + +import pytest + +from agent_framework import ( + AgentResponseUpdate, + ExperimentalFeature, + FunctionalWorkflow, + FunctionalWorkflowAgent, + InMemoryCheckpointStorage, + RunContext, + StepWrapper, + WorkflowEvent, + WorkflowRunResult, + WorkflowRunState, + get_run_context, + step, + workflow, +) +from agent_framework._workflows._functional import ( + RunContext as _RunContext, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +@step +async def add_one(x: int) -> int: + return x + 1 + + +@step +async def double(x: int) -> int: + return x * 2 + + +@step +async def to_upper(s: str) -> str: + return s.upper() + + +@step(name="custom_name") +async def named_step(x: int) -> int: + return x + 10 + + +@step +async def failing_step(x: int) -> int: + raise ValueError(f"step failed with {x}") + + +# --------------------------------------------------------------------------- +# Basic execution +# --------------------------------------------------------------------------- + + +class TestBasicExecution: + async def test_simple_sequential_pipeline(self): + @workflow + async def pipeline(x: int) -> int: + a = await add_one(x) + return await double(a) + + result = await pipeline.run(5) + assert isinstance(result, WorkflowRunResult) + outputs = result.get_outputs() + assert outputs == [12] # (5+1)*2 + + async def test_workflow_with_string_data(self): + @workflow + async def upper_pipeline(text: str) -> str: + return await to_upper(text) + + result = await upper_pipeline.run("hello") + assert result.get_outputs() == ["HELLO"] + + async def test_workflow_returns_result(self): + @workflow + async def simple(x: int) -> int: + return await add_one(x) + + result = await simple.run(10) + assert result.get_outputs() == [11] + + async def test_workflow_name_defaults_to_function_name(self): + @workflow + async def my_pipeline(x: int) -> int: + return x + + assert my_pipeline.name == "my_pipeline" + + async def test_workflow_custom_name(self): + @workflow(name="custom_wf", description="A test workflow") + async def wf(x: int) -> int: + return x + + assert wf.name == "custom_wf" + assert wf.description == "A test workflow" + + +# --------------------------------------------------------------------------- +# Event emission +# --------------------------------------------------------------------------- + + +class TestEventEmission: + async def test_step_events_emitted(self): + @workflow + async def pipeline(x: int) -> int: + return await add_one(x) + + result = await pipeline.run(5) + event_types = [e.type for e in result] + assert "executor_invoked" in event_types + assert "executor_completed" in event_types + assert "output" in event_types + + async def test_step_events_carry_executor_id(self): + @workflow + async def pipeline(x: int) -> int: + return await add_one(x) + + result = await pipeline.run(5) + invoked_events = [e for e in result if e.type == "executor_invoked"] + assert len(invoked_events) == 1 + assert invoked_events[0].executor_id == "add_one" + + completed_events = [e for e in result if e.type == "executor_completed"] + assert len(completed_events) == 1 + assert completed_events[0].executor_id == "add_one" + assert completed_events[0].data == 6 + + async def test_status_events_in_timeline(self): + @workflow + async def pipeline(x: int) -> int: + return x + + result = await pipeline.run(1) + states = [e.state for e in result.status_timeline()] + assert WorkflowRunState.IN_PROGRESS in states + assert WorkflowRunState.IDLE in states + + async def test_final_state_is_idle(self): + @workflow + async def pipeline(x: int) -> int: + return x + + result = await pipeline.run(1) + assert result.get_final_state() == WorkflowRunState.IDLE + + async def test_custom_event(self): + from agent_framework import WorkflowEvent + + @workflow + async def pipeline(x: int, ctx: RunContext) -> int: + await ctx.add_event(WorkflowEvent.emit("pipeline", "custom_data")) + return x + + result = await pipeline.run(1) + data_events = [e for e in result if e.type == "data"] + assert len(data_events) == 1 + assert data_events[0].data == "custom_data" + + +# --------------------------------------------------------------------------- +# Parallel execution +# --------------------------------------------------------------------------- + + +class TestParallelExecution: + async def test_parallel_tasks_with_gather(self): + @step + async def slow_add(x: int) -> int: + await asyncio.sleep(0.01) + return x + 1 + + @step + async def slow_double(x: int) -> int: + await asyncio.sleep(0.01) + return x * 2 + + @workflow + async def parallel_wf(x: int) -> list[int]: + a, b = await asyncio.gather(slow_add(x), slow_double(x)) + return [a, b] + + result = await parallel_wf.run(5) + outputs = result.get_outputs() + assert outputs == [[6, 10]] + + async def test_parallel_events_all_emitted(self): + @step + async def task_a(x: int) -> int: + return x + 1 + + @step + async def task_b(x: int) -> int: + return x * 2 + + @workflow + async def par_wf(x: int) -> tuple[int, int]: + a, b = await asyncio.gather(task_a(x), task_b(x)) + return (a, b) + + result = await par_wf.run(3) + invoked = [e for e in result if e.type == "executor_invoked"] + completed = [e for e in result if e.type == "executor_completed"] + assert len(invoked) == 2 + assert len(completed) == 2 + + +# --------------------------------------------------------------------------- +# HITL (request_info / resume) +# --------------------------------------------------------------------------- + + +class TestHITL: + async def test_request_info_interrupts(self): + @workflow + async def review_wf(doc: str, ctx: RunContext) -> str: + feedback = await ctx.request_info({"draft": doc}, response_type=str, request_id="req1") + return f"Final: {feedback}" + + # Phase 1: should interrupt with pending request + result = await review_wf.run("my doc") + assert result.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS + request_events = result.get_request_info_events() + assert len(request_events) == 1 + assert request_events[0].request_id == "req1" + + async def test_request_info_resume(self): + @workflow + async def review_wf(doc: str, ctx: RunContext) -> str: + feedback = await ctx.request_info({"draft": doc}, response_type=str, request_id="req1") + return f"Final: {feedback}" + + # Phase 1 + result1 = await review_wf.run("my doc") + assert result1.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS + + # Phase 2: resume with response + result2 = await review_wf.run(responses={"req1": "Looks great!"}) + outputs = result2.get_outputs() + assert outputs == ["Final: Looks great!"] + assert result2.get_final_state() == WorkflowRunState.IDLE + + async def test_untyped_ctx_parameter(self): + """ctx is injected by parameter name even without a RunContext annotation.""" + + @workflow # pyright: ignore[reportUnknownArgumentType] + async def review_wf(doc: str, ctx) -> str: # pyright: ignore[reportUnknownParameterType,reportMissingParameterType] + feedback: str = await ctx.request_info({"draft": doc}, response_type=str, request_id="req1") # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] + return f"Final: {feedback}" + + result1 = await review_wf.run("my doc") + assert result1.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS + + result2 = await review_wf.run(responses={"req1": "LGTM"}) + assert result2.get_outputs() == ["Final: LGTM"] + + async def test_multiple_sequential_interrupts(self): + @workflow + async def multi_hitl(data: str, ctx: RunContext) -> str: + r1 = await ctx.request_info("step1", response_type=str, request_id="r1") + r2 = await ctx.request_info("step2", response_type=str, request_id="r2") + return f"{r1}+{r2}" + + # Phase 1: first interrupt + result1 = await multi_hitl.run("start") + assert len(result1.get_request_info_events()) == 1 + assert result1.get_request_info_events()[0].request_id == "r1" + + # Phase 2: respond to first, hits second + result2 = await multi_hitl.run(responses={"r1": "A"}) + assert len(result2.get_request_info_events()) == 1 + assert result2.get_request_info_events()[0].request_id == "r2" + + # Phase 3: respond to second + result3 = await multi_hitl.run(responses={"r1": "A", "r2": "B"}) + assert result3.get_outputs() == ["A+B"] + + async def test_request_info_auto_generates_id(self): + @workflow + async def auto_id_wf(x: int, ctx: RunContext) -> None: + await ctx.request_info("need data", response_type=str) + + result = await auto_id_wf.run(1) + events = result.get_request_info_events() + assert len(events) == 1 + assert events[0].request_id # should be a non-empty uuid string + + +# --------------------------------------------------------------------------- +# Error handling +# --------------------------------------------------------------------------- + + +class TestErrorHandling: + async def test_step_failure_propagates(self): + @workflow + async def failing_wf(x: int) -> None: + await failing_step(x) + + with pytest.raises(ValueError, match="step failed with 42"): + await failing_wf.run(42) + + async def test_step_failure_emits_executor_failed(self): + @workflow + async def failing_wf(x: int) -> None: + await failing_step(x) + + # Use stream to collect events before the raise + stream = failing_wf.run(42, stream=True) + events: list[WorkflowEvent[object]] = [] + with pytest.raises(ValueError): + async for event in stream: + events.append(event) + + failed_events = [e for e in events if e.type == "executor_failed"] + assert len(failed_events) == 1 + assert failed_events[0].executor_id == "failing_step" + + async def test_workflow_failure_emits_failed_status(self): + @workflow + async def bad_wf(x: int) -> None: + raise RuntimeError("workflow broke") + + stream = bad_wf.run(42, stream=True) + events: list[WorkflowEvent[object]] = [] + with pytest.raises(RuntimeError, match="workflow broke"): + async for event in stream: + events.append(event) + + failed_events = [e for e in events if e.type == "failed"] + assert len(failed_events) == 1 + status_events = [e for e in events if e.type == "status"] + assert any(e.state == WorkflowRunState.FAILED for e in status_events) + + async def test_invalid_params_message_and_responses(self): + @workflow + async def wf(x: int) -> None: + pass + + with pytest.raises(ValueError, match="Cannot provide both"): + await wf.run("hello", responses={"r1": "val"}) + + async def test_invalid_params_message_and_checkpoint(self): + @workflow + async def wf(x: int) -> None: + pass + + with pytest.raises(ValueError, match="Cannot provide both"): + await wf.run("hello", checkpoint_id="abc") + + async def test_invalid_params_nothing(self): + @workflow + async def wf(x: int) -> None: + pass + + with pytest.raises(ValueError, match="Must provide at least one"): + await wf.run() + + +# --------------------------------------------------------------------------- +# Streaming +# --------------------------------------------------------------------------- + + +class TestStreaming: + async def test_streaming_yields_events(self): + @workflow + async def pipeline(x: int) -> int: + return await add_one(x) + + stream = pipeline.run(5, stream=True) + events: list[WorkflowEvent[object]] = [] + async for event in stream: + events.append(event) + + event_types = [e.type for e in events] + assert "started" in event_types + assert "executor_invoked" in event_types + assert "executor_completed" in event_types + assert "output" in event_types + + async def test_streaming_final_response(self): + @workflow + async def pipeline(x: int) -> int: + return await add_one(x) + + stream = pipeline.run(5, stream=True) + final = await stream.get_final_response() + assert isinstance(final, WorkflowRunResult) + assert final.get_outputs() == [6] + + async def test_streaming_context_reports_streaming(self): + streaming_flag = None + + @workflow + async def wf(x: int, ctx: RunContext) -> int: + nonlocal streaming_flag + streaming_flag = ctx.is_streaming() + return x + + stream = wf.run(1, stream=True) + await stream.get_final_response() + assert streaming_flag is True + + streaming_flag = None + await wf.run(1) + assert streaming_flag is False + + +# --------------------------------------------------------------------------- +# Step passthrough outside workflow +# --------------------------------------------------------------------------- + + +class TestStepPassthrough: + async def test_step_works_outside_workflow(self): + result = await add_one(10) + assert result == 11 + + async def test_named_step_outside_workflow(self): + result = await named_step(5) + assert result == 15 + + def test_step_wrapper_name(self): + assert add_one.name == "add_one" + assert named_step.name == "custom_name" + + def test_step_wrapper_is_step_wrapper(self): + assert isinstance(add_one, StepWrapper) + assert isinstance(named_step, StepWrapper) + + +# --------------------------------------------------------------------------- +# State management +# --------------------------------------------------------------------------- + + +class TestStateManagement: + async def test_get_set_state(self): + @workflow + async def stateful_wf(x: int, ctx: RunContext) -> int: + ctx.set_state("counter", x) + return ctx.get_state("counter") + + result = await stateful_wf.run(42) + assert result.get_outputs() == [42] + + async def test_get_state_default(self): + @workflow + async def wf(x: int, ctx: RunContext) -> str: + return ctx.get_state("missing", "default_val") + + result = await wf.run(1) + assert result.get_outputs() == ["default_val"] + + +# --------------------------------------------------------------------------- +# Checkpointing +# --------------------------------------------------------------------------- + + +class TestCheckpointing: + async def test_checkpoint_save_and_restore(self): + storage = InMemoryCheckpointStorage() + + @step + async def expensive(x: int) -> int: + return x * 100 + + @workflow(checkpoint_storage=storage) + async def ckpt_wf(x: int) -> int: + return await expensive(x) + + result = await ckpt_wf.run(5) + assert result.get_outputs() == [500] + + # Verify checkpoints were saved: 1 per-step + 1 final + checkpoints = await storage.list_checkpoints(workflow_name="ckpt_wf") + assert len(checkpoints) == 2 + + async def test_checkpoint_runtime_storage_override(self): + storage = InMemoryCheckpointStorage() + + @step + async def compute(x: int) -> int: + return x + 1 + + @workflow + async def wf(x: int) -> int: + return await compute(x) + + result = await wf.run(10, checkpoint_storage=storage) + assert result.get_outputs() == [11] + # 1 per-step checkpoint + 1 final checkpoint + checkpoints = await storage.list_checkpoints(workflow_name="wf") + assert len(checkpoints) == 2 + + async def test_checkpoint_restore_replays_cached_tasks(self): + storage = InMemoryCheckpointStorage() + call_count = 0 + + @step + async def counting_task(x: int) -> int: + nonlocal call_count + call_count += 1 + return x + 1 + + @workflow(checkpoint_storage=storage) + async def wf(x: int) -> int: + return await counting_task(x) + + # First run + result1 = await wf.run(5) + assert result1.get_outputs() == [6] + assert call_count == 1 + + # Get checkpoint ID + checkpoints = await storage.list_checkpoints(workflow_name="wf") + ckpt_id = checkpoints[0].checkpoint_id + + # Restore — step should replay from cache + result2 = await wf.run(checkpoint_id=ckpt_id) + assert result2.get_outputs() == [6] + assert call_count == 1 # not called again + + async def test_checkpoint_hitl_resume(self): + storage = InMemoryCheckpointStorage() + + @workflow(checkpoint_storage=storage) + async def hitl_wf(doc: str, ctx: RunContext) -> str: + feedback = await ctx.request_info({"draft": doc}, response_type=str, request_id="req1") + return f"Done: {feedback}" + + # Phase 1: interrupt + result1 = await hitl_wf.run("draft text") + assert result1.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS + + # Get checkpoint + checkpoints = await storage.list_checkpoints(workflow_name="hitl_wf") + ckpt_id = checkpoints[0].checkpoint_id + + # Phase 2: restore and respond + result2 = await hitl_wf.run(checkpoint_id=ckpt_id, responses={"req1": "Approved!"}) + assert result2.get_outputs() == ["Done: Approved!"] + + async def test_checkpoint_without_storage_raises(self): + @workflow + async def wf(x: int) -> int: + return x + + with pytest.raises(ValueError, match="checkpoint_storage"): + await wf.run(checkpoint_id="nonexistent") + + async def test_checkpoint_preserves_state(self): + storage = InMemoryCheckpointStorage() + + @workflow(checkpoint_storage=storage) + async def stateful_wf(x: int, ctx: RunContext) -> str: + ctx.set_state("key", "value") + feedback = await ctx.request_info("need info", response_type=str, request_id="r1") + val = ctx.get_state("key") + return f"{val}:{feedback}" + + # Phase 1 + result1 = await stateful_wf.run(1) + assert result1.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS + + # Phase 2: restore and respond + checkpoints = await storage.list_checkpoints(workflow_name="stateful_wf") + ckpt_id = checkpoints[0].checkpoint_id + + result2 = await stateful_wf.run(checkpoint_id=ckpt_id, responses={"r1": "hello"}) + assert result2.get_outputs() == ["value:hello"] + + async def test_per_step_checkpoint_enables_crash_recovery(self): + """Simulates crash recovery: step 1 completes and is checkpointed, + then the workflow crashes in step 2. Restoring from the per-step + checkpoint should replay step 1 from cache without re-executing it.""" + storage = InMemoryCheckpointStorage() + step1_calls = 0 + step2_calls = 0 + + @step + async def slow_step1(x: int) -> int: + nonlocal step1_calls + step1_calls += 1 + return x + 10 + + @step + async def crashing_step2(x: int) -> int: + nonlocal step2_calls + step2_calls += 1 + if step2_calls == 1: + raise RuntimeError("simulated crash") + return x * 2 + + @workflow(checkpoint_storage=storage) + async def crash_wf(x: int) -> int: + a = await slow_step1(x) + return await crashing_step2(a) + + # First run: step1 succeeds and checkpoints, step2 crashes + with pytest.raises(RuntimeError, match="simulated crash"): + await crash_wf.run(5) + + assert step1_calls == 1 + assert step2_calls == 1 + + # A per-step checkpoint was saved after step1 completed + checkpoints = await storage.list_checkpoints(workflow_name="crash_wf") + assert len(checkpoints) >= 1 + ckpt_id = checkpoints[0].checkpoint_id + + # Restore from checkpoint: step1 replays from cache, step2 runs fresh + result = await crash_wf.run(checkpoint_id=ckpt_id) + assert result.get_outputs() == [30] # (5+10)*2 + assert step1_calls == 1 # NOT called again — replayed from cache + assert step2_calls == 2 # called again, succeeds this time + + async def test_per_step_checkpoint_chain(self): + """Each step creates a new checkpoint chained to the previous one.""" + storage = InMemoryCheckpointStorage() + + @step + async def s1(x: int) -> int: + return x + 1 + + @step + async def s2(x: int) -> int: + return x + 2 + + @step + async def s3(x: int) -> int: + return x + 3 + + @workflow(checkpoint_storage=storage) + async def multi_step_wf(x: int) -> int: + a = await s1(x) + b = await s2(a) + return await s3(b) + + result = await multi_step_wf.run(0) + assert result.get_outputs() == [6] # 0+1+2+3 + + # 3 per-step checkpoints + 1 final = 4 + checkpoints = await storage.list_checkpoints(workflow_name="multi_step_wf") + assert len(checkpoints) == 4 + + async def test_no_checkpoint_on_cache_hit(self): + """During replay, cached steps should NOT create additional checkpoints.""" + storage = InMemoryCheckpointStorage() + + @step + async def compute(x: int) -> int: + return x + 1 + + @workflow(checkpoint_storage=storage) + async def wf(x: int) -> int: + return await compute(x) + + # First run: 1 per-step + 1 final = 2 checkpoints + await wf.run(5) + checkpoints = await storage.list_checkpoints(workflow_name="wf") + assert len(checkpoints) == 2 + ckpt_id = checkpoints[0].checkpoint_id + + # Restore: step replays from cache (no new per-step checkpoint), + # but final checkpoint still saved = 1 new checkpoint + await wf.run(checkpoint_id=ckpt_id) + checkpoints = await storage.list_checkpoints(workflow_name="wf") + assert len(checkpoints) == 3 # 2 from first run + 1 final from restore + + +# --------------------------------------------------------------------------- +# Branching / control flow +# --------------------------------------------------------------------------- + + +class TestControlFlow: + async def test_if_else_branching(self): + @dataclass + class Classification: + is_spam: bool + + @step + async def classify(text: str) -> Classification: + return Classification(is_spam="spam" in text.lower()) + + @step + async def process_normal(text: str) -> str: + return f"processed: {text}" + + @step + async def quarantine(text: str) -> str: + return f"quarantined: {text}" + + @workflow + async def email_pipeline(email: str) -> str: + cl = await classify(email) + if cl.is_spam: + result = await quarantine(email) + else: + result = await process_normal(email) + return result + + result_spam = await email_pipeline.run("Buy spam now!") + assert result_spam.get_outputs() == ["quarantined: Buy spam now!"] + + result_normal = await email_pipeline.run("Hello friend") + assert result_normal.get_outputs() == ["processed: Hello friend"] + + +# --------------------------------------------------------------------------- +# Nested workflow calls +# --------------------------------------------------------------------------- + + +class TestNestedWorkflows: + async def test_nested_workflow_as_task(self): + @step + async def step_a(x: int) -> int: + return x + 1 + + @workflow + async def inner_wf(x: int) -> int: + return await step_a(x) + + @step + async def call_inner(x: int) -> int: + result = await inner_wf.run(x) + return result.get_outputs()[0] + + @workflow + async def outer_wf(x: int) -> int: + return await call_inner(x) + + result = await outer_wf.run(5) + assert result.get_outputs() == [6] + + +# --------------------------------------------------------------------------- +# as_agent() +# --------------------------------------------------------------------------- + + +class TestAsAgent: + async def test_as_agent_returns_agent(self): + @workflow + async def wf(x: int) -> str: + return f"result: {x}" + + agent = wf.as_agent() + assert agent.name == "wf" + + async def test_as_agent_custom_name(self): + @workflow + async def wf(x: int) -> int: + return x + + agent = wf.as_agent(name="my_agent") + assert agent.name == "my_agent" + + async def test_as_agent_run(self): + @workflow + async def wf(x: int) -> int: + return await add_one(x) + + agent = wf.as_agent() + response = await agent.run(10) + assert response.text == "11" + + async def test_as_agent_run_streaming(self): + @workflow + async def wf(x: int) -> str: + return f"result: {x}" + + agent = wf.as_agent() + stream = agent.run(10, stream=True) + updates: list[AgentResponseUpdate] = [] + async for update in stream: + updates.append(update) + assert len(updates) == 1 + assert updates[0].text == "result: 10" + + response = await stream.get_final_response() + assert len(response.messages) >= 1 + + async def test_as_agent_has_id_and_description(self): + @workflow(description="A test workflow") + async def wf(x: int) -> int: + return x + + agent = wf.as_agent(name="my_agent") + assert agent.id == "FunctionalWorkflowAgent_my_agent" + assert agent.description == "A test workflow" + + +# --------------------------------------------------------------------------- +# Concurrent execution guard +# --------------------------------------------------------------------------- + + +class TestConcurrencyGuard: + async def test_concurrent_run_raises(self): + @workflow + async def slow_wf(x: int) -> int: + await asyncio.sleep(0.1) + return x + + # Start first run + stream = slow_wf.run(1, stream=True) + + # Try to start second run while first is active + with pytest.raises(RuntimeError, match="already running"): + slow_wf.run(2, stream=True) + + # Consume the stream to clean up + await stream.get_final_response() + + async def test_run_after_completion(self): + @workflow + async def wf(x: int) -> int: + return x + + result1 = await wf.run(1) + assert result1.get_outputs() == [1] + + # Should be able to run again after first completes + result2 = await wf.run(2) + assert result2.get_outputs() == [2] + + +# --------------------------------------------------------------------------- +# Decorator forms +# --------------------------------------------------------------------------- + + +class TestDecoratorForms: + def test_step_bare_decorator(self): + @step + async def my_step(x: int) -> int: + return x + + assert isinstance(my_step, StepWrapper) + assert my_step.name == "my_step" + + def test_step_with_name(self): + @step(name="renamed") + async def my_step(x: int) -> int: + return x + + assert isinstance(my_step, StepWrapper) + assert my_step.name == "renamed" + + def test_workflow_bare_decorator(self): + @workflow + async def my_wf(x: int) -> None: + pass + + assert isinstance(my_wf, FunctionalWorkflow) + assert my_wf.name == "my_wf" + + def test_workflow_with_params(self): + @workflow(name="custom", description="desc") + async def my_wf(x: int) -> None: + pass + + assert isinstance(my_wf, FunctionalWorkflow) + assert my_wf.name == "custom" + assert my_wf.description == "desc" + + +# --------------------------------------------------------------------------- +# include_status_events +# --------------------------------------------------------------------------- + + +class TestIncludeStatusEvents: + async def test_status_events_excluded_by_default(self): + @workflow + async def wf(x: int) -> int: + return x + + result = await wf.run(1) + status_in_list = [e for e in result if e.type == "status"] + assert len(status_in_list) == 0 + + async def test_status_events_included_when_requested(self): + @workflow + async def wf(x: int) -> int: + return x + + result = await wf.run(1, include_status_events=True) + status_in_list = [e for e in result if e.type == "status"] + assert len(status_in_list) > 0 + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +class TestEdgeCases: + async def test_workflow_with_no_tasks(self): + @workflow + async def no_tasks(x: int) -> int: + return x * 2 + + result = await no_tasks.run(5) + assert result.get_outputs() == [10] + + async def test_workflow_with_no_output(self): + @workflow + async def silent_wf(x: int) -> None: + pass # returns None — no output emitted + + result = await silent_wf.run(5) + assert result.get_outputs() == [] + + async def test_return_value_auto_yields_output(self): + """Returning a non-None value automatically emits it as an output.""" + + @workflow + async def wf(x: int) -> int: + return x * 3 + + result = await wf.run(5) + assert result.get_outputs() == [15] + + async def test_step_called_multiple_times(self): + @workflow + async def wf(x: int) -> int: + a = await add_one(x) + b = await add_one(a) + return await add_one(b) + + result = await wf.run(0) + assert result.get_outputs() == [3] # 0+1+1+1 + + # Should have 3 invoked and 3 completed events for add_one + invoked = [e for e in result if e.type == "executor_invoked"] + completed = [e for e in result if e.type == "executor_completed"] + assert len(invoked) == 3 + assert len(completed) == 3 + + +# --------------------------------------------------------------------------- +# Recovery after errors +# --------------------------------------------------------------------------- + + +class TestRecoveryAfterErrors: + async def test_run_after_failure_is_allowed(self): + @workflow + async def wf(x: int) -> int: + if x == 1: + raise RuntimeError("boom") + return x + + with pytest.raises(RuntimeError, match="boom"): + await wf.run(1) + + # Must be able to run again after the failure + result = await wf.run(2) + assert result.get_outputs() == [2] + + async def test_step_sync_function_raises(self): + with pytest.raises(TypeError, match="async functions"): + + @step # pyright: ignore[reportArgumentType] + def not_async(x: int) -> int: # pyright: ignore[reportUnusedFunction] + return x + + +# --------------------------------------------------------------------------- +# WorkflowInterrupted is BaseException +# --------------------------------------------------------------------------- + + +class TestWorkflowInterruptedIsBaseException: + async def test_except_exception_does_not_catch_interrupt(self): + """User code with ``except Exception`` should not catch WorkflowInterrupted.""" + caught = False + + @workflow + async def wf(x: int, ctx: RunContext) -> str: + nonlocal caught + try: + return await ctx.request_info("need review", response_type=str, request_id="r1") + except Exception: + # This should NOT catch WorkflowInterrupted + caught = True + return "caught!" + + result = await wf.run("data") + # Should have a pending request, NOT "caught!" + assert result.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS + assert result.get_outputs() == [] + assert caught is False + + +# --------------------------------------------------------------------------- +# Checkpoint validation +# --------------------------------------------------------------------------- + + +class TestCheckpointValidation: + async def test_checkpoint_signature_mismatch_raises(self): + from agent_framework import WorkflowCheckpoint + + storage = InMemoryCheckpointStorage() + + @workflow(name="my_wf", checkpoint_storage=storage) + async def wf(x: int) -> int: + return x + + # Manually create a checkpoint with a different signature hash + bad_checkpoint = WorkflowCheckpoint( + workflow_name="my_wf", + graph_signature_hash="totally_different_hash", + state={"_step_cache": {}, "_original_message": 1}, + ) + ckpt_id = await storage.save(bad_checkpoint) + + # Should fail due to hash mismatch + with pytest.raises(ValueError, match="not compatible"): + await wf.run(checkpoint_id=ckpt_id) + + async def test_import_step_cache_malformed_key(self): + ctx = _RunContext("test") + with pytest.raises(ValueError, match="Corrupted step cache"): + ctx._import_step_cache({"invalid_key_no_separator": 42}) # pyright: ignore[reportPrivateUsage] + + async def test_import_step_cache_non_integer_index(self): + ctx = _RunContext("test") + with pytest.raises(ValueError, match="Corrupted step cache"): + ctx._import_step_cache({"step_name::abc": 42}) # pyright: ignore[reportPrivateUsage] + + +# --------------------------------------------------------------------------- +# executor_bypassed event on replay (review comment #3) +# --------------------------------------------------------------------------- + + +class TestExecutorBypassed: + async def test_cached_step_emits_bypassed_event(self): + """When a step replays from cache, it should emit executor_bypassed.""" + storage = InMemoryCheckpointStorage() + call_count = 0 + + @step + async def tracked(x: int) -> int: + nonlocal call_count + call_count += 1 + return x + 1 + + @workflow(checkpoint_storage=storage) + async def wf(x: int) -> int: + return await tracked(x) + + # First run — live execution + result1 = await wf.run(5) + assert result1.get_outputs() == [6] + assert call_count == 1 + + event_types1 = [e.type for e in result1] + assert "executor_invoked" in event_types1 + assert "executor_completed" in event_types1 + assert "executor_bypassed" not in event_types1 + + # Restore from checkpoint — cached replay + ckpt_id = (await storage.list_checkpoints(workflow_name="wf"))[-1].checkpoint_id + result2 = await wf.run(checkpoint_id=ckpt_id) + assert result2.get_outputs() == [6] + assert call_count == 1 # not called again + + event_types2 = [e.type for e in result2] + assert "executor_bypassed" in event_types2 + # Should NOT have the live-execution pair + assert "executor_invoked" not in event_types2 + assert "executor_completed" not in event_types2 + + async def test_bypassed_event_carries_cached_data(self): + storage = InMemoryCheckpointStorage() + + @step + async def compute(x: int) -> int: + return x * 10 + + @workflow(checkpoint_storage=storage) + async def wf(x: int) -> int: + return await compute(x) + + await wf.run(3) + ckpt_id = (await storage.list_checkpoints(workflow_name="wf"))[-1].checkpoint_id + + result = await wf.run(checkpoint_id=ckpt_id) + bypassed = [e for e in result if e.type == "executor_bypassed"] + assert len(bypassed) == 1 + assert bypassed[0].executor_id == "compute" + assert bypassed[0].data == 30 + + +# --------------------------------------------------------------------------- +# request_info inside @step (review comment #1) +# --------------------------------------------------------------------------- + + +class TestRequestInfoInStep: + async def test_step_with_run_context_injection(self): + """A @step function with a RunContext parameter gets it auto-injected.""" + + @step + async def review_step(doc: str, ctx: RunContext) -> str: + feedback = await ctx.request_info({"draft": doc}, response_type=str, request_id="s1") + return f"reviewed: {feedback}" + + @workflow + async def wf(doc: str) -> str: + return await review_step(doc) + + # Phase 1: should interrupt + result1 = await wf.run("my doc") + assert result1.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS + assert len(result1.get_request_info_events()) == 1 + assert result1.get_request_info_events()[0].request_id == "s1" + + # Phase 2: resume + result2 = await wf.run(responses={"s1": "LGTM"}) + assert result2.get_outputs() == ["reviewed: LGTM"] + + async def test_step_works_outside_workflow_with_explicit_ctx(self): + """Outside a workflow, the step is transparent — caller provides ctx.""" + + @step + async def needs_ctx(data: str, ctx: RunContext) -> str: + val = ctx.get_state("key", "default") + return f"{data}:{val}" + + # Outside a workflow, pass through directly — caller supplies ctx + ctx = RunContext("test") + ctx.set_state("key", "hello") + result = await needs_ctx("data", ctx) + assert result == "data:hello" + + async def test_step_injects_ctx_before_user_positional_parameters(self): + """RunContext injection should not conflict when ctx is the first step parameter.""" + + @step + async def needs_ctx_first(ctx: RunContext, data: str) -> str: + ctx.set_state("seen", data) + return f"{data}:{ctx.get_state('seen')}" + + @workflow + async def wf(data: str) -> str: + return await needs_ctx_first(data) + + result = await wf.run("draft") + + assert result.get_outputs() == ["draft:draft"] + + async def test_get_run_context_inside_workflow(self): + """get_run_context() returns the active RunContext inside a workflow.""" + from agent_framework import get_run_context + + captured_ctx = None + + @step + async def capture_ctx(x: int) -> int: + nonlocal captured_ctx + captured_ctx = get_run_context() + return x + + @workflow + async def wf(x: int) -> int: + return await capture_ctx(x) + + await wf.run(1) + assert captured_ctx is not None + assert isinstance(captured_ctx, RunContext) + + async def test_get_run_context_outside_workflow(self): + """get_run_context() returns None outside a workflow.""" + from agent_framework import get_run_context + + assert get_run_context() is None + + +# --------------------------------------------------------------------------- +# None response handling (review comment #2) +# --------------------------------------------------------------------------- + + +class TestNoneResponseHandling: + async def test_none_response_logs_warning(self): + """Providing None as a response value should log a warning.""" + + @workflow + async def wf(doc: str, ctx: RunContext) -> str: + val = await ctx.request_info("need input", response_type=str, request_id="r1") + return f"got: {val}" + + # Phase 1 + await wf.run("start") + + # Phase 2: resume with None response — should warn but still work + with caplog_context(logging.getLogger("agent_framework._workflows._functional")) as logs: + result = await wf.run(responses={"r1": None}) + + assert result.get_outputs() == ["got: None"] + assert any("None" in msg and "r1" in msg for msg in logs) + + async def test_none_response_is_returned(self): + """None is a valid (if discouraged) response value.""" + + @workflow + async def wf(x: int, ctx: RunContext) -> str: + val = await ctx.request_info("need data", response_type=str, request_id="r1") + return f"value={val}" + + await wf.run(1) + result = await wf.run(responses={"r1": None}) + assert result.get_outputs() == ["value=None"] + + +# Helper for capturing log messages + + +@contextmanager +def caplog_context(target_logger: logging.Logger) -> Iterator[list[str]]: + """Capture log messages from a specific logger.""" + messages: list[str] = [] + + class _Handler(logging.Handler): + def emit(self, record: logging.LogRecord) -> None: + messages.append(self.format(record)) + + handler = _Handler() + handler.setLevel(logging.WARNING) + target_logger.addHandler(handler) + try: + yield messages + finally: + target_logger.removeHandler(handler) + + +# --------------------------------------------------------------------------- +# Combined regression tests (cross-cutting review comments #1, #2, #3) +# --------------------------------------------------------------------------- + + +class TestHITLInStepWithCaching: + """Regression tests: request_info inside @step combined with caching and bypass.""" + + async def test_preceding_step_bypassed_on_hitl_resume(self): + """When a step after a completed step calls request_info and interrupts, + resuming should bypass the first step (cached) and re-execute the HITL step.""" + call_count_a = 0 + + @step + async def step_a(x: int) -> int: + nonlocal call_count_a + call_count_a += 1 + return x + 1 + + @step + async def step_b(val: int, ctx: RunContext) -> str: + feedback = await ctx.request_info({"val": val}, response_type=str, request_id="r1") + return f"{val}:{feedback}" + + @workflow + async def wf(x: int) -> str: + a = await step_a(x) + return await step_b(a) + + # Phase 1: step_a completes, step_b interrupts + result1 = await wf.run(5) + assert call_count_a == 1 + assert result1.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS + + # Phase 2: resume — step_a should be bypassed, step_b re-executes + result2 = await wf.run(responses={"r1": "ok"}) + assert call_count_a == 1 # step_a not called again + assert result2.get_outputs() == ["6:ok"] + + event_types = [e.type for e in result2] + assert "executor_bypassed" in event_types + + async def test_hitl_step_with_checkpoint_full_lifecycle(self): + """Full lifecycle: run -> interrupt -> resume -> checkpoint restore -> all bypassed.""" + storage = InMemoryCheckpointStorage() + + @step + async def compute(x: int) -> int: + return x * 10 + + @step + async def review(val: int, ctx: RunContext) -> str: + feedback = await ctx.request_info({"val": val}, response_type=str, request_id="rev") + return f"reviewed({val}):{feedback}" + + @workflow(checkpoint_storage=storage) + async def wf(x: int) -> str: + v = await compute(x) + return await review(v) + + # Phase 1: interrupt + result1 = await wf.run(3) + assert result1.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS + + # Phase 2: resume + result2 = await wf.run(responses={"rev": "LGTM"}) + assert result2.get_outputs() == ["reviewed(30):LGTM"] + + # Phase 3: restore from latest checkpoint -- both steps should be bypassed + ckpt_id = (await storage.list_checkpoints(workflow_name="wf"))[-1].checkpoint_id + result3 = await wf.run(checkpoint_id=ckpt_id) + assert result3.get_outputs() == ["reviewed(30):LGTM"] + + event_types3 = [e.type for e in result3] + bypassed = [e for e in result3 if e.type == "executor_bypassed"] + assert len(bypassed) == 2 + assert "executor_invoked" not in event_types3 + + async def test_none_response_in_step_request_info(self): + """None response inside a @step request_info should warn and return None.""" + + @step + async def needs_feedback(doc: str, ctx: RunContext) -> str: + val = await ctx.request_info({"doc": doc}, response_type=str, request_id="r1") + return f"got:{val}" + + @workflow + async def wf(doc: str) -> str: + return await needs_feedback(doc) + + await wf.run("draft") + + with caplog_context(logging.getLogger("agent_framework._workflows._functional")) as logs: + result = await wf.run(responses={"r1": None}) + + assert result.get_outputs() == ["got:None"] + assert any("None" in msg and "r1" in msg for msg in logs) + + async def test_step_hitl_does_not_emit_executor_failed(self): + """WorkflowInterrupted from request_info inside a step should NOT emit executor_failed.""" + + @step + async def hitl_step(x: int, ctx: RunContext) -> str: + return await ctx.request_info("need data", response_type=str, request_id="r1") + + @workflow + async def wf(x: int) -> str: + return await hitl_step(x) + + result = await wf.run(1) + event_types = [e.type for e in result] + assert "executor_failed" not in event_types + assert result.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS + + +# --------------------------------------------------------------------------- +# Regression tests for ultrareview findings +# --------------------------------------------------------------------------- + + +class TestDeterministicAutoRequestId: + """Regression for bug_001: auto-generated request_info ids must be stable across replay.""" + + async def test_auto_request_id_roundtrips_on_resume(self): + @workflow + async def wf(x: int, ctx: RunContext) -> str: + # No request_id — framework must generate a deterministic one + val = await ctx.request_info("need data", response_type=str) + return f"got:{val}" + + result1 = await wf.run(1) + assert result1.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS + requests = result1.get_request_info_events() + assert len(requests) == 1 + rid = requests[0].request_id + assert rid # non-empty + + # Resume with the id the caller just received. + result2 = await wf.run(responses={rid: "hello"}) + assert result2.get_final_state() == WorkflowRunState.IDLE + assert result2.get_outputs() == ["got:hello"] + + async def test_multiple_auto_ids_are_distinct_and_stable(self): + @workflow + async def wf(x: int, ctx: RunContext) -> str: + a = await ctx.request_info("first", response_type=str) + b = await ctx.request_info("second", response_type=str) + return f"{a}/{b}" + + r1 = await wf.run(1) + rid1 = r1.get_request_info_events()[0].request_id + r2 = await wf.run(responses={rid1: "A"}) + rid2 = r2.get_request_info_events()[0].request_id + assert rid1 != rid2 + r3 = await wf.run(responses={rid1: "A", rid2: "B"}) + assert r3.get_outputs() == ["A/B"] + + async def test_cached_step_advances_auto_request_id_counter(self): + call_count = 0 + + @step + async def first_review(value: int, ctx: RunContext) -> str: + nonlocal call_count + call_count += 1 + return await ctx.request_info({"step": "first", "value": value}, response_type=str) + + @step + async def second_review(value: int, ctx: RunContext) -> str: + return await ctx.request_info({"step": "second", "value": value}, response_type=str) + + @workflow + async def wf(value: int) -> str: + first = await first_review(value) + second = await second_review(value) + return f"{first}/{second}" + + first_run = await wf.run(1) + first_request_id = first_run.get_request_info_events()[0].request_id + assert first_request_id == "auto::0" + + second_run = await wf.run(responses={first_request_id: "A"}) + second_request_id = second_run.get_request_info_events()[0].request_id + assert second_request_id == "auto::1" + completed_call_count = call_count + + final_run = await wf.run(responses={first_request_id: "A", second_request_id: "B"}) + + assert call_count == completed_call_count + assert final_run.get_outputs() == ["A/B"] + + +class TestPendingRequestsPruned: + """Regression for bug_007: resolved requests must be pruned from _pending_requests.""" + + async def test_final_checkpoint_no_longer_claims_resolved_requests_pending(self): + storage = InMemoryCheckpointStorage() + + @workflow(checkpoint_storage=storage) + async def wf(x: int, ctx: RunContext) -> str: + a = await ctx.request_info("q1", response_type=str, request_id="r1") + b = await ctx.request_info("q2", response_type=str, request_id="r2") + return f"{a}/{b}" + + await wf.run(1) + await wf.run(responses={"r1": "A"}) + result = await wf.run(responses={"r1": "A", "r2": "B"}) + assert result.get_final_state() == WorkflowRunState.IDLE + # Latest checkpoint must show no pending requests. + checkpoints = await storage.list_checkpoints(workflow_name="wf") + assert checkpoints, "expected at least one checkpoint to have been saved" + final = checkpoints[-1] + assert final.pending_request_info_events == {} + + +class TestArityValidation: + """Regression for merged_bug_003: validate workflow signature arity.""" + + def test_multi_non_ctx_param_rejected_at_decoration(self): + with pytest.raises(ValueError, match="multiple non-RunContext parameters"): + + @workflow + async def wf(a: str, b: str, ctx: RunContext) -> str: + return f"{a}+{b}" + + async def test_ctx_only_workflow_with_message_raises_clear_error(self): + @workflow + async def wf(ctx: RunContext) -> str: + return "no message used" + + with pytest.raises(ValueError, match="no non-RunContext parameter"): + await wf.run("important input") + + def test_ctx_only_workflow_decoration_succeeds(self): + # Decoration must not raise even though the workflow has no + # message-receiving parameter. (Running it without a message still + # requires providing responses or a checkpoint_id — that's + # _validate_run_params's job, not ours.) + @workflow + async def wf(ctx: RunContext) -> str: + return "ok" + + assert wf is not None + + +class TestStaleResponsesRejected: + """Regression for bug_014: stale responses after clean completion must be rejected.""" + + async def test_responses_after_clean_completion_raise(self): + @workflow + async def wf(x: int) -> int: + return x * 2 + + await wf.run(5) # clean completion, no pending requests + with pytest.raises(ValueError, match="no pending request_info"): + await wf.run(responses={"stale": "x"}) + + async def test_responses_mismatched_key_raises(self): + @workflow + async def wf(x: int, ctx: RunContext) -> str: + return await ctx.request_info("q", response_type=str, request_id="r1") + + await wf.run(1) # interrupts with r1 pending + with pytest.raises(ValueError, match="do not answer"): + await wf.run(responses={"definitely_not_r1": "x"}) + + +class TestReservedStateKeys: + """Regression for bug_017: set_state must reject underscore-prefixed keys.""" + + async def test_underscore_key_rejected(self): + @workflow + async def wf(x: int, ctx: RunContext) -> int: + ctx.set_state("_private", "user value") + return x + + with pytest.raises(ValueError, match="reserved for framework"): + await wf.run(1) + + async def test_normal_key_still_works(self): + @workflow + async def wf(x: int, ctx: RunContext) -> int: + ctx.set_state("normal_key", "v") + assert ctx.get_state("normal_key") == "v" + return x + + r = await wf.run(1) + assert r.get_outputs() == [1] + + +class TestDeepcopyOnCacheHit: + """Regression for bug_002: cache hits must not deepcopy args.""" + + async def test_step_with_non_deepcopyable_arg_replays(self): + import threading + + @step + async def takes_lock(lock: threading.Lock, n: int) -> int: + return n + 1 + + @workflow + async def wf(x: int) -> int: + lock = threading.Lock() + return await takes_lock(lock, x) + + # First run — must succeed despite threading.Lock not being deepcopyable + # (deepcopy now wrapped in try/except, falls back to live reference for + # the invocation_data event only). + r1 = await wf.run(5) + assert r1.get_outputs() == [6] + + +class TestStepDiscoveryAttributeAccess: + """Regression for bug_008: checkpoint hash must differ when function body changes.""" + + async def test_signature_hash_changes_when_function_body_changes(self): + @workflow + async def wf_a(x: int) -> int: + return x + 1 + + @workflow(name="wf_b") + async def wf_b(x: int) -> int: + return x * 100 + + # Two different function bodies -> different hashes even though the + # static step-name scan would produce the same empty list. + assert wf_a.graph_signature_hash != wf_b.graph_signature_hash + + +class TestAsAgentSignatureParity: + """Regression for bug_015: as_agent signature must accept description/context_providers.""" + + async def test_as_agent_accepts_description_override(self): + @workflow(description="workflow level") + async def wf(x: str) -> str: + return x.upper() + + agent = wf.as_agent(name="a", description="agent level") + assert agent.description == "agent level" + + async def test_as_agent_accepts_context_providers_kwarg(self): + @workflow + async def wf(x: str) -> str: + return x + + providers = [object()] # opaque placeholder; must be stored without error + agent = wf.as_agent(context_providers=providers) + assert list(agent.context_providers or []) == providers + + async def test_as_agent_description_defaults_to_workflow_description(self): + @workflow(description="from workflow") + async def wf(x: str) -> str: + return x + + agent = wf.as_agent() + assert agent.description == "from workflow" + + +class TestFunctionalWorkflowAgentHITL: + """Regression for bug_013: .as_agent() must surface request_info events.""" + + async def test_request_info_surfaces_as_function_approval_request(self): + @workflow + async def wf(x: str, ctx: RunContext) -> str: + answer = await ctx.request_info({"need": x}, response_type=str, request_id="rid-1") + return f"got:{answer}" + + agent = wf.as_agent() + response = await agent.run("topic") + + # Agent must expose the pending request_id. + assert "rid-1" in agent.pending_requests + + # Response must contain at least one content item whose type is + # function_approval_request (or equivalent). + approval_found = False + for message in response.messages: + for content in message.contents: + if getattr(content, "type", None) == "function_approval_request": + approval_found = True + break + assert approval_found, "expected FunctionApprovalRequestContent in agent response" + + async def test_resume_via_agent_responses_kwarg(self): + @workflow + async def wf(x: str, ctx: RunContext) -> str: + answer = await ctx.request_info(x, response_type=str, request_id="rid-1") + return f"got:{answer}" + + agent = wf.as_agent() + # First phase: suspend + await agent.run("topic") + # Second phase: resume via the agent surface + response = await agent.run(responses={"rid-1": "answered"}) + # Agent's final response should contain the workflow's text output. + text_blobs: list[str] = [] + for message in response.messages: + for content in message.contents: + text = getattr(content, "text", None) + if text: + text_blobs.append(text) + assert any("got:answered" in t for t in text_blobs) + + +class TestRunDocstringAllowsResponsesAndCheckpoint: + """Regression for bug_010: docstring must permit responses+checkpoint_id combo.""" + + def test_docstring_says_at_least_one(self): + doc = FunctionalWorkflow.run.__doc__ or "" + assert "At least one" in doc or "at least one" in doc + assert "Exactly one" not in doc + + +class TestFunctionalWorkflowExperimentalStage: + """Tests for the experimental stage annotations applied to functional workflow APIs.""" + + def test_public_symbols_are_marked_experimental(self) -> None: + symbols = [ + get_run_context, + RunContext, + StepWrapper, + step, + FunctionalWorkflow, + workflow, + FunctionalWorkflowAgent, + ] + + for symbol in symbols: + assert symbol.__feature_stage__ == "experimental" + assert symbol.__feature_id__ == ExperimentalFeature.FUNCTIONAL_WORKFLOWS.value + assert symbol.__doc__ is not None + assert ".. warning:: Experimental" in symbol.__doc__ diff --git a/python/samples/01-get-started/05_functional_workflow_with_agents.py b/python/samples/01-get-started/05_functional_workflow_with_agents.py new file mode 100644 index 0000000000..05e5f2637f --- /dev/null +++ b/python/samples/01-get-started/05_functional_workflow_with_agents.py @@ -0,0 +1,49 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +Functional Workflow with Agents — Call agents inside @workflow + +This sample shows how to call agents inside a functional workflow. +Agent calls are just regular async function calls — no special wrappers needed. +""" + +import asyncio + +from agent_framework import Agent, workflow +from agent_framework.foundry import FoundryChatClient +from azure.identity import AzureCliCredential + +# +client = FoundryChatClient(credential=AzureCliCredential()) + +writer = Agent( + name="WriterAgent", + instructions="Write a short poem (4 lines max) about the given topic.", + client=client, +) + +reviewer = Agent( + name="ReviewerAgent", + instructions="Review the given poem in one sentence. Is it good?", + client=client, +) +# + + +# +@workflow +async def poem_workflow(topic: str) -> str: + """Write a poem, then review it.""" + poem = (await writer.run(f"Write a poem about: {topic}")).text + review = (await reviewer.run(f"Review this poem: {poem}")).text + return f"Poem:\n{poem}\n\nReview: {review}" +# + + +async def main() -> None: + result = await poem_workflow.run("a cat learning to code") + print(result.get_outputs()[0]) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/01-get-started/06_functional_workflow_basics.py b/python/samples/01-get-started/06_functional_workflow_basics.py new file mode 100644 index 0000000000..8033a7ac4e --- /dev/null +++ b/python/samples/01-get-started/06_functional_workflow_basics.py @@ -0,0 +1,57 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +Functional Workflow Basics — Orchestrate async functions with @workflow + +The functional API lets you write workflows as plain Python async functions. +No graph concepts, no edges, no executor classes — just call functions +and use native control flow (if/else, loops, asyncio.gather). + +This sample builds a minimal pipeline with two steps: +1. Convert text to uppercase +2. Reverse the text + +No external services are required. +""" + +import asyncio + +from agent_framework import workflow + + +# Plain async functions — no decorators needed +async def to_upper_case(text: str) -> str: + """Convert input to uppercase.""" + return text.upper() + + +async def reverse_text(text: str) -> str: + """Reverse the string.""" + return text[::-1] + + +# +@workflow +async def text_workflow(text: str) -> str: + """Uppercase the text, then reverse it.""" + upper = await to_upper_case(text) + return await reverse_text(upper) +# + + +async def main() -> None: + # + result = await text_workflow.run("hello world") + print(f"Output: {result.get_outputs()}") + print(f"Final state: {result.get_final_state()}") + # + + """ + Expected output: + Output: ['DLROW OLLEH'] + Final state: WorkflowRunState.IDLE + """ + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/01-get-started/05_first_workflow.py b/python/samples/01-get-started/07_first_graph_workflow.py similarity index 87% rename from python/samples/01-get-started/05_first_workflow.py rename to python/samples/01-get-started/07_first_graph_workflow.py index 74720e529f..2a84fabfab 100644 --- a/python/samples/01-get-started/05_first_workflow.py +++ b/python/samples/01-get-started/07_first_graph_workflow.py @@ -12,9 +12,12 @@ from typing_extensions import Never """ -First Workflow — Chain executors with edges +First Graph Workflow — Chain executors with edges -This sample builds a minimal workflow with two steps: +The graph API gives you full control over execution topology: edges, +fan-out/fan-in, switch/case, and superstep-based checkpointing. + +This sample builds a minimal graph workflow with two steps: 1. Convert text to uppercase (class-based executor) 2. Reverse the text (function-based executor) diff --git a/python/samples/01-get-started/06_host_your_agent.py b/python/samples/01-get-started/08_host_your_agent.py similarity index 100% rename from python/samples/01-get-started/06_host_your_agent.py rename to python/samples/01-get-started/08_host_your_agent.py diff --git a/python/samples/01-get-started/README.md b/python/samples/01-get-started/README.md index 9ecfdf08c9..61aefc6592 100644 --- a/python/samples/01-get-started/README.md +++ b/python/samples/01-get-started/README.md @@ -24,8 +24,10 @@ export FOUNDRY_MODEL="gpt-4o" # optional, defaults to gpt-4o | 2 | [02_add_tools.py](02_add_tools.py) | Define a function tool with `@tool` and attach it to an agent. | | 3 | [03_multi_turn.py](03_multi_turn.py) | Keep conversation history across turns with `AgentSession`. | | 4 | [04_memory.py](04_memory.py) | Add dynamic context with a custom `ContextProvider`. | -| 5 | [05_first_workflow.py](05_first_workflow.py) | Chain executors into a workflow with edges. | -| 6 | [06_host_your_agent.py](06_host_your_agent.py) | Host a single agent with Azure Functions. | +| 5 | [05_functional_workflow_with_agents.py](05_functional_workflow_with_agents.py) | Call agents inside a functional workflow. | +| 6 | [06_functional_workflow_basics.py](06_functional_workflow_basics.py) | Write a workflow as a plain async function. | +| 7 | [07_first_graph_workflow.py](07_first_graph_workflow.py) | Chain executors into a graph workflow with edges. | +| 8 | [08_host_your_agent.py](08_host_your_agent.py) | Host a single agent with Azure Functions. | Run any sample with: diff --git a/python/samples/03-workflows/README.md b/python/samples/03-workflows/README.md index ae3292a07a..15d224563f 100644 --- a/python/samples/03-workflows/README.md +++ b/python/samples/03-workflows/README.md @@ -30,6 +30,20 @@ Once comfortable with these, explore the rest of the samples below. ## Samples Overview (by directory) +### functional + +Write workflows as plain Python async functions — no graph concepts, no executor classes, no edges. Use native control flow (`if`/`else`, loops, `asyncio.gather`) for branching and parallelism. + +| Sample | File | Concepts | +|---|---|---| +| Basic Pipeline | [functional/basic_pipeline.py](./functional/basic_pipeline.py) | Sequential steps as plain async functions | +| Basic Streaming Pipeline | [functional/basic_streaming_pipeline.py](./functional/basic_streaming_pipeline.py) | Stream workflow events in real time with `run(stream=True)` | +| Parallel Pipeline | [functional/parallel_pipeline.py](./functional/parallel_pipeline.py) | Fan-out/fan-in with `asyncio.gather` | +| Steps and Checkpointing | [functional/steps_and_checkpointing.py](./functional/steps_and_checkpointing.py) | `@step` decorator for per-step checkpointing and observability | +| Human-in-the-Loop Review | [functional/hitl_review.py](./functional/hitl_review.py) | HITL with `ctx.request_info()` and replay | +| Agent Integration | [functional/agent_integration.py](./functional/agent_integration.py) | Calling agents inside workflow steps | +| Naive Group Chat | [functional/naive_group_chat.py](./functional/naive_group_chat.py) | Simple round-robin group chat as a plain loop | + ### agents | Sample | File | Concepts | diff --git a/python/samples/03-workflows/functional/agent_integration.py b/python/samples/03-workflows/functional/agent_integration.py new file mode 100644 index 0000000000..b5911cb690 --- /dev/null +++ b/python/samples/03-workflows/functional/agent_integration.py @@ -0,0 +1,107 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Calling agents inside functional workflows. + +Agent calls work inside @workflow as plain function calls — no decorator needed. +Just call the agent and use the result. + +If you want per-step caching (so agent calls don't re-execute on HITL resume +or crash recovery), add @step. Since each agent call hits an LLM API (time + +money), @step is often worth it. But it's always opt-in. + +This sample shows both approaches side-by-side so you can see the difference. +""" + +import asyncio + +from agent_framework import Agent, step, workflow +from agent_framework.foundry import FoundryChatClient +from azure.identity import AzureCliCredential + +# --------------------------------------------------------------------------- +# Create agents +# --------------------------------------------------------------------------- + +client = FoundryChatClient(credential=AzureCliCredential()) + +classifier_agent = Agent( + name="ClassifierAgent", + instructions=( + "Classify documents into one category: Technical, Legal, Marketing, or Scientific. " + "Reply with only the category name." + ), + client=client, +) + +writer_agent = Agent( + name="WriterAgent", + instructions="Summarize the given content in one sentence.", + client=client, +) + +reviewer_agent = Agent( + name="ReviewerAgent", + instructions="Review the given summary in one sentence. Is it accurate and complete?", + client=client, +) + +# --------------------------------------------------------------------------- +# Simplest approach: call agents directly inside the workflow. +# No @step, no wrappers — just plain function calls. +# --------------------------------------------------------------------------- + + +@workflow +async def simple_pipeline(document: str) -> str: + """Process a document — agents called inline, no @step.""" + classification = (await classifier_agent.run(f"Classify this document: {document}")).text + summary = (await writer_agent.run(f"Summarize: {document}")).text + review = (await reviewer_agent.run(f"Review this summary: {summary}")).text + + return f"Classification: {classification}\nSummary: {summary}\nReview: {review}" + + +# --------------------------------------------------------------------------- +# With @step: agent results are cached. On HITL resume or checkpoint +# recovery, completed steps return their saved result instead of calling +# the LLM again. Worth it for expensive operations. +# --------------------------------------------------------------------------- + + +@step +async def classify_document(doc: str) -> str: + return (await classifier_agent.run(f"Classify this document: {doc}")).text + + +@step +async def generate_summary(doc: str) -> str: + return (await writer_agent.run(f"Summarize: {doc}")).text + + +@step +async def review_summary(summary: str) -> str: + return (await reviewer_agent.run(f"Review this summary: {summary}")).text + + +@workflow +async def cached_pipeline(document: str) -> str: + """Same pipeline, but @step caches each agent call.""" + classification = await classify_document(document) + summary = await generate_summary(document) + review = await review_summary(summary) + + return f"Classification: {classification}\nSummary: {summary}\nReview: {review}" + + +async def main(): + # Simple version — agents called inline + result = await simple_pipeline.run("This is a technical document about machine learning...") + print(result.get_outputs()[0]) + + # Cached version — same result, but steps won't re-execute on resume + result = await cached_pipeline.run("This is a technical document about machine learning...") + print(f"\nCached: {result.get_outputs()[0]}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/03-workflows/functional/basic_pipeline.py b/python/samples/03-workflows/functional/basic_pipeline.py new file mode 100644 index 0000000000..81514da53a --- /dev/null +++ b/python/samples/03-workflows/functional/basic_pipeline.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Basic sequential pipeline using the functional workflow API. + +The simplest possible workflow: plain async functions orchestrated by @workflow. +No @step decorator needed — just write Python. +""" + +import asyncio + +from agent_framework import workflow + + +# These are plain async functions — no decorators needed. +# They run normally inside the workflow, just like any other Python function. +async def fetch_data(url: str) -> dict[str, str | int]: + """Simulate fetching data from a URL.""" + return {"url": url, "content": f"Data from {url}", "status": 200} + + +async def transform_data(data: dict[str, str | int]) -> str: + """Transform raw data into a summary string.""" + return f"[{data['status']}] {data['content']}" + + +# @workflow turns this async function into a FunctionalWorkflow object. +# Without it, this is just a normal async function. With it, you get: +# - .run() that returns a WorkflowRunResult with events and outputs +# - .run(stream=True) for streaming events in real time +# - .as_agent() to use this workflow anywhere an agent is expected +# +# The function's first parameter receives the input from .run("..."). +# Add a `ctx: RunContext` parameter only if you need HITL, state, or custom events. +@workflow +async def data_pipeline(url: str) -> str: + """A simple sequential data pipeline.""" + raw = await fetch_data(url) + summary = await transform_data(raw) + + # This is just a function — plain Python works between calls. + # No need to wrap every operation in a separate async function. + is_valid = len(summary) > 0 and "[200]" in summary + tag = "VALID" if is_valid else "INVALID" + + # Returning a value automatically emits it as an output. + # Callers retrieve it via result.get_outputs(). + return f"[{tag}] {summary}" + + +async def main(): + # .run() is provided by @workflow — a plain async function wouldn't have it + result = await data_pipeline.run("https://example.com/api/data") + print("Output:", result.get_outputs()[0]) + print("State:", result.get_final_state()) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/03-workflows/functional/basic_streaming_pipeline.py b/python/samples/03-workflows/functional/basic_streaming_pipeline.py new file mode 100644 index 0000000000..4ee61da60f --- /dev/null +++ b/python/samples/03-workflows/functional/basic_streaming_pipeline.py @@ -0,0 +1,63 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Basic streaming pipeline using the functional workflow API. + +Stream workflow events in real time with run(stream=True). +""" + +import asyncio + +from agent_framework import workflow + + +# Plain async functions — no decorators needed for simple helpers. +async def fetch_data(url: str) -> dict[str, str | int]: + """Simulate fetching data from a URL.""" + return {"url": url, "content": f"Data from {url}", "status": 200} + + +async def transform_data(data: dict[str, str | int]) -> str: + """Transform raw data into a summary string.""" + return f"[{data['status']}] {data['content']}" + + +async def validate_result(summary: str) -> bool: + """Validate the transformed result.""" + return len(summary) > 0 and "[200]" in summary + + +# @workflow enables .run(stream=True), which returns a ResponseStream +# you can iterate over with `async for`. Without @workflow, you'd just +# have a normal async function with no streaming capability. +@workflow +async def data_pipeline(url: str) -> str: + """A simple sequential data pipeline.""" + raw = await fetch_data(url) + summary = await transform_data(raw) + is_valid = await validate_result(summary) + + return f"{summary} (valid={is_valid})" + + +async def main(): + # run(stream=True) returns a ResponseStream that yields events as they + # are produced. The raw stream includes lifecycle events (started, status) + # alongside application events — filter by event.type to find what you need. + stream = data_pipeline.run("https://example.com/api/data", stream=True) + async for event in stream: + if event.type == "output": + print(f"Output: {event.data}") + + # After iteration, get_final_response() returns the WorkflowRunResult + result = await stream.get_final_response() + print(f"Final state: {result.get_final_state()}") + + """ + Expected output: + Output: [200] Data from https://example.com/api/data (valid=True) + Final state: WorkflowRunState.IDLE + """ + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/03-workflows/functional/hitl_review.py b/python/samples/03-workflows/functional/hitl_review.py new file mode 100644 index 0000000000..39f2dae885 --- /dev/null +++ b/python/samples/03-workflows/functional/hitl_review.py @@ -0,0 +1,84 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Human-in-the-loop review pipeline using functional workflows. + +Demonstrates ctx.request_info() for pausing the workflow to wait for +external input and resuming with run(responses={...}). + +HITL works with or without @step. The difference is what happens on resume: +- Without @step: every function re-executes from the top (fine for cheap calls). +- With @step: completed functions return their saved result instantly. + +This sample uses @step on write_draft() because it simulates an expensive +operation that shouldn't re-run just because the workflow was paused. +""" + +import asyncio + +from agent_framework import RunContext, WorkflowRunState, step, workflow + + +# @step saves the result. When the workflow resumes after the HITL pause, +# this returns its saved result instead of running the expensive operation again. +# +# In a real workflow you might call an agent here instead: +# @step +# async def write_draft(topic: str) -> str: +# return (await writer_agent.run(f"Write a draft about: {topic}")).text +@step +async def write_draft(topic: str) -> str: + """Simulate writing a draft — expensive, shouldn't re-run on resume.""" + print(f" write_draft executing for '{topic}'") + return f"Draft document about '{topic}': Lorem ipsum dolor sit amet..." + + +@step +async def revise_draft(draft: str, feedback: str) -> str: + """Revise the draft based on feedback.""" + return f"Revised: {draft[:50]}... [Applied feedback: {feedback}]" + + +@workflow +async def review_pipeline(topic: str, ctx: RunContext) -> str: + """Write a draft, get human review, then revise.""" + draft = await write_draft(topic) + + # ctx.request_info() suspends the workflow here. The caller gets back + # a WorkflowRunResult with state IDLE_WITH_PENDING_REQUESTS and can + # inspect the pending request via result.get_request_info_events(). + feedback = await ctx.request_info( + {"draft": draft, "instructions": "Please review this draft"}, + response_type=str, + request_id="review_request", + ) + + # This only executes after the caller resumes with run(responses={...}). + # write_draft above returns its saved result (thanks to @step), + # request_info returns the provided response, and we continue here. + return await revise_draft(draft, feedback) + + +async def main(): + # Phase 1: Run until the workflow pauses for human input + print("=== Phase 1: Initial run ===") + result1 = await review_pipeline.run("AI Safety") + + # If request_info() was reached, the state is IDLE_WITH_PENDING_REQUESTS. + # If the workflow completed without hitting request_info(), it would be IDLE. + print(f"State: {(final_state := result1.get_final_state())}") + assert final_state == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS + + requests = result1.get_request_info_events() + print(f"Pending request: {requests[0].request_id}") + + # Phase 2: Resume with the human's response + print("\n=== Phase 2: Resume with feedback ===") + print("(write_draft should NOT execute again — saved by @step)") + result2 = await review_pipeline.run(responses={"review_request": "Add more details about alignment research"}) + + print(f"State: {result2.get_final_state()}") + print(f"Output: {result2.get_outputs()[0]}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/03-workflows/functional/naive_group_chat.py b/python/samples/03-workflows/functional/naive_group_chat.py new file mode 100644 index 0000000000..45a3266049 --- /dev/null +++ b/python/samples/03-workflows/functional/naive_group_chat.py @@ -0,0 +1,82 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Naive group chat using the functional workflow API. + +A simple round-robin group chat where agents take turns responding. +Because it's just a function, you control the loop, the turn order, +and the termination condition with plain Python — no framework abstractions. + +Compare this with the graph-based GroupChat orchestration to see how the +functional API lets you start simple and add complexity only when needed. +""" + +import asyncio + +from agent_framework import Agent, Message, workflow +from agent_framework.foundry import FoundryChatClient +from azure.identity import AzureCliCredential + +# --------------------------------------------------------------------------- +# Create agents +# --------------------------------------------------------------------------- + +client = FoundryChatClient(credential=AzureCliCredential()) + +expert = Agent( + name="PythonExpert", + instructions=( + "You are a Python expert in a group discussion. " + "Answer questions about Python and refine your answer based on feedback. " + "Keep responses concise (2-3 sentences)." + ), + client=client, +) + +critic = Agent( + name="Critic", + instructions=( + "You are a constructive critic in a group discussion. " + "Point out edge cases, gotchas, or missing nuances in the previous answer. " + "If the answer is solid, say so briefly." + ), + client=client, +) + +summarizer = Agent( + name="Summarizer", + instructions=( + "You are a summarizer in a group discussion. " + "After the discussion, provide a final concise summary that incorporates " + "the expert's answer and the critic's feedback. Keep it to 2-3 sentences." + ), + client=client, +) + +# --------------------------------------------------------------------------- +# A naive group chat is just a loop — no special framework needed +# --------------------------------------------------------------------------- + + +@workflow +async def group_chat(question: str) -> str: + """Round-robin group chat: expert answers, critic reviews, summarizer wraps up.""" + participants = [expert, critic, summarizer] + # Passing list[Message] keeps roles/authorship intact between turns, + # instead of stringifying everything into a single prompt. + conversation: list[Message] = [Message("user", [question])] + + # Simple round-robin: each agent sees the full conversation so far + for agent in participants: + response = await agent.run(conversation) + conversation.extend(response.messages) + + return "\n\n".join(f"{m.author_name or m.role}: {m.text}" for m in conversation) + + +async def main(): + result = await group_chat.run("What's the difference between a list and a tuple in Python?") + print(result.get_outputs()[0]) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/03-workflows/functional/parallel_pipeline.py b/python/samples/03-workflows/functional/parallel_pipeline.py new file mode 100644 index 0000000000..21fd8dceac --- /dev/null +++ b/python/samples/03-workflows/functional/parallel_pipeline.py @@ -0,0 +1,66 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Parallel pipeline using asyncio.gather with functional workflows. + +Fan-out/fan-in uses native Python concurrency via asyncio.gather. +No @step needed — still just plain async functions. +""" + +import asyncio + +from agent_framework import workflow + + +# Plain async functions — asyncio.gather handles the concurrency, +# no framework primitives needed for parallelism. +async def research_web(topic: str) -> str: + """Simulate web research.""" + await asyncio.sleep(0.05) + return f"Web results for '{topic}': 10 articles found" + + +async def research_papers(topic: str) -> str: + """Simulate academic paper search.""" + await asyncio.sleep(0.05) + return f"Papers on '{topic}': 3 relevant papers" + + +async def research_news(topic: str) -> str: + """Simulate news search.""" + await asyncio.sleep(0.05) + return f"News about '{topic}': 5 recent articles" + + +async def synthesize(sources: list[str]) -> str: + """Combine research results into a summary.""" + return "Research Summary:\n" + "\n".join(f" - {s}" for s in sources) + + +# @workflow wraps the orchestration logic so you get .run(), streaming, +# and events. The functions it calls are plain Python — no decorators +# needed just because they're inside a workflow. +@workflow +async def research_pipeline(topic: str) -> str: + """Fan-out to three research tasks, then synthesize results.""" + # asyncio.gather runs all three concurrently — this is standard Python, + # not a framework concept. Use it the same way you would anywhere else. + # + # Tip: if any of these were wrapped with @step (e.g. an expensive agent call), + # the pattern is identical — @step composes with asyncio.gather, so each + # branch is independently cached on HITL resume or checkpoint restore. + web, papers, news = await asyncio.gather( + research_web(topic), + research_papers(topic), + research_news(topic), + ) + + return await synthesize([web, papers, news]) + + +async def main(): + result = await research_pipeline.run("AI agents") + print(result.get_outputs()[0]) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/03-workflows/functional/steps_and_checkpointing.py b/python/samples/03-workflows/functional/steps_and_checkpointing.py new file mode 100644 index 0000000000..93a9423df0 --- /dev/null +++ b/python/samples/03-workflows/functional/steps_and_checkpointing.py @@ -0,0 +1,97 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Introducing @step: per-step checkpointing and observability. + +The previous samples used plain functions — and that works. Workflows support +HITL (ctx.request_info) and checkpointing regardless of whether you use @step. + +The difference: without @step, a resumed workflow re-executes every function +call from the top. That's fine for cheap functions. But for expensive operations +(API calls, agent runs, etc.) you don't want to pay that cost again. + +@step saves each function's result so it skips re-execution on resume: +- On HITL resume, completed steps return their saved result instantly. +- On crash recovery from a checkpoint, earlier step results are restored. +- Each step emits executor_invoked/executor_completed events for observability. + +@step is opt-in. Plain functions still work alongside @step in the same workflow. +""" + +import asyncio + +from agent_framework import InMemoryCheckpointStorage, step, workflow + +# Track call counts to show which functions actually execute on resume +fetch_calls = 0 +transform_calls = 0 + + +# @step saves this function's result. On resume, it returns the saved +# result instead of re-executing — useful because this is expensive. +@step +async def fetch_data(url: str) -> dict[str, str | int]: + """Expensive operation — @step prevents re-execution on resume.""" + global fetch_calls + fetch_calls += 1 + print(f" fetch_data called (call #{fetch_calls})") + return {"url": url, "content": f"Data from {url}", "status": 200} + + +@step +async def transform_data(data: dict[str, str | int]) -> str: + """Another expensive operation — @step saves the result.""" + global transform_calls + transform_calls += 1 + print(f" transform_data called (call #{transform_calls})") + return f"[{data['status']}] {data['content']}" + + +# No @step — this is cheap, so it just re-runs on resume. That's fine. +async def validate_result(summary: str) -> bool: + """Cheap validation — no @step needed.""" + return len(summary) > 0 and "[200]" in summary + + +storage = InMemoryCheckpointStorage() + + +# checkpoint_storage tells @workflow where to persist step results. +# Each @step saves a checkpoint after it completes. +@workflow(checkpoint_storage=storage) +async def data_pipeline(url: str) -> str: + """Mix of @step functions and plain functions.""" + raw = await fetch_data(url) + summary = await transform_data(raw) + is_valid = await validate_result(summary) + + return f"{summary} (valid={is_valid})" + + +async def main(): + # --- Run 1: Everything executes normally --- + print("=== Run 1: Fresh execution ===") + result = await data_pipeline.run("https://example.com/api/data") + print(f"Output: {result.get_outputs()[0]}") + print(f"fetch_calls={fetch_calls}, transform_calls={transform_calls}") + + # @step functions emit executor events; plain functions don't. + print("\nEvents:") + for event in result: + if event.type in ("executor_invoked", "executor_completed"): + print(f" {event.type}: {event.executor_id}") + + # --- Run 2: Restore from checkpoint --- + # The workflow re-executes, but @step functions return saved results. + # Only validate_result() (no @step) actually runs again. + print("\n=== Run 2: Restored from checkpoint ===") + latest = await storage.get_latest(workflow_name="data_pipeline") + assert latest is not None + + result2 = await data_pipeline.run(checkpoint_id=latest.checkpoint_id) + print(f"Output: {result2.get_outputs()[0]}") + print(f"fetch_calls={fetch_calls}, transform_calls={transform_calls}") + print("(call counts unchanged — @step results were restored from checkpoint)") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/README.md b/python/samples/README.md index 953d9ff9bb..0ff8563933 100644 --- a/python/samples/README.md +++ b/python/samples/README.md @@ -20,8 +20,10 @@ Start with `01-get-started/` and work through the numbered files: 2. **[02_add_tools.py](./01-get-started/02_add_tools.py)** — Add function tools with `@tool` 3. **[03_multi_turn.py](./01-get-started/03_multi_turn.py)** — Multi-turn conversations with `AgentSession` 4. **[04_memory.py](./01-get-started/04_memory.py)** — Agent memory with `ContextProvider` -5. **[05_first_workflow.py](./01-get-started/05_first_workflow.py)** — Build a workflow with executors and edges -6. **[06_host_your_agent.py](./01-get-started/06_host_your_agent.py)** — Host your agent via Azure Functions +5. **[05_functional_workflow_with_agents.py](./01-get-started/05_functional_workflow_with_agents.py)** — Call agents inside a functional workflow +6. **[06_functional_workflow_basics.py](./01-get-started/06_functional_workflow_basics.py)** — Write a workflow as a plain async function +7. **[07_first_graph_workflow.py](./01-get-started/07_first_graph_workflow.py)** — Build a workflow with executors and edges +8. **[08_host_your_agent.py](./01-get-started/08_host_your_agent.py)** — Host your agent via Azure Functions ## Prerequisites