From 0f605d7f0d03a632b37ba1956299bb90806170a0 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Tue, 31 Mar 2026 12:14:13 -0700 Subject: [PATCH 1/9] Refactor workflows kwargs usage --- .../_workflows/_agent_executor.py | 112 ++++++++--------- .../core/agent_framework/_workflows/_const.py | 4 + .../agent_framework/_workflows/_workflow.py | 118 +++++++++++++++--- .../state-management/workflow_kwargs.py | 2 +- 4 files changed, 161 insertions(+), 75 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_agent_executor.py b/python/packages/core/agent_framework/_workflows/_agent_executor.py index 462c3f8c64..ba5be916b0 100644 --- a/python/packages/core/agent_framework/_workflows/_agent_executor.py +++ b/python/packages/core/agent_framework/_workflows/_agent_executor.py @@ -2,7 +2,7 @@ import logging import sys -from collections.abc import Awaitable, Callable, Mapping +from collections.abc import Awaitable, Callable from dataclasses import dataclass from typing import Any, Literal, cast @@ -14,7 +14,7 @@ from .._sessions import AgentSession from .._types import AgentResponse, AgentResponseUpdate, Message from ._agent_utils import resolve_agent_id -from ._const import WORKFLOW_RUN_KWARGS_KEY +from ._const import GLOBAL_KWARGS_KEY, WORKFLOW_RUN_KWARGS_KEY from ._executor import Executor, handler from ._message_utils import normalize_messages_input from ._request_info_mixin import response_handler @@ -350,14 +350,16 @@ async def _run_agent(self, ctx: WorkflowContext[Never, AgentResponse]) -> AgentR Returns: The complete AgentResponse, or None if waiting for user input. """ - run_kwargs, options = self._prepare_agent_run_args(ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {})) - + function_invocation_kwargs, client_kwargs, backward_compatible_kwargs = self._prepare_agent_run_args( + ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {}) + ) response = await self._agent.run( self._cache, stream=False, session=self._session, - options=options, - **run_kwargs, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, + **backward_compatible_kwargs, ) await ctx.yield_output(response) @@ -379,7 +381,9 @@ async def _run_agent_streaming(self, ctx: WorkflowContext[Never, AgentResponseUp Returns: The complete AgentResponse, or None if waiting for user input. """ - run_kwargs, options = self._prepare_agent_run_args(ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {})) + function_invocation_kwargs, client_kwargs, backward_compatible_kwargs = self._prepare_agent_run_args( + ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {}) + ) updates: list[AgentResponseUpdate] = [] streamed_user_input_requests: list[Content] = [] @@ -387,8 +391,9 @@ async def _run_agent_streaming(self, ctx: WorkflowContext[Never, AgentResponseUp self._cache, stream=True, session=self._session, - options=options, - **run_kwargs, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, + **backward_compatible_kwargs, ) async for update in stream: updates.append(update) @@ -438,21 +443,35 @@ async def _run_agent_streaming(self, ctx: WorkflowContext[Never, AgentResponseUp # and must not appear in **run_kwargs to avoid TypeError from duplicate values. _RESERVED_RUN_PARAMS: frozenset[str] = frozenset({"session", "stream", "messages"}) - @staticmethod - def _prepare_agent_run_args(raw_run_kwargs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any] | None]: - """Prepare kwargs and options for agent.run(), avoiding duplicate option passing. + def _prepare_agent_run_args( + self, + raw_run_kwargs: dict[str, Any], + ) -> tuple[dict[str, Any] | None, dict[str, Any] | None, dict[str, Any]]: + """Prepare function_invocation_kwargs, client_kwargs, and backward-compatible kwargs for agent.run(). + + Extracts ``function_invocation_kwargs`` and ``client_invocation_kwargs`` from the + workflow state dict, resolving per-executor entries using ``self.id``. The + ``__global__`` sentinel key (set by ``Workflow._resolve_invocation_kwargs``) denotes + global kwargs that apply to all executors. Per-executor dicts use executor IDs as + keys; this executor extracts only its own entry. - Workflow-level kwargs are propagated to tool calls through - `options.additional_function_arguments`. If workflow kwargs include an - `options` key, merge it into the final options object and remove it from - kwargs before spreading `**run_kwargs`. + Any remaining keys in the raw dict are treated as backward-compatible agent.run() + kwargs. Reserved parameters (session, stream, messages) that are explicitly managed + by AgentExecutor are stripped to prevent duplicate-keyword collisions. - Reserved parameters (session, stream, messages) that are explicitly - managed by AgentExecutor are stripped from run_kwargs to prevent - ``TypeError: got multiple values for keyword argument`` collisions. + Returns: + A 3-tuple of (function_invocation_kwargs, client_kwargs, backward_compatible_kwargs). """ run_kwargs = dict(raw_run_kwargs) + # Extract the already-resolved invocation kwargs dicts + # (set by Workflow._resolve_invocation_kwargs). + fi_resolved = run_kwargs.pop("function_invocation_kwargs", None) + ci_resolved = run_kwargs.pop("client_invocation_kwargs", None) + + function_invocation_kwargs = self._resolve_executor_kwargs(fi_resolved) + client_kwargs = self._resolve_executor_kwargs(ci_resolved) + # Strip reserved params that AgentExecutor passes explicitly to agent.run(). for key in AgentExecutor._RESERVED_RUN_PARAMS: if key in run_kwargs: @@ -463,45 +482,22 @@ def _prepare_agent_run_args(raw_run_kwargs: dict[str, Any]) -> tuple[dict[str, A ) run_kwargs.pop(key) - options_from_workflow = run_kwargs.pop("options", None) - workflow_additional_args = run_kwargs.pop("additional_function_arguments", None) - - options: dict[str, Any] = {} - if options_from_workflow is not None: - if isinstance(options_from_workflow, Mapping): - options_from_workflow_map = cast(Mapping[str, Any], options_from_workflow) - for key, value in options_from_workflow_map.items(): - options[key] = value - else: - logger.warning( - "Ignoring non-mapping workflow 'options' kwarg of type %s for AgentExecutor %s.", - type(options_from_workflow).__name__, - AgentExecutor.__name__, - ) - - existing_additional_args = options.get("additional_function_arguments") - additional_args: dict[str, Any] - if isinstance(existing_additional_args, Mapping): - existing_additional_args_map = cast(Mapping[str, Any], existing_additional_args) - additional_args = {key: value for key, value in existing_additional_args_map.items()} - else: - additional_args = {} - - if workflow_additional_args is not None: - if isinstance(workflow_additional_args, Mapping): - workflow_additional_args_map = cast(Mapping[str, Any], workflow_additional_args) - additional_args.update({key: value for key, value in workflow_additional_args_map.items()}) - else: - logger.warning( - "Ignoring non-mapping workflow 'additional_function_arguments' kwarg of type %s for AgentExecutor %s.", # noqa: E501 - type(workflow_additional_args).__name__, - AgentExecutor.__name__, - ) + return function_invocation_kwargs, client_kwargs, run_kwargs - if run_kwargs: - additional_args.update(run_kwargs) + def _resolve_executor_kwargs(self, resolved: dict[str, Any] | None) -> dict[str, Any] | None: + """Extract this executor's kwargs from a resolved invocation kwargs dict. - if additional_args: - options["additional_function_arguments"] = additional_args + Args: + resolved: The resolved dict produced by ``Workflow._resolve_invocation_kwargs``, + containing either a ``__global__`` key (global kwargs) or executor-ID keys + (per-executor kwargs). May also be ``None``. - return run_kwargs, options or None + Returns: + The kwargs for this executor, or ``None`` if not applicable. + """ + if not isinstance(resolved, dict): + return None + executor_kwargs = resolved.get(self.id) or resolved.get(GLOBAL_KWARGS_KEY) + if isinstance(executor_kwargs, dict): + return cast(dict[str, Any], executor_kwargs) or None + return None diff --git a/python/packages/core/agent_framework/_workflows/_const.py b/python/packages/core/agent_framework/_workflows/_const.py index a8416af790..e83025bbdc 100644 --- a/python/packages/core/agent_framework/_workflows/_const.py +++ b/python/packages/core/agent_framework/_workflows/_const.py @@ -14,6 +14,10 @@ # to pass kwargs from workflow.run() through to agent.run() and @tool functions. WORKFLOW_RUN_KWARGS_KEY = "_workflow_run_kwargs" +# Sentinel key used in resolved invocation kwargs dicts to denote global kwargs +# that apply to all executors (as opposed to per-executor keyed entries). +GLOBAL_KWARGS_KEY = "__global__" + def INTERNAL_SOURCE_ID(executor_id: str) -> str: """Generate an internal source ID for a given executor.""" diff --git a/python/packages/core/agent_framework/_workflows/_workflow.py b/python/packages/core/agent_framework/_workflows/_workflow.py index cf030bf7b0..e862648402 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow.py +++ b/python/packages/core/agent_framework/_workflows/_workflow.py @@ -10,14 +10,15 @@ import logging import types import uuid -from collections.abc import AsyncIterable, Awaitable, Callable, Sequence +import warnings +from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, Sequence from typing import Any, Literal, overload from .._types import ResponseStream from ..observability import OtelAttr, capture_exception, create_workflow_span from ._agent import WorkflowAgent from ._checkpoint import CheckpointStorage -from ._const import DEFAULT_MAX_ITERATIONS, WORKFLOW_RUN_KWARGS_KEY +from ._const import DEFAULT_MAX_ITERATIONS, GLOBAL_KWARGS_KEY, WORKFLOW_RUN_KWARGS_KEY from ._edge import ( EdgeGroup, FanOutEdgeGroup, @@ -180,7 +181,6 @@ def __init__( description: str | None = None, max_iterations: int = DEFAULT_MAX_ITERATIONS, output_executors: list[str] | None = None, - **kwargs: Any, ): """Initialize the workflow with a list of edges. @@ -198,7 +198,6 @@ def __init__( WorkflowBuilder, this will be the description of the builder. output_executors: Optional list of executor IDs whose outputs will be considered workflow outputs. If None or empty, all executor outputs are treated as workflow outputs. - kwargs: Additional keyword arguments. Unused in this implementation. """ self.edge_groups = list(edge_groups) self.executors = dict(executors) @@ -300,7 +299,9 @@ async def _run_workflow_with_tracing( initial_executor_fn: Callable[[], Awaitable[None]] | None = None, reset_context: bool = True, streaming: bool = False, - run_kwargs: dict[str, Any] | None = None, + function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, + client_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, + run_kwargs: Mapping[str, Any] | None = None, ) -> AsyncIterable[WorkflowEvent]: """Private method to run workflow with proper tracing. @@ -311,7 +312,11 @@ async def _run_workflow_with_tracing( initial_executor_fn: Optional function to execute initial executor reset_context: Whether to reset the context for a new run streaming: Whether to enable streaming mode for agents - run_kwargs: Optional kwargs to store in State for agent invocations + function_invocation_kwargs: Optional kwargs to store in State for function + invocations in subagents + client_invocation_kwargs: Optional kwargs to store in State for chat client + invocations in subagents + run_kwargs: Deprecated optional kwargs to store in State for agent invocations Yields: WorkflowEvent: The events generated during the workflow execution. @@ -350,7 +355,20 @@ async def _run_workflow_with_tracing( # Only overwrite when new kwargs are explicitly provided or state was # just cleared (fresh run). On continuation (reset_context=False) with # no new kwargs, preserve the kwargs from the original run. - if run_kwargs is not None: + if function_invocation_kwargs or client_invocation_kwargs: + combined_kwargs: dict[str, Any] = {} + if function_invocation_kwargs: + combined_kwargs["function_invocation_kwargs"] = self._resolve_invocation_kwargs( + function_invocation_kwargs, "function_invocation_kwargs" + ) + if client_invocation_kwargs: + combined_kwargs["client_invocation_kwargs"] = self._resolve_invocation_kwargs( + client_invocation_kwargs, "client_invocation_kwargs" + ) + self._state.set(WORKFLOW_RUN_KWARGS_KEY, combined_kwargs) + elif run_kwargs is not None: + # Deprecated path for direct kwargs - still support but prefer the more explicit + # function_invocation_kwargs/client_invocation_kwargs self._state.set(WORKFLOW_RUN_KWARGS_KEY, run_kwargs) elif reset_context: self._state.set(WORKFLOW_RUN_KWARGS_KEY, {}) @@ -459,9 +477,10 @@ def run( message: Any | None = None, *, stream: Literal[True], - responses: dict[str, Any] | None = None, + responses: Mapping[str, Any] | None = None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> ResponseStream[WorkflowEvent, WorkflowRunResult]: ... @@ -471,10 +490,11 @@ def run( message: Any | None = None, *, stream: Literal[False] = ..., - responses: dict[str, Any] | None = None, + responses: Mapping[str, Any] | None = None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, include_status_events: bool = False, + function_invocation_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[WorkflowRunResult]: ... @@ -483,10 +503,12 @@ def run( message: Any | None = None, *, stream: bool = False, - responses: dict[str, Any] | None = None, + responses: Mapping[str, Any] | None = None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, include_status_events: bool = False, + function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, + client_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, **kwargs: Any, ) -> ResponseStream[WorkflowEvent, WorkflowRunResult] | Awaitable[WorkflowRunResult]: """Run the workflow, optionally streaming events. @@ -509,7 +531,14 @@ def run( (restore then send responses). checkpoint_storage: Runtime checkpoint storage. include_status_events: Whether to include status events (non-streaming only). - **kwargs: Additional keyword arguments to pass through to agent invocations. + function_invocation_kwargs: Keyword arguments forwarded to tool invocations in + subagents. Either a mapping for agent name or agent executor id to kwargs, + or a flat mapping of kwargs for all tool invocations. + client_invocation_kwargs: Keyword arguments forwarded to chat client calls in + subagents. Either a mapping for agent name or agent executor id to kwargs, + or a flat mapping of kwargs for all chat client calls. + **kwargs: Deprecated additional keyword arguments for the subagents. They are + forwarded to both tool invocations and the chat clients. Returns: When stream=True: A ResponseStream[WorkflowEvent, WorkflowRunResult] for @@ -519,6 +548,19 @@ def run( Raises: ValueError: If parameter combination is invalid. """ + if kwargs: + warnings.warn( + "Passing runtime keyword arguments directly to run() is deprecated; pass tool values via " + "function_invocation_kwargs and client_invocation_kwargs instead.", + DeprecationWarning, + stacklevel=2, + ) + if function_invocation_kwargs or client_invocation_kwargs: + raise ValueError( + "Cannot provide both deprecated kwargs and function_invocation_kwargs/client_invocation_kwargs. " + "Please consolidate to function_invocation_kwargs/client_invocation_kwargs." + ) + # Validate parameters and set running flag eagerly (before any async work) self._validate_run_params(message, responses, checkpoint_id) self._ensure_not_running() @@ -530,6 +572,8 @@ def run( checkpoint_id=checkpoint_id, checkpoint_storage=checkpoint_storage, streaming=stream, + function_invocation_kwargs=function_invocation_kwargs, + client_invocation_kwargs=client_invocation_kwargs, **kwargs, ), finalizer=functools.partial(self._finalize_events, include_status_events=include_status_events), @@ -546,10 +590,12 @@ async def _run_core( self, message: Any | None = None, *, - responses: dict[str, Any] | None = None, + responses: Mapping[str, Any] | None = None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, streaming: bool = False, + function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, + client_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, **kwargs: Any, ) -> AsyncIterable[WorkflowEvent]: """Single core execution path for both streaming and non-streaming modes. @@ -569,6 +615,8 @@ async def _run_core( initial_executor_fn=initial_executor_fn, reset_context=reset_context, streaming=streaming, + function_invocation_kwargs=function_invocation_kwargs, + client_invocation_kwargs=client_invocation_kwargs, # Empty **kwargs (no caller-provided kwargs) is collapsed to None so that # continuation calls without explicit kwargs preserve the original run's kwargs. # A non-empty kwargs dict (even one with empty values like {"key": {}}) @@ -624,7 +672,7 @@ def _finalize_events( @staticmethod def _validate_run_params( message: Any | None, - responses: dict[str, Any] | None, + responses: Mapping[str, Any] | None, checkpoint_id: str | None, ) -> None: """Validate parameter combinations for run(). @@ -650,7 +698,7 @@ def _validate_run_params( def _resolve_execution_mode( self, message: Any | None, - responses: dict[str, Any] | None, + responses: Mapping[str, Any] | None, checkpoint_id: str | None, checkpoint_storage: CheckpointStorage | None, ) -> tuple[Callable[[], Awaitable[None]], bool]: @@ -680,7 +728,7 @@ async def _restore_and_send_responses( self, checkpoint_id: str, checkpoint_storage: CheckpointStorage | None, - responses: dict[str, Any], + responses: Mapping[str, Any], ) -> None: """Restore from a checkpoint then send responses to pending requests. @@ -700,7 +748,7 @@ async def _restore_and_send_responses( await self._runner.restore_from_checkpoint(checkpoint_id, checkpoint_storage) await self._send_responses_internal(responses) - async def _send_responses_internal(self, responses: dict[str, Any]) -> None: + async def _send_responses_internal(self, responses: Mapping[str, Any]) -> None: """Internal method to validate and send responses to the executors.""" pending_requests = await self._runner_context.get_pending_request_info_events() if not pending_requests: @@ -739,6 +787,44 @@ def _get_executor_by_id(self, executor_id: str) -> Executor: raise ValueError(f"Executor with ID {executor_id} not found.") return self.executors[executor_id] + def _resolve_invocation_kwargs( + self, + kwargs: Mapping[str, Any], + param_name: str, + ) -> dict[str, Any]: + """Resolve invocation kwargs into a normalized per-executor or global format. + + Detects whether the provided kwargs dict uses per-executor targeting by checking + if any top-level key matches a known executor ID in the workflow. If at least one + key matches, all entries are treated as per-executor. Otherwise the dict is treated + as global kwargs that apply to every executor. + + Args: + kwargs: The raw invocation kwargs from the caller. + param_name: The parameter name (for logging), e.g. ``"function_invocation_kwargs"``. + + Returns: + A dict with either: + - ``{"__global__": }`` for global kwargs, or + - The original dict unchanged for per-executor kwargs. + """ + executor_ids = set(self.executors.keys()) + matched_ids = kwargs.keys() & executor_ids + if matched_ids: + logger.info( + "Detected per-executor %s: executor ID(s) %s found in keys. " + "All entries will be treated as per-executor.", + param_name, + matched_ids, + ) + return dict(kwargs) + + logger.info( + "No executor IDs found in %s keys; treating as global kwargs for all executors.", + param_name, + ) + return {GLOBAL_KWARGS_KEY: dict(kwargs)} + def _should_yield_output_event(self, event: WorkflowEvent[Any]) -> bool: """Determine if an output event should be yielded as a workflow output. diff --git a/python/samples/03-workflows/state-management/workflow_kwargs.py b/python/samples/03-workflows/state-management/workflow_kwargs.py index 630eaafc52..c8da0b0cf9 100644 --- a/python/samples/03-workflows/state-management/workflow_kwargs.py +++ b/python/samples/03-workflows/state-management/workflow_kwargs.py @@ -83,7 +83,7 @@ async def main() -> None: # Create chat client client = FoundryChatClient( project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"], - model=os.environ["AZURE_AI_MODEL_DEPLOYMENT_NAME"], + model=os.environ["FOUNDRY_MODEL"], credential=AzureCliCredential(), ) From a6e0811c921be7ac9546132bfbeb81eae9195eee Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Tue, 31 Mar 2026 12:53:22 -0700 Subject: [PATCH 2/9] Update sample --- .../samples/03-workflows/state-management/workflow_kwargs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/samples/03-workflows/state-management/workflow_kwargs.py b/python/samples/03-workflows/state-management/workflow_kwargs.py index c8da0b0cf9..4a9ad4ad9f 100644 --- a/python/samples/03-workflows/state-management/workflow_kwargs.py +++ b/python/samples/03-workflows/state-management/workflow_kwargs.py @@ -129,7 +129,7 @@ async def main() -> None: # Run workflow with kwargs - these will flow through to tools async for event in workflow.run( "Please get my user data and then call the users API endpoint.", - additional_function_arguments={"custom_data": custom_data, "user_token": user_token}, + function_invocation_kwargs={"custom_data": custom_data, "user_token": user_token}, stream=True, ): if event.type == "output": @@ -137,7 +137,7 @@ async def main() -> None: if isinstance(output_data, list): for item in output_data: if isinstance(item, Message) and item.text: - print(f"\n[Final Answer]: {item.text}") + print(f"\n[{item.author_name}]: {item.text}") print("\n" + "=" * 70) print("Sample Complete") From f95d6af2a2612df1cf4965bd8a686232b9c3c83f Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Tue, 31 Mar 2026 13:53:24 -0700 Subject: [PATCH 3/9] Add tests --- .../tests/workflow/test_agent_executor.py | 162 +++++++++- .../tests/workflow/test_workflow_kwargs.py | 304 +++++++++++++----- 2 files changed, 369 insertions(+), 97 deletions(-) diff --git a/python/packages/core/tests/workflow/test_agent_executor.py b/python/packages/core/tests/workflow/test_agent_executor.py index 6298a8963d..d68b3fd4b5 100644 --- a/python/packages/core/tests/workflow/test_agent_executor.py +++ b/python/packages/core/tests/workflow/test_agent_executor.py @@ -338,44 +338,50 @@ async def test_agent_executor_run_streaming_with_stream_kwarg_does_not_raise() - @pytest.mark.parametrize("reserved_kwarg", ["session", "stream", "messages"]) async def test_prepare_agent_run_args_strips_reserved_kwargs(reserved_kwarg: str, caplog: "LogCaptureFixture") -> None: """_prepare_agent_run_args must remove reserved kwargs and log a warning.""" + agent = _CountingAgent(id="test_agent", name="TestAgent") + executor = AgentExecutor(agent, id="test_exec") + raw: dict[str, Any] = { reserved_kwarg: "should-be-stripped", "custom_key": "keep-me", } with caplog.at_level(logging.WARNING): - run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage] + _, _, backward_kwargs = executor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage] - assert reserved_kwarg not in run_kwargs - assert "custom_key" in run_kwargs - assert options is not None - assert options["additional_function_arguments"]["custom_key"] == "keep-me" + assert reserved_kwarg not in backward_kwargs + assert "custom_key" in backward_kwargs + assert backward_kwargs["custom_key"] == "keep-me" assert any(reserved_kwarg in record.message for record in caplog.records) async def test_prepare_agent_run_args_preserves_non_reserved_kwargs() -> None: """Non-reserved workflow kwargs should pass through unchanged.""" + agent = _CountingAgent(id="test_agent", name="TestAgent") + executor = AgentExecutor(agent, id="test_exec") + raw: dict[str, Any] = {"custom_param": "value", "another": 42} - run_kwargs, _options = AgentExecutor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage] - assert run_kwargs["custom_param"] == "value" - assert run_kwargs["another"] == 42 + _, _, backward_kwargs = executor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage] + assert backward_kwargs["custom_param"] == "value" + assert backward_kwargs["another"] == 42 async def test_prepare_agent_run_args_strips_all_reserved_kwargs_at_once( caplog: "LogCaptureFixture", ) -> None: """All reserved kwargs should be stripped when supplied together, each emitting a warning.""" + agent = _CountingAgent(id="test_agent", name="TestAgent") + executor = AgentExecutor(agent, id="test_exec") + raw: dict[str, Any] = {"session": "x", "stream": True, "messages": [], "custom": 1} with caplog.at_level(logging.WARNING): - run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage] + _, _, backward_kwargs = executor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage] - assert "session" not in run_kwargs - assert "stream" not in run_kwargs - assert "messages" not in run_kwargs - assert run_kwargs["custom"] == 1 - assert options is not None - assert options["additional_function_arguments"]["custom"] == 1 + assert "session" not in backward_kwargs + assert "stream" not in backward_kwargs + assert "messages" not in backward_kwargs + assert backward_kwargs["custom"] == 1 warned_keys = {r.message.split("'")[1] for r in caplog.records if "reserved" in r.message.lower()} assert warned_keys == {"session", "stream", "messages"} @@ -638,3 +644,129 @@ async def test_checkpoint_restore_works_without_context_mode_in_state() -> None: assert cache[0].text == "cached msg" # context_mode should remain as configured in the constructor, not changed by restore assert executor._context_mode == "last_agent" # pyright: ignore[reportPrivateUsage] + + +# --------------------------------------------------------------------------- +# Per-executor kwargs resolution tests +# --------------------------------------------------------------------------- + + +from agent_framework._workflows._const import GLOBAL_KWARGS_KEY + + +async def test_resolve_executor_kwargs_returns_global_kwargs() -> None: + """_resolve_executor_kwargs with __global__ key returns the global kwargs.""" + agent = _CountingAgent(id="a", name="A") + executor = AgentExecutor(agent, id="exec_a") + + resolved = {"__global__": {"tool_param": "value"}} + result = executor._resolve_executor_kwargs(resolved) # pyright: ignore[reportPrivateUsage] + assert result == {"tool_param": "value"} + + +async def test_resolve_executor_kwargs_returns_per_executor_kwargs() -> None: + """_resolve_executor_kwargs with matching executor ID returns that executor's kwargs.""" + agent = _CountingAgent(id="a", name="A") + executor = AgentExecutor(agent, id="exec_a") + + resolved = {"exec_a": {"my_param": 42}, "exec_b": {"other_param": 99}} + result = executor._resolve_executor_kwargs(resolved) # pyright: ignore[reportPrivateUsage] + assert result == {"my_param": 42} + + +async def test_resolve_executor_kwargs_returns_none_for_unmatched_per_executor() -> None: + """_resolve_executor_kwargs returns None when per-executor dict has no matching ID.""" + agent = _CountingAgent(id="a", name="A") + executor = AgentExecutor(agent, id="exec_c") + + resolved = {"exec_a": {"my_param": 42}, "exec_b": {"other_param": 99}} + result = executor._resolve_executor_kwargs(resolved) # pyright: ignore[reportPrivateUsage] + assert result is None + + +async def test_resolve_executor_kwargs_returns_none_for_none_input() -> None: + """_resolve_executor_kwargs returns None when input is None.""" + agent = _CountingAgent(id="a", name="A") + executor = AgentExecutor(agent, id="exec_a") + + result = executor._resolve_executor_kwargs(None) # pyright: ignore[reportPrivateUsage] + assert result is None + + +async def test_resolve_executor_kwargs_prefers_executor_id_over_global() -> None: + """_resolve_executor_kwargs prefers executor-specific entry over __global__.""" + agent = _CountingAgent(id="a", name="A") + executor = AgentExecutor(agent, id="exec_a") + + # Dict has both a per-executor entry and a global entry + resolved = {"exec_a": {"specific": True}, GLOBAL_KWARGS_KEY: {"global": True}} + result = executor._resolve_executor_kwargs(resolved) # pyright: ignore[reportPrivateUsage] + assert result == {"specific": True} + + +async def test_resolve_executor_kwargs_returns_none_for_empty_dict_value() -> None: + """_resolve_executor_kwargs returns None when the matched value is an empty dict.""" + agent = _CountingAgent(id="a", name="A") + executor = AgentExecutor(agent, id="exec_a") + + resolved = {GLOBAL_KWARGS_KEY: {}} + result = executor._resolve_executor_kwargs(resolved) # pyright: ignore[reportPrivateUsage] + assert result is None + + +async def test_prepare_agent_run_args_extracts_function_invocation_kwargs() -> None: + """_prepare_agent_run_args extracts function_invocation_kwargs from the state dict.""" + agent = _CountingAgent(id="a", name="A") + executor = AgentExecutor(agent, id="exec_a") + + raw: dict[str, Any] = { + "function_invocation_kwargs": {GLOBAL_KWARGS_KEY: {"tool_key": "tool_val"}}, + } + fi_kwargs, client_kwargs, backward_kwargs = executor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage] + assert fi_kwargs == {"tool_key": "tool_val"} + assert client_kwargs is None + assert "function_invocation_kwargs" not in backward_kwargs + + +async def test_prepare_agent_run_args_extracts_client_invocation_kwargs() -> None: + """_prepare_agent_run_args extracts client_invocation_kwargs from the state dict.""" + agent = _CountingAgent(id="a", name="A") + executor = AgentExecutor(agent, id="exec_a") + + raw: dict[str, Any] = { + "client_invocation_kwargs": {GLOBAL_KWARGS_KEY: {"model": "gpt-4"}}, + } + fi_kwargs, client_kwargs, backward_kwargs = executor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage] + assert fi_kwargs is None + assert client_kwargs == {"model": "gpt-4"} + assert "client_invocation_kwargs" not in backward_kwargs + + +async def test_prepare_agent_run_args_per_executor_resolution() -> None: + """_prepare_agent_run_args resolves per-executor function_invocation_kwargs using self.id.""" + agent = _CountingAgent(id="a", name="A") + executor = AgentExecutor(agent, id="exec_a") + + raw: dict[str, Any] = { + "function_invocation_kwargs": { + "exec_a": {"my_tool_key": "my_val"}, + "exec_b": {"other_tool_key": "other_val"}, + }, + } + fi_kwargs, _, _ = executor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage] + assert fi_kwargs == {"my_tool_key": "my_val"} + + +async def test_prepare_agent_run_args_per_executor_no_match() -> None: + """_prepare_agent_run_args returns None for function_invocation_kwargs when executor ID not found.""" + agent = _CountingAgent(id="a", name="A") + executor = AgentExecutor(agent, id="exec_c") + + raw: dict[str, Any] = { + "function_invocation_kwargs": { + "exec_a": {"my_tool_key": "my_val"}, + "exec_b": {"other_tool_key": "other_val"}, + }, + } + fi_kwargs, _, _ = executor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage] + assert fi_kwargs is None diff --git a/python/packages/core/tests/workflow/test_workflow_kwargs.py b/python/packages/core/tests/workflow/test_workflow_kwargs.py index d315f75f85..d14be92bca 100644 --- a/python/packages/core/tests/workflow/test_workflow_kwargs.py +++ b/python/packages/core/tests/workflow/test_workflow_kwargs.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. from collections.abc import AsyncIterable, Awaitable -from typing import Annotated, Any, Literal, overload +from typing import TYPE_CHECKING, Annotated, Any, Literal, overload import pytest @@ -26,6 +26,9 @@ SequentialBuilder, ) +if TYPE_CHECKING: + from _pytest.logging import LogCaptureFixture + # Track kwargs received by tools during test execution _received_kwargs: list[dict[str, Any]] = [] @@ -205,50 +208,8 @@ async def test_sequential_run_kwargs_flow() -> None: assert agent.captured_kwargs[0].get("custom_data") == {"test": True} -async def test_sequential_run_options_does_not_conflict_with_agent_options() -> None: - """Test workflow.run(options=...) does not conflict with Agent.run(options=...).""" - agent = _OptionsAwareAgent(name="options_agent") - workflow = SequentialBuilder(participants=[agent]).build() - - custom_data = {"session_id": "abc123"} - user_token = {"user_name": "alice"} - provided_options = { - "store": False, - "additional_function_arguments": {"source": "workflow-options"}, - } - - async for event in workflow.run( - "test message", - stream=True, - options=provided_options, - custom_data=custom_data, - user_token=user_token, - ): - if event.type == "status" and event.state == WorkflowRunState.IDLE: - break - - assert len(agent.captured_options) >= 1 - captured_options: dict[str, Any] | None = agent.captured_options[0] - assert captured_options is not None - assert captured_options.get("store") is False - - additional_args: Any = captured_options.get("additional_function_arguments") - assert isinstance(additional_args, dict) - assert additional_args.get("source") == "workflow-options" # pyright: ignore[reportUnknownMemberType] - assert additional_args.get("custom_data") == custom_data # pyright: ignore[reportUnknownMemberType] - assert additional_args.get("user_token") == user_token # pyright: ignore[reportUnknownMemberType] - - # "options" should be passed once via the dedicated options parameter, - # not duplicated in **kwargs. - assert len(agent.captured_kwargs) >= 1 - captured_kwargs = agent.captured_kwargs[0] - assert "options" not in captured_kwargs - assert captured_kwargs.get("custom_data") == custom_data - assert captured_kwargs.get("user_token") == user_token - - async def test_sequential_run_additional_function_arguments_flattened() -> None: - """Test workflow.run(additional_function_arguments=...) maps directly to tool kwargs.""" + """Test workflow.run(additional_function_arguments=...) passes through as a backward-compat kwarg.""" agent = _OptionsAwareAgent(name="options_agent") workflow = SequentialBuilder(participants=[agent]).build() @@ -263,46 +224,14 @@ async def test_sequential_run_additional_function_arguments_flattened() -> None: if event.type == "status" and event.state == WorkflowRunState.IDLE: break - assert len(agent.captured_options) >= 1 - captured_options: dict[str, Any] | None = agent.captured_options[0] - assert captured_options is not None - - additional_args: Any = captured_options.get("additional_function_arguments") - assert isinstance(additional_args, dict) - assert additional_args.get("custom_data") == custom_data # pyright: ignore[reportUnknownMemberType] - assert additional_args.get("user_token") == user_token # pyright: ignore[reportUnknownMemberType] - assert "additional_function_arguments" not in additional_args - + # additional_function_arguments is passed through as a backward-compat kwarg, + # not merged into options. assert len(agent.captured_kwargs) >= 1 captured_kwargs = agent.captured_kwargs[0] - assert "additional_function_arguments" not in captured_kwargs - - -async def test_sequential_run_additional_function_arguments_merges_with_options() -> None: - """Test workflow additional_function_arguments merges with workflow options.""" - agent = _OptionsAwareAgent(name="options_agent") - workflow = SequentialBuilder(participants=[agent]).build() - - async for event in workflow.run( - "test message", - stream=True, - options={"additional_function_arguments": {"source": "workflow-options"}}, - additional_function_arguments={"custom_data": {"session_id": "abc123"}}, - user_token={"user_name": "alice"}, - ): - if event.type == "status" and event.state == WorkflowRunState.IDLE: - break - - assert len(agent.captured_options) >= 1 - captured_options: dict[str, Any] | None = agent.captured_options[0] - assert captured_options is not None - - additional_args: Any = captured_options.get("additional_function_arguments") - assert isinstance(additional_args, dict) - assert additional_args.get("source") == "workflow-options" # pyright: ignore[reportUnknownMemberType] - assert additional_args.get("custom_data") == {"session_id": "abc123"} # pyright: ignore[reportUnknownMemberType] - assert additional_args.get("user_token") == {"user_name": "alice"} # pyright: ignore[reportUnknownMemberType] - assert "additional_function_arguments" not in additional_args + assert captured_kwargs.get("additional_function_arguments") == { + "custom_data": custom_data, + "user_token": user_token, + } # endregion @@ -1179,3 +1108,214 @@ async def test_nested_subworkflow_kwargs_propagation() -> None: # endregion + + +# region Per-Executor Invocation Kwargs Tests + + +async def test_legacy_kwargs_cannot_coexist_with_new_invocation_kwargs() -> None: + """Passing both legacy **kwargs and function_invocation_kwargs/client_invocation_kwargs must raise ValueError.""" + agent = _KwargsCapturingAgent(name="agent1") + workflow = SequentialBuilder(participants=[agent]).build() + + with pytest.raises(ValueError, match="Cannot provide both deprecated kwargs"): + await workflow.run( + "test", + function_invocation_kwargs={"tool_key": "val"}, + custom_legacy_kwarg="should_conflict", + ) + + with pytest.raises(ValueError, match="Cannot provide both deprecated kwargs"): + await workflow.run( + "test", + client_invocation_kwargs={"model": "gpt-4"}, + custom_legacy_kwarg="should_conflict", + ) + + +async def test_function_and_client_invocation_kwargs_together() -> None: + """Both function_invocation_kwargs and client_invocation_kwargs can be provided in the same call.""" + agent1 = _KwargsCapturingAgent(name="agent1") + agent2 = _KwargsCapturingAgent(name="agent2") + workflow = SequentialBuilder(participants=[agent1, agent2]).build() + + fi_kwargs = {"tool_param": "tool_value"} + ci_kwargs = {"temperature": 0.7} + + async for event in workflow.run( + "test", + stream=True, + function_invocation_kwargs=fi_kwargs, + client_invocation_kwargs=ci_kwargs, + ): + if event.type == "status" and event.state == WorkflowRunState.IDLE: + break + + # Both agents should receive both kwargs + for agent in [agent1, agent2]: + assert len(agent.captured_kwargs) >= 1 + assert agent.captured_kwargs[0].get("function_invocation_kwargs") == fi_kwargs + assert agent.captured_kwargs[0].get("client_kwargs") == ci_kwargs + + +async def test_global_function_invocation_kwargs_flow_to_all_agents() -> None: + """Global function_invocation_kwargs should be received by all agents in a sequential workflow.""" + agent1 = _KwargsCapturingAgent(name="agent1") + agent2 = _KwargsCapturingAgent(name="agent2") + workflow = SequentialBuilder(participants=[agent1, agent2]).build() + + fi_kwargs = {"tool_param": "shared_value"} + + async for event in workflow.run( + "test", + stream=True, + function_invocation_kwargs=fi_kwargs, + ): + if event.type == "status" and event.state == WorkflowRunState.IDLE: + break + + # Both agents should receive function_invocation_kwargs + assert len(agent1.captured_kwargs) >= 1 + assert agent1.captured_kwargs[0].get("function_invocation_kwargs") == fi_kwargs + assert len(agent2.captured_kwargs) >= 1 + assert agent2.captured_kwargs[0].get("function_invocation_kwargs") == fi_kwargs + + +async def test_per_executor_function_invocation_kwargs_routes_to_correct_agent() -> None: + """Per-executor function_invocation_kwargs should only be received by the targeted agent.""" + agent1 = _KwargsCapturingAgent(name="agent1") + agent2 = _KwargsCapturingAgent(name="agent2") + workflow = SequentialBuilder(participants=[agent1, agent2]).build() + + # Per-executor: keys match agent names (which are used as executor IDs) + fi_kwargs = { + "agent1": {"tool_param": "value_for_agent1"}, + "agent2": {"tool_param": "value_for_agent2"}, + } + + async for event in workflow.run( + "test", + stream=True, + function_invocation_kwargs=fi_kwargs, + ): + if event.type == "status" and event.state == WorkflowRunState.IDLE: + break + + # Each agent should receive only its own kwargs + assert len(agent1.captured_kwargs) >= 1 + assert agent1.captured_kwargs[0].get("function_invocation_kwargs") == {"tool_param": "value_for_agent1"} + assert len(agent2.captured_kwargs) >= 1 + assert agent2.captured_kwargs[0].get("function_invocation_kwargs") == {"tool_param": "value_for_agent2"} + + +async def test_per_executor_kwargs_unmatched_agent_gets_none() -> None: + """An agent not targeted in per-executor kwargs should receive None for that kwarg.""" + agent1 = _KwargsCapturingAgent(name="agent1") + agent2 = _KwargsCapturingAgent(name="agent2") + workflow = SequentialBuilder(participants=[agent1, agent2]).build() + + # Only agent1 is targeted + fi_kwargs = {"agent1": {"tool_param": "only_for_agent1"}} + + async for event in workflow.run( + "test", + stream=True, + function_invocation_kwargs=fi_kwargs, + ): + if event.type == "status" and event.state == WorkflowRunState.IDLE: + break + + assert len(agent1.captured_kwargs) >= 1 + assert agent1.captured_kwargs[0].get("function_invocation_kwargs") == {"tool_param": "only_for_agent1"} + assert len(agent2.captured_kwargs) >= 1 + assert agent2.captured_kwargs[0].get("function_invocation_kwargs") is None + + +async def test_global_client_invocation_kwargs_flow_to_all_agents() -> None: + """Global client_invocation_kwargs should be received by all agents.""" + agent1 = _KwargsCapturingAgent(name="agent1") + agent2 = _KwargsCapturingAgent(name="agent2") + workflow = SequentialBuilder(participants=[agent1, agent2]).build() + + ci_kwargs = {"temperature": 0.5} + + async for event in workflow.run( + "test", + stream=True, + client_invocation_kwargs=ci_kwargs, + ): + if event.type == "status" and event.state == WorkflowRunState.IDLE: + break + + assert len(agent1.captured_kwargs) >= 1 + assert agent1.captured_kwargs[0].get("client_kwargs") == ci_kwargs + assert len(agent2.captured_kwargs) >= 1 + assert agent2.captured_kwargs[0].get("client_kwargs") == ci_kwargs + + +async def test_per_executor_client_invocation_kwargs_routes_correctly() -> None: + """Per-executor client_invocation_kwargs should only be received by the targeted agent.""" + agent1 = _KwargsCapturingAgent(name="agent1") + agent2 = _KwargsCapturingAgent(name="agent2") + workflow = SequentialBuilder(participants=[agent1, agent2]).build() + + ci_kwargs = { + "agent1": {"temperature": 0.1}, + "agent2": {"temperature": 0.9}, + } + + async for event in workflow.run( + "test", + stream=True, + client_invocation_kwargs=ci_kwargs, + ): + if event.type == "status" and event.state == WorkflowRunState.IDLE: + break + + assert len(agent1.captured_kwargs) >= 1 + assert agent1.captured_kwargs[0].get("client_kwargs") == {"temperature": 0.1} + assert len(agent2.captured_kwargs) >= 1 + assert agent2.captured_kwargs[0].get("client_kwargs") == {"temperature": 0.9} + + +async def test_resolve_invocation_kwargs_logs_per_executor(caplog: "LogCaptureFixture") -> None: + """Workflow._resolve_invocation_kwargs logs info when per-executor format is detected.""" + import logging + + agent = _KwargsCapturingAgent(name="agent1") + workflow = SequentialBuilder(participants=[agent]).build() + + with caplog.at_level(logging.INFO): + async for event in workflow.run( + "test", + stream=True, + function_invocation_kwargs={"agent1": {"key": "val"}}, + ): + if event.type == "status" and event.state == WorkflowRunState.IDLE: + break + + per_executor_logs = [r for r in caplog.records if "per-executor" in r.message.lower()] + assert len(per_executor_logs) >= 1 + + +async def test_resolve_invocation_kwargs_logs_global(caplog: "LogCaptureFixture") -> None: + """Workflow._resolve_invocation_kwargs logs info when global format is detected.""" + import logging + + agent = _KwargsCapturingAgent(name="agent1") + workflow = SequentialBuilder(participants=[agent]).build() + + with caplog.at_level(logging.INFO): + async for event in workflow.run( + "test", + stream=True, + function_invocation_kwargs={"tool_key": "tool_val"}, + ): + if event.type == "status" and event.state == WorkflowRunState.IDLE: + break + + global_logs = [r for r in caplog.records if "global kwargs" in r.message.lower()] + assert len(global_logs) >= 1 + + +# endregion From 9012ca542014ddd3cf07a22c1fdd8e98a4088a1c Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Tue, 31 Mar 2026 14:45:28 -0700 Subject: [PATCH 4/9] Update samples --- .../state-management/workflow_kwargs.py | 148 ------------ .../workflow_kwargs_global.py | 170 ++++++++++++++ .../workflow_kwargs_per_agent.py | 222 ++++++++++++++++++ 3 files changed, 392 insertions(+), 148 deletions(-) delete mode 100644 python/samples/03-workflows/state-management/workflow_kwargs.py create mode 100644 python/samples/03-workflows/state-management/workflow_kwargs_global.py create mode 100644 python/samples/03-workflows/state-management/workflow_kwargs_per_agent.py diff --git a/python/samples/03-workflows/state-management/workflow_kwargs.py b/python/samples/03-workflows/state-management/workflow_kwargs.py deleted file mode 100644 index 4a9ad4ad9f..0000000000 --- a/python/samples/03-workflows/state-management/workflow_kwargs.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -import asyncio -import json -import os -from typing import Annotated, Any, cast - -from agent_framework import Agent, Message, tool -from agent_framework.foundry import FoundryChatClient -from agent_framework.orchestrations import SequentialBuilder -from azure.identity import AzureCliCredential -from dotenv import load_dotenv -from pydantic import Field - -# Load environment variables from .env file -load_dotenv() - -""" -Sample: Workflow kwargs Flow to @tool Tools - -This sample demonstrates how to flow custom context (skill data, user tokens, etc.) -through any workflow pattern to @tool functions using the **kwargs pattern. - -Key Concepts: -- Pass custom context as kwargs when invoking workflow.run() -- kwargs are stored in State and passed to all agent invocations -- @tool functions receive kwargs via **kwargs parameter -- Works with Sequential, Concurrent, GroupChat, Handoff, and Magentic patterns - -Prerequisites: -- FOUNDRY_PROJECT_ENDPOINT must be your Azure AI Foundry Agent Service (V2) project endpoint. -- Environment variables configured -""" - - -# Define tools that accept custom context via **kwargs -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; -# see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. -@tool(approval_mode="never_require") -def get_user_data( - query: Annotated[str, Field(description="What user data to retrieve")], - **kwargs: Any, -) -> str: - """Retrieve user-specific data based on the authenticated context.""" - user_token = kwargs.get("user_token", {}) - user_name = user_token.get("user_name", "anonymous") - access_level = user_token.get("access_level", "none") - - print(f"\n[get_user_data] Received kwargs keys: {list(kwargs.keys())}") - print(f"[get_user_data] User: {user_name}") - print(f"[get_user_data] Access level: {access_level}") - - return f"Retrieved data for user {user_name} with {access_level} access: {query}" - - -@tool(approval_mode="never_require") -def call_api( - endpoint_name: Annotated[str, Field(description="Name of the API endpoint to call")], - **kwargs: Any, -) -> str: - """Call an API using the configured endpoints from custom_data.""" - custom_data = kwargs.get("custom_data", {}) - api_config = custom_data.get("api_config", {}) - - base_url = api_config.get("base_url", "unknown") - endpoints = api_config.get("endpoints", {}) - - print(f"\n[call_api] Received kwargs keys: {list(kwargs.keys())}") - print(f"[call_api] Base URL: {base_url}") - print(f"[call_api] Available endpoints: {list(endpoints.keys())}") - - if endpoint_name in endpoints: - return f"Called {base_url}{endpoints[endpoint_name]} successfully" - return f"Endpoint '{endpoint_name}' not found in configuration" - - -async def main() -> None: - print("=" * 70) - print("Workflow kwargs Flow Demo (SequentialBuilder)") - print("=" * 70) - - # Create chat client - client = FoundryChatClient( - project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"], - model=os.environ["FOUNDRY_MODEL"], - credential=AzureCliCredential(), - ) - - # Create agent with tools that use kwargs - agent = Agent( - client=client, - name="assistant", - instructions=( - "You are a helpful assistant. Use the available tools to help users. " - "When asked about user data, use get_user_data. " - "When asked to call an API, use call_api." - ), - tools=[get_user_data, call_api], - ) - - # Build a simple sequential workflow - workflow = SequentialBuilder(participants=[agent]).build() - - # Define custom context that will flow to tools via kwargs - custom_data = { - "api_config": { - "base_url": "https://api.example.com", - "endpoints": { - "users": "/v1/users", - "orders": "/v1/orders", - "products": "/v1/products", - }, - }, - } - - user_token = { - "user_name": "bob@contoso.com", - "access_level": "admin", - } - - print("\nCustom Data being passed:") - print(json.dumps(custom_data, indent=2)) - print(f"\nUser: {user_token['user_name']}") - print("\n" + "-" * 70) - print("Workflow Execution (watch for [tool_name] logs showing kwargs received):") - print("-" * 70) - - # Run workflow with kwargs - these will flow through to tools - async for event in workflow.run( - "Please get my user data and then call the users API endpoint.", - function_invocation_kwargs={"custom_data": custom_data, "user_token": user_token}, - stream=True, - ): - if event.type == "output": - output_data = cast(list[Message], event.data) - if isinstance(output_data, list): - for item in output_data: - if isinstance(item, Message) and item.text: - print(f"\n[{item.author_name}]: {item.text}") - - print("\n" + "=" * 70) - print("Sample Complete") - print("=" * 70) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/python/samples/03-workflows/state-management/workflow_kwargs_global.py b/python/samples/03-workflows/state-management/workflow_kwargs_global.py new file mode 100644 index 0000000000..89f0c1672e --- /dev/null +++ b/python/samples/03-workflows/state-management/workflow_kwargs_global.py @@ -0,0 +1,170 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +import json +import os +from typing import Annotated, Any, cast + +from agent_framework import Agent, Message, tool +from agent_framework.foundry import FoundryChatClient +from agent_framework.orchestrations import SequentialBuilder +from azure.identity import AzureCliCredential +from dotenv import load_dotenv +from pydantic import Field + +# Load environment variables from .env file +load_dotenv() + +""" +Sample: Global Workflow kwargs + +This sample demonstrates how to pass the same kwargs to every agent in a +workflow using global targeting. When keys in function_invocation_kwargs do NOT +match any executor ID (agent name), the framework treats them as global and +delivers them to all agents. + +Compare with workflow_kwargs_per_agent.py which targets kwargs to specific agents. + +Key Concepts: +- Global function_invocation_kwargs are delivered to every agent in the workflow +- Useful when all agents share the same credentials, config, or context +- @tool functions receive kwargs via the **kwargs parameter + +Prerequisites: +- FOUNDRY_PROJECT_ENDPOINT must be your Azure AI Foundry Agent Service (V2) project endpoint. +- Environment variables configured +""" + + +# 1. Define a tool for the research agent — queries a company's internal +# database using credentials passed via global kwargs. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; +# see samples/02-agents/tools/function_tool_with_approval.py +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. +@tool(approval_mode="never_require") +def query_company_database( + query: Annotated[ + str, Field(description="The database query to run, e.g. 'Q3 revenue' or 'headcount by department'") + ], + **kwargs: Any, +) -> str: + """Query the company's internal database for business metrics and data.""" + db_config = kwargs.get("db_config", {}) + connection_string = db_config.get("connection_string", "") + database = db_config.get("database", "") + + if not connection_string or not database: + return f"ERROR: missing db_config — cannot run query '{query}'" + + print(f"\n [query_company_database] Connecting to {database} at {connection_string[:30]}...") + + # Simulated company data that the LLM would not know on its own + return ( + f"Query results from {database}:\n" + f"- Contoso Q3 2025 revenue: $47.2M (up 12% YoY)\n" + f"- Top product line: CloudSync Pro ($18.6M)\n" + f"- Engineering headcount: 342 (up from 298 in Q2)\n" + f"- Customer churn rate: 4.1% (down from 5.3% in Q2)\n" + f"- Net new enterprise customers: 28" + ) + + +# 2. Define a tool for the writer agent — retrieves the formatting style +# from user preferences passed via global kwargs. +@tool(approval_mode="never_require") +def get_formatting_instructions( + section_title: Annotated[str, Field(description="The title of the section or report to format")], + **kwargs: Any, +) -> str: + """Get the formatting instructions based on user preferences.""" + user_prefs = kwargs.get("user_preferences", {}) + output_format = user_prefs.get("format", "plain") + language = user_prefs.get("language", "en") + + print(f"\n [get_formatting_instructions] Format: {output_format}, Language: {language}") + + return ( + f"Formatting rules for '{section_title}':\n" + f"- Output format: {output_format}\n" + f"- Language/locale: {language}\n" + f"- Include a footer: 'Generated in {output_format} for locale {language}'" + ) + + +async def main() -> None: + print("=" * 70) + print("Global Workflow kwargs Demo") + print("=" * 70) + + # 3. Create a shared chat client. + client = FoundryChatClient( + project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"], + model=os.environ["FOUNDRY_MODEL"], + credential=AzureCliCredential(), + ) + + # 4. Create two agents with different tools and responsibilities. + researcher = Agent( + client=client, + name="researcher", + instructions=( + "You are a data analyst. Call query_company_database exactly once " + "with the user's request as the query. Return the raw results." + ), + tools=[query_company_database], + ) + + writer = Agent( + client=client, + name="writer", + instructions=( + "You are a report writer. Call get_formatting_instructions exactly once, " + "then rewrite the data you receive into a polished report following those rules." + ), + tools=[get_formatting_instructions], + ) + + # 5. Build a sequential workflow: researcher -> writer. + workflow = SequentialBuilder(participants=[researcher, writer]).build() + + # 6. Define global kwargs — every agent receives all of these. + # Because the keys ("db_config", "user_preferences") do NOT match any + # executor ID ("researcher", "writer"), the framework treats them as + # global and delivers the full dict to every agent. + global_fi_kwargs = { + "db_config": { + "connection_string": "Server=contoso-sql.database.windows.net;Database=metrics", + "database": "contoso_metrics_prod", + }, + "user_preferences": { + "format": "markdown", + "language": "en-US", + }, + } + + print("\nGlobal function_invocation_kwargs (sent to all agents):") + print(json.dumps(global_fi_kwargs, indent=2)) + print("\n" + "-" * 70) + print("Workflow Execution:") + print("-" * 70) + + # 7. Run the workflow — every agent receives the same global kwargs. + async for event in workflow.run( + "Pull Contoso's Q3 2025 performance data and write an executive summary.", + function_invocation_kwargs=global_fi_kwargs, + stream=True, + ): + if event.type == "output": + output_data = cast(list[Message], event.data) + if isinstance(output_data, list): + for item in output_data: + if isinstance(item, Message) and item.text: + print(f"\n[{item.author_name}]: {item.text}") + + print("\n" + "=" * 70) + print("Sample Complete") + print("=" * 70) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/03-workflows/state-management/workflow_kwargs_per_agent.py b/python/samples/03-workflows/state-management/workflow_kwargs_per_agent.py new file mode 100644 index 0000000000..9c12ac90fc --- /dev/null +++ b/python/samples/03-workflows/state-management/workflow_kwargs_per_agent.py @@ -0,0 +1,222 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +import json +import os +from typing import Annotated, Any, cast + +from agent_framework import Agent, Message, tool +from agent_framework.foundry import FoundryChatClient +from agent_framework.orchestrations import SequentialBuilder +from azure.identity import AzureCliCredential +from dotenv import load_dotenv +from pydantic import Field + +# Load environment variables from .env file +load_dotenv() + +""" +Sample: Per-Agent Workflow kwargs + +This sample demonstrates how to pass different kwargs to different agents in a +workflow using per-agent targeting. When keys in function_invocation_kwargs (or +client_invocation_kwargs) match executor IDs (agent names by default), each agent +receives only its own slice of the kwargs. + +Key Concepts: +- Per-agent function_invocation_kwargs target specific agents by executor ID +- Agents only receive the kwargs assigned to them (not other agents' kwargs) +- Useful when different agents need different credentials, configs, or context + +Prerequisites: +- FOUNDRY_PROJECT_ENDPOINT must be your Azure AI Foundry Agent Service (V2) project endpoint. +- Environment variables configured +""" + + +# 1. Define a tool for the research agent — queries a company's internal +# database using credentials passed via per-agent kwargs. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; +# see samples/02-agents/tools/function_tool_with_approval.py +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. +@tool(approval_mode="never_require") +def query_company_database( + query: Annotated[ + str, Field(description="The database query to run, e.g. 'Q3 revenue' or 'headcount by department'") + ], + **kwargs: Any, +) -> str: + """Query the company's internal database for business metrics and data.""" + db_config = kwargs.get("db_config", {}) + connection_string = db_config.get("connection_string", "") + database = db_config.get("database", "") + + if not connection_string or not database: + return f"ERROR: missing db_config — cannot run query '{query}'" + + print(f"\n [query_company_database] Connecting to {database} at {connection_string[:30]}...") + + # Simulated company data that the LLM would not know on its own + return ( + f"Query results from {database}:\n" + f"- Contoso Q3 2025 revenue: $47.2M (up 12% YoY)\n" + f"- Top product line: CloudSync Pro ($18.6M)\n" + f"- Engineering headcount: 342 (up from 298 in Q2)\n" + f"- Customer churn rate: 4.1% (down from 5.3% in Q2)\n" + f"- Net new enterprise customers: 28" + ) + + +# 2. Define a tool for the writer agent — retrieves the formatting style +# from user preferences passed via per-agent kwargs. +@tool(approval_mode="never_require") +def get_formatting_instructions( + section_title: Annotated[str, Field(description="The title of the section or report to format")], + **kwargs: Any, +) -> str: + """Get the formatting instructions based on user preferences.""" + user_prefs = kwargs.get("user_preferences", {}) + output_format = user_prefs.get("format", "plain") + language = user_prefs.get("language", "en") + + print(f"\n [get_formatting_instructions] Format: {output_format}, Language: {language}") + + return ( + f"Formatting rules for '{section_title}':\n" + f"- Output format: {output_format}\n" + f"- Language/locale: {language}\n" + f"- Include a footer: 'Generated in {output_format} for locale {language}'" + ) + + +async def main() -> None: + print("=" * 70) + print("Per-Agent Workflow kwargs Demo") + print("=" * 70) + + # 3. Create a shared chat client. + client = FoundryChatClient( + project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"], + model=os.environ["FOUNDRY_MODEL"], + credential=AzureCliCredential(), + ) + + # 4. Create two agents with different tools and responsibilities. + researcher = Agent( + client=client, + name="researcher", + instructions=( + "You are a data analyst. Call query_company_database exactly once " + "with the user's request as the query. Return the raw results." + ), + tools=[query_company_database], + ) + + writer = Agent( + client=client, + name="writer", + instructions=( + "You are a report writer. Call get_formatting_instructions exactly once, " + "then rewrite the data you receive into a polished report following those rules." + ), + tools=[get_formatting_instructions], + ) + + # 5. Build a sequential workflow: researcher -> writer. + workflow = SequentialBuilder(participants=[researcher, writer]).build() + + # 6. Define per-agent kwargs — each agent gets only its own config. + # The keys ("researcher", "writer") match the agent names, which are + # used as executor IDs by default. + per_agent_fi_kwargs = { + "researcher": { + "db_config": { + "connection_string": "Server=contoso-sql.database.windows.net;Database=metrics", + "database": "contoso_metrics_prod", + }, + }, + "writer": { + "user_preferences": { + "format": "markdown", + "language": "en-US", + }, + }, + } + + print("\nPer-agent function_invocation_kwargs:") + print(json.dumps(per_agent_fi_kwargs, indent=2)) + print("\n" + "-" * 70) + print("Workflow Execution:") + print("-" * 70) + + # 7. Run the workflow — each agent receives only its targeted kwargs. + async for event in workflow.run( + "Pull Contoso's Q3 2025 performance data and write an executive summary.", + function_invocation_kwargs=per_agent_fi_kwargs, + stream=True, + ): + if event.type == "output": + output_data = cast(list[Message], event.data) + if isinstance(output_data, list): + for item in output_data: + if isinstance(item, Message) and item.text: + print(f"\n[{item.author_name}]: {item.text}") + + print("\n" + "=" * 70) + print("Sample Complete") + print("=" * 70) + + +if __name__ == "__main__": + asyncio.run(main()) + +""" +Sample output: + +Per-agent function_invocation_kwargs: +{ + "researcher": { + "db_config": { + "connection_string": "Server=contoso-sql.database.windows.net;Database=metrics", + "database": "contoso_metrics_prod" + } + }, + "writer": { + "user_preferences": { + "format": "markdown", + "language": "en-US" + } + } +} + +---------------------------------------------------------------------- +Workflow Execution: +---------------------------------------------------------------------- + + [query_company_database] Connecting to contoso_metrics_prod at Server=contoso-sql.database.wi... + +[researcher]: Here is Contoso's Q3 2025 data: +- Revenue: $47.2M (up 12% YoY) +- Top product: CloudSync Pro ($18.6M) +- Engineering headcount: 342 +- Churn rate: 4.1% +- Net new enterprise customers: 28 + + [get_formatting_instructions] Format: markdown, Language: en-US + +[writer]: # Contoso Q3 2025 Executive Summary + +| Metric | Value | +|---|---| +| Revenue | $47.2M (+12% YoY) | +| Top Product | CloudSync Pro ($18.6M) | +| Engineering Headcount | 342 | +| Customer Churn | 4.1% | +| New Enterprise Customers | 28 | + +Generated in markdown for locale en-US + +====================================================================== +Sample Complete +====================================================================== +""" From 57a6efb02ad60c41ca126de44389764276e17781 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Tue, 31 Mar 2026 14:50:39 -0700 Subject: [PATCH 5/9] Fix formatting --- python/packages/core/agent_framework/_evaluation.py | 9 +++------ .../packages/core/tests/workflow/test_agent_executor.py | 4 +--- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/python/packages/core/agent_framework/_evaluation.py b/python/packages/core/agent_framework/_evaluation.py index 92a694cc36..80bc1ecffb 100644 --- a/python/packages/core/agent_framework/_evaluation.py +++ b/python/packages/core/agent_framework/_evaluation.py @@ -96,6 +96,7 @@ def split_before_memory(conversation): # Fallback: split at last user message return EvalItem._split_last_turn_static(conversation) + item.split_messages(split=split_before_memory) """ @@ -468,10 +469,7 @@ def raise_for_status(self, msg: str | None = None) -> None: """ if not self.all_passed: errored = (self.result_counts or {}).get("errored", 0) - detail = msg or ( - f"Eval run {self.run_id} {self.status}: " - f"{self.passed} passed, {self.failed} failed." - ) + detail = msg or (f"Eval run {self.run_id} {self.status}: {self.passed} passed, {self.failed} failed.") if errored: detail += f" {errored} errored." if self.report_url: @@ -1188,8 +1186,7 @@ def _coerce_result(value: Any, check_name: str) -> CheckResult: score = float(d["score"]) except (TypeError, ValueError) as exc: raise TypeError( - f"Function evaluator '{check_name}' returned dict with non-numeric 'score' value:" - f" {d['score']!r}" + f"Function evaluator '{check_name}' returned dict with non-numeric 'score' value: {d['score']!r}" ) from exc # Honour an explicit 'passed' override; otherwise threshold-based. passed = bool(d["passed"]) if "passed" in d else score >= float(d.get("threshold", 0.5)) diff --git a/python/packages/core/tests/workflow/test_agent_executor.py b/python/packages/core/tests/workflow/test_agent_executor.py index d68b3fd4b5..7f58472b24 100644 --- a/python/packages/core/tests/workflow/test_agent_executor.py +++ b/python/packages/core/tests/workflow/test_agent_executor.py @@ -22,6 +22,7 @@ ) from agent_framework._workflows._agent_executor import AgentExecutorResponse from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage +from agent_framework._workflows._const import GLOBAL_KWARGS_KEY if TYPE_CHECKING: from _pytest.logging import LogCaptureFixture @@ -651,9 +652,6 @@ async def test_checkpoint_restore_works_without_context_mode_in_state() -> None: # --------------------------------------------------------------------------- -from agent_framework._workflows._const import GLOBAL_KWARGS_KEY - - async def test_resolve_executor_kwargs_returns_global_kwargs() -> None: """_resolve_executor_kwargs with __global__ key returns the global kwargs.""" agent = _CountingAgent(id="a", name="A") From bd3eea71730da74c2006284a9d4d56c5a9be81b2 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Wed, 1 Apr 2026 11:37:23 -0700 Subject: [PATCH 6/9] Comments --- .../core/agent_framework/_workflows/_agent.py | 94 +++- .../_workflows/_agent_executor.py | 42 +- .../agent_framework/_workflows/_workflow.py | 58 +- .../_workflows/_workflow_executor.py | 15 +- .../tests/workflow/test_agent_executor.py | 111 +--- .../tests/workflow/test_workflow_kwargs.py | 516 +++--------------- .../workflow_kwargs_per_agent.py | 2 +- 7 files changed, 211 insertions(+), 627 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_agent.py b/python/packages/core/agent_framework/_workflows/_agent.py index bf615814b3..ff98befdab 100644 --- a/python/packages/core/agent_framework/_workflows/_agent.py +++ b/python/packages/core/agent_framework/_workflows/_agent.py @@ -6,7 +6,7 @@ import logging import sys import uuid -from collections.abc import AsyncIterable, Awaitable, Sequence +from collections.abc import AsyncIterable, Awaitable, Mapping, Sequence from dataclasses import dataclass from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast, overload @@ -152,7 +152,8 @@ def run( session: AgentSession | None = None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, - **kwargs: Any, + function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, ) -> ResponseStream[AgentResponseUpdate, AgentResponse]: ... @overload @@ -164,7 +165,8 @@ async def run( session: AgentSession | None = None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, - **kwargs: Any, + function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, ) -> AgentResponse: ... def run( @@ -175,7 +177,8 @@ def run( session: AgentSession | None = None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, - **kwargs: Any, + function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, ) -> ResponseStream[AgentResponseUpdate, AgentResponse] | Awaitable[AgentResponse]: """Get a response from the workflow agent. @@ -192,8 +195,12 @@ def run( checkpoint_storage: Runtime checkpoint storage. When provided with checkpoint_id, used to load and restore the checkpoint. When provided without checkpoint_id, enables checkpointing for this run. - **kwargs: Additional keyword arguments passed through to underlying workflow - and tool functions. + function_invocation_kwargs: Keyword arguments forwarded to tool invocations in + subagents. Either a mapping of agent name/executor id to kwargs, or a flat + mapping of kwargs for all tool invocations. + client_kwargs: Keyword arguments forwarded to chat client calls in + subagents. Either a mapping of agent name/executor id to kwargs, or a flat + mapping of kwargs for all chat client calls. Returns: When stream=True: An AsyncIterable[AgentResponseUpdate] for streaming updates. @@ -208,10 +215,26 @@ def run( response_id = str(uuid.uuid4()) if stream: return ResponseStream( - self._run_stream_impl(messages, response_id, session, checkpoint_id, checkpoint_storage, **kwargs), + self._run_stream_impl( + messages, + response_id, + session, + checkpoint_id, + checkpoint_storage, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, + ), finalizer=AgentResponse.from_updates, ) - return self._run_impl(messages, response_id, session, checkpoint_id, checkpoint_storage, **kwargs) + return self._run_impl( + messages, + response_id, + session, + checkpoint_id, + checkpoint_storage, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, + ) async def _run_impl( self, @@ -220,7 +243,8 @@ async def _run_impl( session: AgentSession | None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, - **kwargs: Any, + function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, ) -> AgentResponse: """Internal implementation of non-streaming execution. @@ -230,8 +254,8 @@ async def _run_impl( session: The agent session for conversation context. checkpoint_id: ID of checkpoint to restore from. checkpoint_storage: Runtime checkpoint storage. - **kwargs: Additional keyword arguments passed through to the underlying - workflow and tool functions. + function_invocation_kwargs: Optional kwargs for tool invocations. + client_kwargs: Optional kwargs for chat client calls. Returns: An AgentResponse representing the workflow execution results. @@ -264,7 +288,12 @@ async def _run_impl( output_events: list[WorkflowEvent[Any]] = [] async for event in self._run_core( - session_messages, checkpoint_id, checkpoint_storage, streaming=False, **kwargs + session_messages, + checkpoint_id, + checkpoint_storage, + streaming=False, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, ): if event.type == "output" or event.type == "request_info": output_events.append(event) @@ -285,7 +314,8 @@ async def _run_stream_impl( session: AgentSession | None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, - **kwargs: Any, + function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, ) -> AsyncIterable[AgentResponseUpdate]: """Internal implementation of streaming execution. @@ -295,8 +325,8 @@ async def _run_stream_impl( session: The agent session for conversation context. checkpoint_id: ID of checkpoint to restore from. checkpoint_storage: Runtime checkpoint storage. - **kwargs: Additional keyword arguments passed through to the underlying - workflow and tool functions. + function_invocation_kwargs: Optional kwargs for tool invocations. + client_kwargs: Optional kwargs for chat client calls. Yields: AgentResponseUpdate objects representing the workflow execution progress. @@ -329,7 +359,12 @@ async def _run_stream_impl( session_messages: list[Message] = session_context.get_messages(include_input=True) all_updates: list[AgentResponseUpdate] = [] async for event in self._run_core( - session_messages, checkpoint_id, checkpoint_storage, streaming=True, **kwargs + session_messages, + checkpoint_id, + checkpoint_storage, + streaming=True, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, ): updates = self._convert_workflow_event_to_agent_response_updates(response_id, event) for update in updates: @@ -349,7 +384,8 @@ async def _run_core( checkpoint_id: str | None, checkpoint_storage: CheckpointStorage | None, streaming: bool, - **kwargs: Any, + function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, ) -> AsyncIterable[WorkflowEvent]: """Core implementation that yields workflow events for both streaming and non-streaming modes. @@ -358,12 +394,18 @@ async def _run_core( checkpoint_id: ID of checkpoint to restore from. checkpoint_storage: Runtime checkpoint storage. streaming: Whether to use streaming workflow methods. - **kwargs: Additional keyword arguments passed through to the underlying - workflow and tool functions. + function_invocation_kwargs: Optional kwargs for tool invocations. + client_kwargs: Optional kwargs for chat client calls. Yields: WorkflowEvent objects from the workflow execution. """ + invocation_kwargs: dict[str, Any] = {} + if function_invocation_kwargs is not None: + invocation_kwargs["function_invocation_kwargs"] = function_invocation_kwargs + if client_kwargs is not None: + invocation_kwargs["client_kwargs"] = client_kwargs + # Determine the execution mode based on state. # The streaming flag controls the workflow's internal streaming mode, # which affects executor behavior (e.g. AgentExecutor emits different event @@ -371,10 +413,12 @@ async def _run_core( if bool(self.pending_requests): function_responses = self._process_pending_requests(input_messages) if streaming: - async for event in self.workflow.run(responses=function_responses, stream=True, **kwargs): + async for event in self.workflow.run( + responses=function_responses, stream=True, **invocation_kwargs + ): yield event else: - for event in await self.workflow.run(responses=function_responses, **kwargs): + for event in await self.workflow.run(responses=function_responses, **invocation_kwargs): yield event elif checkpoint_id is not None: @@ -383,14 +427,14 @@ async def _run_core( stream=True, checkpoint_id=checkpoint_id, checkpoint_storage=checkpoint_storage, - **kwargs, + **invocation_kwargs, ): yield event else: for event in await self.workflow.run( checkpoint_id=checkpoint_id, checkpoint_storage=checkpoint_storage, - **kwargs, + **invocation_kwargs, ): yield event @@ -400,14 +444,14 @@ async def _run_core( message=input_messages, stream=True, checkpoint_storage=checkpoint_storage, - **kwargs, + **invocation_kwargs, ): yield event else: for event in await self.workflow.run( message=input_messages, checkpoint_storage=checkpoint_storage, - **kwargs, + **invocation_kwargs, ): yield event diff --git a/python/packages/core/agent_framework/_workflows/_agent_executor.py b/python/packages/core/agent_framework/_workflows/_agent_executor.py index 94aa5e8a38..72c3a91b74 100644 --- a/python/packages/core/agent_framework/_workflows/_agent_executor.py +++ b/python/packages/core/agent_framework/_workflows/_agent_executor.py @@ -335,7 +335,7 @@ async def _run_agent(self, ctx: WorkflowContext[Never, AgentResponse]) -> AgentR Returns: The complete AgentResponse, or None if waiting for user input. """ - function_invocation_kwargs, client_kwargs, backward_compatible_kwargs = self._prepare_agent_run_args( + function_invocation_kwargs, client_kwargs = self._prepare_agent_run_args( ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {}) ) @@ -346,7 +346,6 @@ async def _run_agent(self, ctx: WorkflowContext[Never, AgentResponse]) -> AgentR session=self._session, function_invocation_kwargs=function_invocation_kwargs, client_kwargs=client_kwargs, - **backward_compatible_kwargs, ) await ctx.yield_output(response) @@ -368,7 +367,7 @@ async def _run_agent_streaming(self, ctx: WorkflowContext[Never, AgentResponseUp Returns: The complete AgentResponse, or None if waiting for user input. """ - function_invocation_kwargs, client_kwargs, backward_compatible_kwargs = self._prepare_agent_run_args( + function_invocation_kwargs, client_kwargs = self._prepare_agent_run_args( ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {}) ) @@ -381,7 +380,6 @@ async def _run_agent_streaming(self, ctx: WorkflowContext[Never, AgentResponseUp session=self._session, function_invocation_kwargs=function_invocation_kwargs, client_kwargs=client_kwargs, - **backward_compatible_kwargs, ) async for update in stream: updates.append(update) @@ -427,50 +425,28 @@ async def _run_agent_streaming(self, ctx: WorkflowContext[Never, AgentResponseUp return response - # Parameters that are explicitly passed to agent.run() by AgentExecutor - # and must not appear in **run_kwargs to avoid TypeError from duplicate values. - _RESERVED_RUN_PARAMS: frozenset[str] = frozenset({"session", "stream", "messages"}) - def _prepare_agent_run_args( self, raw_run_kwargs: dict[str, Any], - ) -> tuple[dict[str, Any] | None, dict[str, Any] | None, dict[str, Any]]: - """Prepare function_invocation_kwargs, client_kwargs, and backward-compatible kwargs for agent.run(). + ) -> tuple[dict[str, Any] | None, dict[str, Any] | None]: + """Prepare function_invocation_kwargs and client_kwargs for agent.run(). - Extracts ``function_invocation_kwargs`` and ``client_invocation_kwargs`` from the + Extracts ``function_invocation_kwargs`` and ``client_kwargs`` from the workflow state dict, resolving per-executor entries using ``self.id``. The ``__global__`` sentinel key (set by ``Workflow._resolve_invocation_kwargs``) denotes global kwargs that apply to all executors. Per-executor dicts use executor IDs as keys; this executor extracts only its own entry. - Any remaining keys in the raw dict are treated as backward-compatible agent.run() - kwargs. Reserved parameters (session, stream, messages) that are explicitly managed - by AgentExecutor are stripped to prevent duplicate-keyword collisions. - Returns: - A 3-tuple of (function_invocation_kwargs, client_kwargs, backward_compatible_kwargs). + A 2-tuple of (function_invocation_kwargs, client_kwargs). """ - run_kwargs = dict(raw_run_kwargs) - - # Extract the already-resolved invocation kwargs dicts - # (set by Workflow._resolve_invocation_kwargs). - fi_resolved = run_kwargs.pop("function_invocation_kwargs", None) - ci_resolved = run_kwargs.pop("client_invocation_kwargs", None) + fi_resolved = raw_run_kwargs.get("function_invocation_kwargs") + ci_resolved = raw_run_kwargs.get("client_kwargs") function_invocation_kwargs = self._resolve_executor_kwargs(fi_resolved) client_kwargs = self._resolve_executor_kwargs(ci_resolved) - # Strip reserved params that AgentExecutor passes explicitly to agent.run(). - for key in AgentExecutor._RESERVED_RUN_PARAMS: - if key in run_kwargs: - logger.warning( - "Workflow kwarg '%s' is reserved by AgentExecutor and will be ignored. " - "Remove it from workflow.run() kwargs to silence this warning.", - key, - ) - run_kwargs.pop(key) - - return function_invocation_kwargs, client_kwargs, run_kwargs + return function_invocation_kwargs, client_kwargs def _resolve_executor_kwargs(self, resolved: dict[str, Any] | None) -> dict[str, Any] | None: """Extract this executor's kwargs from a resolved invocation kwargs dict. diff --git a/python/packages/core/agent_framework/_workflows/_workflow.py b/python/packages/core/agent_framework/_workflows/_workflow.py index e862648402..58050eece9 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow.py +++ b/python/packages/core/agent_framework/_workflows/_workflow.py @@ -10,7 +10,6 @@ import logging import types import uuid -import warnings from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, Sequence from typing import Any, Literal, overload @@ -300,8 +299,7 @@ async def _run_workflow_with_tracing( reset_context: bool = True, streaming: bool = False, function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, - client_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, - run_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, ) -> AsyncIterable[WorkflowEvent]: """Private method to run workflow with proper tracing. @@ -314,9 +312,8 @@ async def _run_workflow_with_tracing( streaming: Whether to enable streaming mode for agents function_invocation_kwargs: Optional kwargs to store in State for function invocations in subagents - client_invocation_kwargs: Optional kwargs to store in State for chat client + client_kwargs: Optional kwargs to store in State for chat client invocations in subagents - run_kwargs: Deprecated optional kwargs to store in State for agent invocations Yields: WorkflowEvent: The events generated during the workflow execution. @@ -355,21 +352,17 @@ async def _run_workflow_with_tracing( # Only overwrite when new kwargs are explicitly provided or state was # just cleared (fresh run). On continuation (reset_context=False) with # no new kwargs, preserve the kwargs from the original run. - if function_invocation_kwargs or client_invocation_kwargs: + if function_invocation_kwargs is not None or client_kwargs is not None: combined_kwargs: dict[str, Any] = {} - if function_invocation_kwargs: + if function_invocation_kwargs is not None: combined_kwargs["function_invocation_kwargs"] = self._resolve_invocation_kwargs( function_invocation_kwargs, "function_invocation_kwargs" ) - if client_invocation_kwargs: - combined_kwargs["client_invocation_kwargs"] = self._resolve_invocation_kwargs( - client_invocation_kwargs, "client_invocation_kwargs" + if client_kwargs is not None: + combined_kwargs["client_kwargs"] = self._resolve_invocation_kwargs( + client_kwargs, "client_kwargs" ) self._state.set(WORKFLOW_RUN_KWARGS_KEY, combined_kwargs) - elif run_kwargs is not None: - # Deprecated path for direct kwargs - still support but prefer the more explicit - # function_invocation_kwargs/client_invocation_kwargs - self._state.set(WORKFLOW_RUN_KWARGS_KEY, run_kwargs) elif reset_context: self._state.set(WORKFLOW_RUN_KWARGS_KEY, {}) self._state.commit() # Commit immediately so kwargs are available @@ -481,7 +474,7 @@ def run( checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, function_invocation_kwargs: Mapping[str, Any] | None = None, - **kwargs: Any, + client_kwargs: Mapping[str, Any] | None = None, ) -> ResponseStream[WorkflowEvent, WorkflowRunResult]: ... @overload @@ -495,7 +488,7 @@ def run( checkpoint_storage: CheckpointStorage | None = None, include_status_events: bool = False, function_invocation_kwargs: Mapping[str, Any] | None = None, - **kwargs: Any, + client_kwargs: Mapping[str, Any] | None = None, ) -> Awaitable[WorkflowRunResult]: ... def run( @@ -508,8 +501,7 @@ def run( checkpoint_storage: CheckpointStorage | None = None, include_status_events: bool = False, function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, - client_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, - **kwargs: Any, + client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, ) -> ResponseStream[WorkflowEvent, WorkflowRunResult] | Awaitable[WorkflowRunResult]: """Run the workflow, optionally streaming events. @@ -534,11 +526,9 @@ def run( function_invocation_kwargs: Keyword arguments forwarded to tool invocations in subagents. Either a mapping for agent name or agent executor id to kwargs, or a flat mapping of kwargs for all tool invocations. - client_invocation_kwargs: Keyword arguments forwarded to chat client calls in + client_kwargs: Keyword arguments forwarded to chat client calls in subagents. Either a mapping for agent name or agent executor id to kwargs, or a flat mapping of kwargs for all chat client calls. - **kwargs: Deprecated additional keyword arguments for the subagents. They are - forwarded to both tool invocations and the chat clients. Returns: When stream=True: A ResponseStream[WorkflowEvent, WorkflowRunResult] for @@ -548,19 +538,6 @@ def run( Raises: ValueError: If parameter combination is invalid. """ - if kwargs: - warnings.warn( - "Passing runtime keyword arguments directly to run() is deprecated; pass tool values via " - "function_invocation_kwargs and client_invocation_kwargs instead.", - DeprecationWarning, - stacklevel=2, - ) - if function_invocation_kwargs or client_invocation_kwargs: - raise ValueError( - "Cannot provide both deprecated kwargs and function_invocation_kwargs/client_invocation_kwargs. " - "Please consolidate to function_invocation_kwargs/client_invocation_kwargs." - ) - # Validate parameters and set running flag eagerly (before any async work) self._validate_run_params(message, responses, checkpoint_id) self._ensure_not_running() @@ -573,8 +550,7 @@ def run( checkpoint_storage=checkpoint_storage, streaming=stream, function_invocation_kwargs=function_invocation_kwargs, - client_invocation_kwargs=client_invocation_kwargs, - **kwargs, + client_kwargs=client_kwargs, ), finalizer=functools.partial(self._finalize_events, include_status_events=include_status_events), cleanup_hooks=[ @@ -595,8 +571,7 @@ async def _run_core( checkpoint_storage: CheckpointStorage | None = None, streaming: bool = False, function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, - client_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, - **kwargs: Any, + client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, ) -> AsyncIterable[WorkflowEvent]: """Single core execution path for both streaming and non-streaming modes. @@ -616,12 +591,7 @@ async def _run_core( reset_context=reset_context, streaming=streaming, function_invocation_kwargs=function_invocation_kwargs, - client_invocation_kwargs=client_invocation_kwargs, - # Empty **kwargs (no caller-provided kwargs) is collapsed to None so that - # continuation calls without explicit kwargs preserve the original run's kwargs. - # A non-empty kwargs dict (even one with empty values like {"key": {}}) - # is passed through and will overwrite stored kwargs. - run_kwargs=kwargs if kwargs else None, + client_kwargs=client_kwargs, ): if event.type == "output" and not self._should_yield_output_event(event): continue diff --git a/python/packages/core/agent_framework/_workflows/_workflow_executor.py b/python/packages/core/agent_framework/_workflows/_workflow_executor.py index e9e4196bfd..b4dfb6a7f8 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_executor.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_executor.py @@ -12,7 +12,7 @@ from ._workflow import Workflow from ._checkpoint_encoding import decode_checkpoint_value -from ._const import WORKFLOW_RUN_KWARGS_KEY +from ._const import GLOBAL_KWARGS_KEY, WORKFLOW_RUN_KWARGS_KEY from ._events import ( WorkflowEvent, WorkflowRunState, @@ -387,8 +387,19 @@ async def process_workflow(self, input_data: object, ctx: WorkflowContext[Any]) # Get kwargs from parent workflow's State to propagate to subworkflow parent_kwargs: dict[str, Any] = ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {}) + # Extract invocation kwargs recognised by Workflow.run() + # The state stores resolved format (with __global__ wrapper for global kwargs). + # Unwrap __global__ before passing to the subworkflow so it gets re-resolved + # against the subworkflow's own executor IDs. + invocation_kwargs: dict[str, Any] = {} + for key in ("function_invocation_kwargs", "client_kwargs"): + resolved = parent_kwargs.get(key) + if isinstance(resolved, dict): + # Unwrap global sentinel; pass per-executor dicts as-is + invocation_kwargs[key] = resolved.get(GLOBAL_KWARGS_KEY, resolved) + # Run the sub-workflow and collect all events, passing parent kwargs - result = await self.workflow.run(input_data, **parent_kwargs) + result = await self.workflow.run(input_data, **invocation_kwargs) logger.debug( f"WorkflowExecutor {self.id} sub-workflow {self.workflow.id} " diff --git a/python/packages/core/tests/workflow/test_agent_executor.py b/python/packages/core/tests/workflow/test_agent_executor.py index 7f58472b24..4e32265a44 100644 --- a/python/packages/core/tests/workflow/test_agent_executor.py +++ b/python/packages/core/tests/workflow/test_agent_executor.py @@ -1,8 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. -import logging from collections.abc import AsyncIterable, Awaitable -from typing import TYPE_CHECKING, Any, Literal, overload +from typing import Any, Literal, overload import pytest @@ -24,9 +23,6 @@ from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage from agent_framework._workflows._const import GLOBAL_KWARGS_KEY -if TYPE_CHECKING: - from _pytest.logging import LogCaptureFixture - class _CountingAgent(BaseAgent): """Agent that echoes messages with a counter to verify session state persistence.""" @@ -310,93 +306,28 @@ async def test_agent_executor_save_and_restore_state_directly() -> None: assert restored_session.session_id == session.session_id -async def test_agent_executor_run_with_session_kwarg_does_not_raise() -> None: - """Passing session= via workflow.run() should not cause a duplicate-keyword TypeError (#4295).""" - agent = _CountingAgent(id="session_kwarg_agent", name="SessionKwargAgent") - executor = AgentExecutor(agent, id="session_kwarg_exec") - workflow = WorkflowBuilder(start_executor=executor).build() - - # This previously raised: TypeError: run() got multiple values for keyword argument 'session' - result = await workflow.run("hello", session="user-supplied-value") - assert result is not None - assert agent.call_count == 1 - - -async def test_agent_executor_run_streaming_with_stream_kwarg_does_not_raise() -> None: - """Passing stream= via workflow.run() kwargs should not cause a duplicate-keyword TypeError.""" - agent = _CountingAgent(id="stream_kwarg_agent", name="StreamKwargAgent") - executor = AgentExecutor(agent, id="stream_kwarg_exec") - workflow = WorkflowBuilder(start_executor=executor).build() - - # stream=True at workflow level triggers streaming mode (returns async iterable) - events: list[WorkflowEvent] = [] - async for event in workflow.run("hello", stream=True): - events.append(event) - assert len(events) > 0 - assert agent.call_count == 1 - - -@pytest.mark.parametrize("reserved_kwarg", ["session", "stream", "messages"]) -async def test_prepare_agent_run_args_strips_reserved_kwargs(reserved_kwarg: str, caplog: "LogCaptureFixture") -> None: - """_prepare_agent_run_args must remove reserved kwargs and log a warning.""" +async def test_prepare_agent_run_args_extracts_invocation_kwargs() -> None: + """_prepare_agent_run_args extracts function_invocation_kwargs and client_kwargs.""" agent = _CountingAgent(id="test_agent", name="TestAgent") executor = AgentExecutor(agent, id="test_exec") raw: dict[str, Any] = { - reserved_kwarg: "should-be-stripped", - "custom_key": "keep-me", + "function_invocation_kwargs": {"__global__": {"key": "fi_val"}}, + "client_kwargs": {"__global__": {"key": "ci_val"}}, } + fi_kwargs, ci_kwargs = executor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage] + assert fi_kwargs == {"key": "fi_val"} + assert ci_kwargs == {"key": "ci_val"} - with caplog.at_level(logging.WARNING): - _, _, backward_kwargs = executor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage] - - assert reserved_kwarg not in backward_kwargs - assert "custom_key" in backward_kwargs - assert backward_kwargs["custom_key"] == "keep-me" - assert any(reserved_kwarg in record.message for record in caplog.records) - - -async def test_prepare_agent_run_args_preserves_non_reserved_kwargs() -> None: - """Non-reserved workflow kwargs should pass through unchanged.""" - agent = _CountingAgent(id="test_agent", name="TestAgent") - executor = AgentExecutor(agent, id="test_exec") - - raw: dict[str, Any] = {"custom_param": "value", "another": 42} - _, _, backward_kwargs = executor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage] - assert backward_kwargs["custom_param"] == "value" - assert backward_kwargs["another"] == 42 - -async def test_prepare_agent_run_args_strips_all_reserved_kwargs_at_once( - caplog: "LogCaptureFixture", -) -> None: - """All reserved kwargs should be stripped when supplied together, each emitting a warning.""" +async def test_prepare_agent_run_args_returns_none_when_no_kwargs() -> None: + """_prepare_agent_run_args returns None for both when raw dict has no invocation kwargs.""" agent = _CountingAgent(id="test_agent", name="TestAgent") executor = AgentExecutor(agent, id="test_exec") - raw: dict[str, Any] = {"session": "x", "stream": True, "messages": [], "custom": 1} - - with caplog.at_level(logging.WARNING): - _, _, backward_kwargs = executor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage] - - assert "session" not in backward_kwargs - assert "stream" not in backward_kwargs - assert "messages" not in backward_kwargs - assert backward_kwargs["custom"] == 1 - - warned_keys = {r.message.split("'")[1] for r in caplog.records if "reserved" in r.message.lower()} - assert warned_keys == {"session", "stream", "messages"} - - -async def test_agent_executor_run_with_messages_kwarg_does_not_raise() -> None: - """Passing messages= via workflow.run() kwargs should not cause a duplicate-keyword TypeError.""" - agent = _CountingAgent(id="messages_kwarg_agent", name="MessagesKwargAgent") - executor = AgentExecutor(agent, id="messages_kwarg_exec") - workflow = WorkflowBuilder(start_executor=executor).build() - - result = await workflow.run("hello", messages=["stale"]) - assert result is not None - assert agent.call_count == 1 + fi_kwargs, ci_kwargs = executor._prepare_agent_run_args({}) # pyright: ignore[reportPrivateUsage] + assert fi_kwargs is None + assert ci_kwargs is None class _NonCopyableRaw: @@ -720,24 +651,22 @@ async def test_prepare_agent_run_args_extracts_function_invocation_kwargs() -> N raw: dict[str, Any] = { "function_invocation_kwargs": {GLOBAL_KWARGS_KEY: {"tool_key": "tool_val"}}, } - fi_kwargs, client_kwargs, backward_kwargs = executor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage] + fi_kwargs, client_kwargs = executor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage] assert fi_kwargs == {"tool_key": "tool_val"} assert client_kwargs is None - assert "function_invocation_kwargs" not in backward_kwargs -async def test_prepare_agent_run_args_extracts_client_invocation_kwargs() -> None: - """_prepare_agent_run_args extracts client_invocation_kwargs from the state dict.""" +async def test_prepare_agent_run_args_extracts_client_kwargs() -> None: + """_prepare_agent_run_args extracts client_kwargs from the state dict.""" agent = _CountingAgent(id="a", name="A") executor = AgentExecutor(agent, id="exec_a") raw: dict[str, Any] = { - "client_invocation_kwargs": {GLOBAL_KWARGS_KEY: {"model": "gpt-4"}}, + "client_kwargs": {GLOBAL_KWARGS_KEY: {"model": "gpt-4"}}, } - fi_kwargs, client_kwargs, backward_kwargs = executor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage] + fi_kwargs, client_kwargs = executor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage] assert fi_kwargs is None assert client_kwargs == {"model": "gpt-4"} - assert "client_invocation_kwargs" not in backward_kwargs async def test_prepare_agent_run_args_per_executor_resolution() -> None: @@ -751,7 +680,7 @@ async def test_prepare_agent_run_args_per_executor_resolution() -> None: "exec_b": {"other_tool_key": "other_val"}, }, } - fi_kwargs, _, _ = executor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage] + fi_kwargs, _ = executor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage] assert fi_kwargs == {"my_tool_key": "my_val"} @@ -766,5 +695,5 @@ async def test_prepare_agent_run_args_per_executor_no_match() -> None: "exec_b": {"other_tool_key": "other_val"}, }, } - fi_kwargs, _, _ = executor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage] + fi_kwargs, _ = executor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage] assert fi_kwargs is None diff --git a/python/packages/core/tests/workflow/test_workflow_kwargs.py b/python/packages/core/tests/workflow/test_workflow_kwargs.py index d14be92bca..83927893cc 100644 --- a/python/packages/core/tests/workflow/test_workflow_kwargs.py +++ b/python/packages/core/tests/workflow/test_workflow_kwargs.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. from collections.abc import AsyncIterable, Awaitable -from typing import TYPE_CHECKING, Annotated, Any, Literal, overload +from typing import TYPE_CHECKING, Any, Literal, overload import pytest @@ -15,7 +15,6 @@ Message, ResponseStream, WorkflowRunState, - tool, ) from agent_framework._workflows._const import WORKFLOW_RUN_KWARGS_KEY from agent_framework.orchestrations import ( @@ -29,22 +28,8 @@ if TYPE_CHECKING: from _pytest.logging import LogCaptureFixture -# Track kwargs received by tools during test execution -_received_kwargs: list[dict[str, Any]] = [] - - -@tool(approval_mode="never_require") -def tool_with_kwargs( - action: Annotated[str, "The action to perform"], - **kwargs: Any, -) -> str: - """A test tool that captures kwargs for verification.""" - _received_kwargs.append(dict(kwargs)) - custom_data = kwargs.get("custom_data", {}) - user_token = kwargs.get("user_token", {}) - return f"Executed {action} with custom_data={custom_data}, user={user_token.get('user_name', 'unknown')}" - +# Track kwargs received by tools during test execution class _KwargsCapturingAgent(BaseAgent): """Test agent that captures kwargs passed to run.""" @@ -95,76 +80,20 @@ async def _run() -> AgentResponse: return _run() -class _OptionsAwareAgent(BaseAgent): - """Test agent that captures explicit `options` and kwargs passed to run().""" - - captured_options: list[dict[str, Any] | None] - captured_kwargs: list[dict[str, Any]] - - def __init__(self, name: str = "options_agent") -> None: - super().__init__(name=name, description="Test agent for options capture") - self.captured_options = [] - self.captured_kwargs = [] - - @overload - def run( - self, - messages: AgentRunInputs | None = ..., - *, - stream: Literal[False] = ..., - session: AgentSession | None = ..., - **kwargs: Any, - ) -> Awaitable[AgentResponse[Any]]: ... - @overload - def run( - self, - messages: AgentRunInputs | None = ..., - *, - stream: Literal[True], - session: AgentSession | None = ..., - **kwargs: Any, - ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... - - def run( - self, - messages: AgentRunInputs | None = None, - *, - stream: bool = False, - session: AgentSession | None = None, - options: dict[str, Any] | None = None, - **kwargs: Any, - ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: - self.captured_options.append(dict(options) if options is not None else None) - self.captured_kwargs.append(dict(kwargs)) - if stream: - - async def _stream() -> AsyncIterable[AgentResponseUpdate]: - yield AgentResponseUpdate(contents=[Content.from_text(text=f"{self.name} response")]) - - return ResponseStream(_stream(), finalizer=AgentResponse.from_updates) - - async def _run() -> AgentResponse: - return AgentResponse(messages=[Message("assistant", [f"{self.name} response"])]) - - return _run() - - # region Sequential Builder Tests async def test_sequential_kwargs_flow_to_agent() -> None: - """Test that kwargs passed to SequentialBuilder workflow flow through to agent.""" + """Test that function_invocation_kwargs passed to SequentialBuilder workflow flow through to agent.""" agent = _KwargsCapturingAgent(name="seq_agent") workflow = SequentialBuilder(participants=[agent]).build() - custom_data = {"endpoint": "https://api.example.com", "version": "v1"} - user_token = {"user_name": "alice", "access_level": "admin"} + fi_kwargs = {"endpoint": "https://api.example.com", "version": "v1"} async for event in workflow.run( "test message", stream=True, - custom_data=custom_data, - user_token=user_token, + function_invocation_kwargs=fi_kwargs, ): if event.type == "status" and event.state == WorkflowRunState.IDLE: break @@ -172,66 +101,53 @@ async def test_sequential_kwargs_flow_to_agent() -> None: # Verify agent received kwargs assert len(agent.captured_kwargs) >= 1, "Agent should have been invoked at least once" received = agent.captured_kwargs[0] - assert "custom_data" in received, "Agent should receive custom_data kwarg" - assert "user_token" in received, "Agent should receive user_token kwarg" - assert received["custom_data"] == custom_data - assert received["user_token"] == user_token + assert received.get("function_invocation_kwargs") == fi_kwargs async def test_sequential_kwargs_flow_to_multiple_agents() -> None: - """Test that kwargs flow to all agents in a sequential workflow.""" + """Test that function_invocation_kwargs flow to all agents in a sequential workflow.""" agent1 = _KwargsCapturingAgent(name="agent1") agent2 = _KwargsCapturingAgent(name="agent2") workflow = SequentialBuilder(participants=[agent1, agent2]).build() - custom_data = {"key": "value"} + fi_kwargs = {"key": "value"} - async for event in workflow.run("test", custom_data=custom_data, stream=True): + async for event in workflow.run("test", function_invocation_kwargs=fi_kwargs, stream=True): if event.type == "status" and event.state == WorkflowRunState.IDLE: break # Both agents should have received kwargs assert len(agent1.captured_kwargs) >= 1, "First agent should be invoked" assert len(agent2.captured_kwargs) >= 1, "Second agent should be invoked" - assert agent1.captured_kwargs[0].get("custom_data") == custom_data - assert agent2.captured_kwargs[0].get("custom_data") == custom_data + assert agent1.captured_kwargs[0].get("function_invocation_kwargs") == fi_kwargs + assert agent2.captured_kwargs[0].get("function_invocation_kwargs") == fi_kwargs async def test_sequential_run_kwargs_flow() -> None: - """Test that kwargs flow through workflow.run() (non-streaming).""" + """Test that function_invocation_kwargs flow through workflow.run() (non-streaming).""" agent = _KwargsCapturingAgent(name="run_agent") workflow = SequentialBuilder(participants=[agent]).build() - _ = await workflow.run("test message", custom_data={"test": True}) + _ = await workflow.run("test message", function_invocation_kwargs={"test": True}) assert len(agent.captured_kwargs) >= 1 - assert agent.captured_kwargs[0].get("custom_data") == {"test": True} + assert agent.captured_kwargs[0].get("function_invocation_kwargs") == {"test": True} -async def test_sequential_run_additional_function_arguments_flattened() -> None: - """Test workflow.run(additional_function_arguments=...) passes through as a backward-compat kwarg.""" - agent = _OptionsAwareAgent(name="options_agent") +async def test_sequential_run_non_streaming_kwargs_flow() -> None: + """Test workflow.run(function_invocation_kwargs=...) non-streaming path.""" + agent = _KwargsCapturingAgent(name="options_agent") workflow = SequentialBuilder(participants=[agent]).build() - custom_data = {"session_id": "abc123"} - user_token = {"user_name": "alice"} + fi_kwargs = {"session_id": "abc123"} - async for event in workflow.run( + _ = await workflow.run( "test message", - stream=True, - additional_function_arguments={"custom_data": custom_data, "user_token": user_token}, - ): - if event.type == "status" and event.state == WorkflowRunState.IDLE: - break + function_invocation_kwargs=fi_kwargs, + ) - # additional_function_arguments is passed through as a backward-compat kwarg, - # not merged into options. assert len(agent.captured_kwargs) >= 1 - captured_kwargs = agent.captured_kwargs[0] - assert captured_kwargs.get("additional_function_arguments") == { - "custom_data": custom_data, - "user_token": user_token, - } + assert agent.captured_kwargs[0].get("function_invocation_kwargs") == fi_kwargs # endregion @@ -241,19 +157,17 @@ async def test_sequential_run_additional_function_arguments_flattened() -> None: async def test_concurrent_kwargs_flow_to_agents() -> None: - """Test that kwargs flow to all agents in a concurrent workflow.""" + """Test that function_invocation_kwargs flow to all agents in a concurrent workflow.""" agent1 = _KwargsCapturingAgent(name="concurrent1") agent2 = _KwargsCapturingAgent(name="concurrent2") workflow = ConcurrentBuilder(participants=[agent1, agent2]).build() - custom_data = {"batch_id": "123"} - user_token = {"user_name": "bob"} + fi_kwargs = {"batch_id": "123"} async for event in workflow.run( "concurrent test", stream=True, - custom_data=custom_data, - user_token=user_token, + function_invocation_kwargs=fi_kwargs, ): if event.type == "status" and event.state == WorkflowRunState.IDLE: break @@ -264,8 +178,7 @@ async def test_concurrent_kwargs_flow_to_agents() -> None: for agent in [agent1, agent2]: received = agent.captured_kwargs[0] - assert received.get("custom_data") == custom_data - assert received.get("user_token") == user_token + assert received.get("function_invocation_kwargs") == fi_kwargs # endregion @@ -275,7 +188,7 @@ async def test_concurrent_kwargs_flow_to_agents() -> None: async def test_groupchat_kwargs_flow_to_agents() -> None: - """Test that kwargs flow to agents in a group chat workflow.""" + """Test that function_invocation_kwargs flow to agents in a group chat workflow.""" agent1 = _KwargsCapturingAgent(name="chat1") agent2 = _KwargsCapturingAgent(name="chat2") @@ -297,9 +210,9 @@ def simple_selector(state: GroupChatState) -> str: selection_func=simple_selector, ).build() - custom_data = {"session_id": "group123"} + fi_kwargs = {"session_id": "group123"} - async for event in workflow.run("group chat test", custom_data=custom_data, stream=True): + async for event in workflow.run("group chat test", function_invocation_kwargs=fi_kwargs, stream=True): if event.type == "status" and event.state == WorkflowRunState.IDLE: break @@ -308,7 +221,7 @@ def simple_selector(state: GroupChatState) -> str: assert len(all_kwargs) >= 1, "At least one agent should be invoked in group chat" for received in all_kwargs: - assert received.get("custom_data") == custom_data + assert received.get("function_invocation_kwargs") == fi_kwargs # endregion @@ -318,7 +231,7 @@ def simple_selector(state: GroupChatState) -> str: async def test_kwargs_stored_in_state() -> None: - """Test that kwargs are stored in State with the correct key.""" + """Test that function_invocation_kwargs are stored in State with the correct key.""" from agent_framework import Executor, WorkflowContext, handler stored_kwargs: dict[str, Any] | None = None @@ -333,13 +246,12 @@ async def inspect(self, msgs: list[Message], ctx: WorkflowContext[list[Message]] inspector = _StateInspector(id="inspector") workflow = SequentialBuilder(participants=[inspector]).build() - async for event in workflow.run("test", my_kwarg="my_value", another=123, stream=True): + async for event in workflow.run("test", function_invocation_kwargs={"my_kwarg": "my_value"}, stream=True): if event.type == "status" and event.state == WorkflowRunState.IDLE: break assert stored_kwargs is not None, "kwargs should be stored in State" - assert stored_kwargs.get("my_kwarg") == "my_value" - assert stored_kwargs.get("another") == 123 + assert "function_invocation_kwargs" in stored_kwargs async def test_empty_kwargs_stored_as_empty_dict() -> None: @@ -373,24 +285,8 @@ async def check(self, msgs: list[Message], ctx: WorkflowContext[list[Message]]) # region Edge Cases -async def test_kwargs_with_none_values() -> None: - """Test that kwargs with None values are passed through correctly.""" - agent = _KwargsCapturingAgent(name="none_test") - workflow = SequentialBuilder(participants=[agent]).build() - - async for event in workflow.run("test", optional_param=None, other_param="value", stream=True): - if event.type == "status" and event.state == WorkflowRunState.IDLE: - break - - assert len(agent.captured_kwargs) >= 1 - received = agent.captured_kwargs[0] - assert "optional_param" in received - assert received["optional_param"] is None - assert received["other_param"] == "value" - - async def test_kwargs_with_complex_nested_data() -> None: - """Test that complex nested data structures flow through correctly.""" + """Test that complex nested data structures flow through correctly via function_invocation_kwargs.""" agent = _KwargsCapturingAgent(name="nested_test") workflow = SequentialBuilder(participants=[agent]).build() @@ -405,17 +301,17 @@ async def test_kwargs_with_complex_nested_data() -> None: "tuple_like": [1, 2, 3], } - async for event in workflow.run("test", complex_data=complex_data, stream=True): + async for event in workflow.run("test", function_invocation_kwargs=complex_data, stream=True): if event.type == "status" and event.state == WorkflowRunState.IDLE: break assert len(agent.captured_kwargs) >= 1 received = agent.captured_kwargs[0] - assert received.get("complex_data") == complex_data + assert received.get("function_invocation_kwargs") == complex_data async def test_kwargs_preserved_on_response_continuation() -> None: - """Test that run kwargs are preserved when continuing a paused workflow with run(responses=...). + """Test that function_invocation_kwargs are preserved when continuing a paused workflow with run(responses=...). Regression test for #4293: kwargs were overwritten to {} on continuation calls. """ @@ -479,8 +375,9 @@ async def _done() -> AgentResponse: agent = _ApprovalCapturingAgent() workflow = WorkflowBuilder(start_executor=agent, output_executors=[agent]).build() - # Initial run with kwargs — workflow should pause for approval - result = await workflow.run("go", custom_data={"token": "abc"}) + # Initial run with function_invocation_kwargs — workflow should pause for approval + fi_kwargs = {"token": "abc"} + result = await workflow.run("go", function_invocation_kwargs=fi_kwargs) request_events = result.get_request_info_events() assert len(request_events) == 1 @@ -488,174 +385,14 @@ async def _done() -> AgentResponse: approval = request_events[0] await workflow.run(responses={approval.request_id: approval.data.to_function_approval_response(True)}) - # Both calls should have received the original kwargs + # Both calls should have received the original function_invocation_kwargs assert len(agent.captured_kwargs) == 2 - assert agent.captured_kwargs[0].get("custom_data") == {"token": "abc"} - assert agent.captured_kwargs[1].get("custom_data") == {"token": "abc"}, ( + assert agent.captured_kwargs[0].get("function_invocation_kwargs") == fi_kwargs + assert agent.captured_kwargs[1].get("function_invocation_kwargs") == fi_kwargs, ( f"kwargs should be preserved on continuation, got: {agent.captured_kwargs[1]}" ) -async def test_kwargs_overridden_on_response_continuation() -> None: - """Test that explicitly provided kwargs override prior kwargs on continuation.""" - - class _ApprovalCapturingAgent(BaseAgent): - captured_kwargs: list[dict[str, Any]] - _asked: bool - - def __init__(self) -> None: - super().__init__(name="approval_agent", description="Test agent") - self.captured_kwargs = [] - self._asked = False - - @overload - def run( - self, - messages: AgentRunInputs | None = ..., - *, - stream: Literal[False] = ..., - session: AgentSession | None = ..., - **kwargs: Any, - ) -> Awaitable[AgentResponse[Any]]: ... - @overload - def run( - self, - messages: AgentRunInputs | None = ..., - *, - stream: Literal[True], - session: AgentSession | None = ..., - **kwargs: Any, - ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... - - def run( - self, - messages: AgentRunInputs | None = None, - *, - stream: bool = False, - session: AgentSession | None = None, - **kwargs: Any, - ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: - self.captured_kwargs.append(dict(kwargs)) - if not self._asked: - self._asked = True - - async def _pause() -> AgentResponse: - call = Content.from_function_call(call_id="c1", name="do_thing", arguments="{}") - req = Content.from_function_approval_request(id="r1", function_call=call) - return AgentResponse(messages=[Message("assistant", [req])]) - - return _pause() - - async def _done() -> AgentResponse: - return AgentResponse(messages=[Message("assistant", ["done"])]) - - return _done() - - from agent_framework import WorkflowBuilder - - agent = _ApprovalCapturingAgent() - workflow = WorkflowBuilder(start_executor=agent, output_executors=[agent]).build() - - result = await workflow.run("go", custom_data={"token": "abc"}) - request_events = result.get_request_info_events() - approval = request_events[0] - - # Continue with responses AND new kwargs — should override - await workflow.run( - responses={approval.request_id: approval.data.to_function_approval_response(True)}, - custom_data={"token": "xyz"}, - ) - - assert len(agent.captured_kwargs) == 2 - assert agent.captured_kwargs[0].get("custom_data") == {"token": "abc"} - assert agent.captured_kwargs[1].get("custom_data") == {"token": "xyz"} - - -async def test_kwargs_empty_value_passed_on_continuation() -> None: - """Test that explicitly passing a kwarg with an empty value on continuation overrides prior kwargs. - - This exercises the boundary where the caller provides kwargs (e.g., custom_data={}) - that differ from the original run. Because the kwargs dict is non-empty (it has a key), - it passes the `kwargs if kwargs else None` gate and the `is not None` check, so it - overwrites the previously stored kwargs. - """ - - class _ApprovalCapturingAgent(BaseAgent): - captured_kwargs: list[dict[str, Any]] - _asked: bool - - def __init__(self) -> None: - super().__init__(name="approval_agent", description="Test agent") - self.captured_kwargs = [] - self._asked = False - - @overload - def run( - self, - messages: AgentRunInputs | None = ..., - *, - stream: Literal[False] = ..., - session: AgentSession | None = ..., - **kwargs: Any, - ) -> Awaitable[AgentResponse[Any]]: ... - @overload - def run( - self, - messages: AgentRunInputs | None = ..., - *, - stream: Literal[True], - session: AgentSession | None = ..., - **kwargs: Any, - ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... - - def run( - self, - messages: AgentRunInputs | None = None, - *, - stream: bool = False, - session: AgentSession | None = None, - **kwargs: Any, - ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: - self.captured_kwargs.append(dict(kwargs)) - if not self._asked: - self._asked = True - - async def _pause() -> AgentResponse: - call = Content.from_function_call(call_id="c1", name="do_thing", arguments="{}") - req = Content.from_function_approval_request(id="r1", function_call=call) - return AgentResponse(messages=[Message("assistant", [req])]) - - return _pause() - - async def _done() -> AgentResponse: - return AgentResponse(messages=[Message("assistant", ["done"])]) - - return _done() - - from agent_framework import WorkflowBuilder - - agent = _ApprovalCapturingAgent() - workflow = WorkflowBuilder(start_executor=agent, output_executors=[agent]).build() - - # Initial run with non-empty kwargs - result = await workflow.run("go", custom_data={"token": "abc"}) - request_events = result.get_request_info_events() - assert len(request_events) == 1 - - # Continue with custom_data={} — explicitly clearing the value. - # kwargs={"custom_data": {}} is truthy (has a key), so run_kwargs is set. - approval = request_events[0] - await workflow.run( - responses={approval.request_id: approval.data.to_function_approval_response(True)}, - custom_data={}, - ) - - assert len(agent.captured_kwargs) == 2 - assert agent.captured_kwargs[0].get("custom_data") == {"token": "abc"} - # The continuation explicitly set custom_data={}, overriding the original - assert agent.captured_kwargs[1].get("custom_data") == {} - - async def test_kwargs_reset_context_stores_empty_dict() -> None: """Test that reset_context=True with no kwargs stores an empty dict. @@ -672,33 +409,9 @@ async def test_kwargs_reset_context_stores_empty_dict() -> None: break assert len(agent.captured_kwargs) >= 1 - # The only kwarg should be the framework-injected 'options' (no user-provided kwargs) received = agent.captured_kwargs[0] - assert "custom_data" not in received - assert received.get("options") is None - - -async def test_kwargs_preserved_across_workflow_reruns() -> None: - """Test that kwargs are correctly isolated between workflow runs.""" - agent = _KwargsCapturingAgent(name="rerun_test") - - # Build separate workflows for each run to avoid "already running" error - workflow1 = SequentialBuilder(participants=[agent]).build() - workflow2 = SequentialBuilder(participants=[agent]).build() - - # First run - async for event in workflow1.run("run1", run_id="first", stream=True): - if event.type == "status" and event.state == WorkflowRunState.IDLE: - break - - # Second run with different kwargs (using fresh workflow) - async for event in workflow2.run("run2", run_id="second", stream=True): - if event.type == "status" and event.state == WorkflowRunState.IDLE: - break - - assert len(agent.captured_kwargs) >= 2 - assert agent.captured_kwargs[0].get("run_id") == "first" - assert agent.captured_kwargs[1].get("run_id") == "second" + assert received.get("function_invocation_kwargs") is None + assert received.get("client_kwargs") is None # endregion @@ -709,7 +422,7 @@ async def test_kwargs_preserved_across_workflow_reruns() -> None: @pytest.mark.xfail(reason="Handoff workflow does not yet propagate kwargs to agents") async def test_handoff_kwargs_flow_to_agents() -> None: - """Test that kwargs flow to agents in a handoff workflow.""" + """Test that function_invocation_kwargs flow to agents in a handoff workflow.""" agent1 = _KwargsCapturingAgent(name="coordinator") agent2 = _KwargsCapturingAgent(name="specialist") @@ -721,15 +434,15 @@ async def test_handoff_kwargs_flow_to_agents() -> None: .build() ) - custom_data = {"session_id": "handoff123"} + fi_kwargs = {"session_id": "handoff123"} - async for event in workflow.run("handoff test", custom_data=custom_data, stream=True): + async for event in workflow.run("handoff test", function_invocation_kwargs=fi_kwargs, stream=True): if event.type == "status" and event.state == WorkflowRunState.IDLE: break # Coordinator agent should have received kwargs assert len(agent1.captured_kwargs) >= 1, "Coordinator should be invoked in handoff" - assert agent1.captured_kwargs[0].get("custom_data") == custom_data + assert agent1.captured_kwargs[0].get("function_invocation_kwargs") == fi_kwargs # endregion @@ -781,7 +494,7 @@ async def prepare_final_answer(self, magentic_context: MagenticContext) -> Messa custom_data = {"session_id": "magentic123"} - async for event in workflow.run("magentic test", custom_data=custom_data, stream=True): + async for event in workflow.run("magentic test", function_invocation_kwargs=custom_data, stream=True): if event.type == "status" and event.state == WorkflowRunState.IDLE: break @@ -832,7 +545,7 @@ async def prepare_final_answer(self, magentic_context: MagenticContext) -> Messa # Use MagenticWorkflow.run() which goes through the kwargs attachment path custom_data = {"magentic_key": "magentic_value"} - async for event in magentic_workflow.run("test task", custom_data=custom_data, stream=True): + async for event in magentic_workflow.run("test task", function_invocation_kwargs=custom_data, stream=True): if event.type == "status" and event.state == WorkflowRunState.IDLE: break @@ -847,86 +560,61 @@ async def prepare_final_answer(self, magentic_context: MagenticContext) -> Messa async def test_workflow_as_agent_run_propagates_kwargs_to_underlying_agent() -> None: - """Test that kwargs passed to workflow_agent.run() flow through to the underlying agents.""" + """Test that function_invocation_kwargs passed to workflow_agent.run() flow through to the underlying agents.""" agent = _KwargsCapturingAgent(name="inner_agent") workflow = SequentialBuilder(participants=[agent]).build() workflow_agent = workflow.as_agent(name="TestWorkflowAgent") - custom_data = {"endpoint": "https://api.example.com", "version": "v1"} - user_token = {"user_name": "alice", "access_level": "admin"} + fi_kwargs = {"endpoint": "https://api.example.com", "version": "v1"} _ = await workflow_agent.run( "test message", - custom_data=custom_data, - user_token=user_token, + function_invocation_kwargs=fi_kwargs, ) # Verify inner agent received kwargs assert len(agent.captured_kwargs) >= 1, "Inner agent should have been invoked at least once" received = agent.captured_kwargs[0] - assert "custom_data" in received, "Inner agent should receive custom_data kwarg" - assert "user_token" in received, "Inner agent should receive user_token kwarg" - assert received["custom_data"] == custom_data - assert received["user_token"] == user_token + assert received.get("function_invocation_kwargs") == fi_kwargs async def test_workflow_as_agent_run_stream_propagates_kwargs_to_underlying_agent() -> None: - """Test that kwargs passed to workflow_agent.run() flow through to the underlying agents.""" + """Test that function_invocation_kwargs passed to workflow_agent.run(stream=True) flow through.""" agent = _KwargsCapturingAgent(name="inner_agent") workflow = SequentialBuilder(participants=[agent]).build() workflow_agent = workflow.as_agent(name="TestWorkflowAgent") - custom_data = {"session_id": "xyz123"} - api_token = "secret-token" + fi_kwargs = {"session_id": "xyz123"} async for _ in workflow_agent.run( "test message", stream=True, - custom_data=custom_data, - api_token=api_token, + function_invocation_kwargs=fi_kwargs, ): pass # Verify inner agent received kwargs assert len(agent.captured_kwargs) >= 1, "Inner agent should have been invoked at least once" received = agent.captured_kwargs[0] - assert "custom_data" in received, "Inner agent should receive custom_data kwarg" - assert "api_token" in received, "Inner agent should receive api_token kwarg" - assert received["custom_data"] == custom_data - assert received["api_token"] == api_token + assert received.get("function_invocation_kwargs") == fi_kwargs async def test_workflow_as_agent_propagates_kwargs_to_multiple_agents() -> None: - """Test that kwargs flow to all agents when using workflow.as_agent().""" + """Test that function_invocation_kwargs flow to all agents when using workflow.as_agent().""" agent1 = _KwargsCapturingAgent(name="agent1") agent2 = _KwargsCapturingAgent(name="agent2") workflow = SequentialBuilder(participants=[agent1, agent2]).build() workflow_agent = workflow.as_agent(name="MultiAgentWorkflow") - custom_data = {"batch_id": "batch-001"} + fi_kwargs = {"batch_id": "batch-001"} - _ = await workflow_agent.run("test message", custom_data=custom_data) + _ = await workflow_agent.run("test message", function_invocation_kwargs=fi_kwargs) # Both agents should have received kwargs assert len(agent1.captured_kwargs) >= 1, "First agent should be invoked" assert len(agent2.captured_kwargs) >= 1, "Second agent should be invoked" - assert agent1.captured_kwargs[0].get("custom_data") == custom_data - assert agent2.captured_kwargs[0].get("custom_data") == custom_data - - -async def test_workflow_as_agent_kwargs_with_none_values() -> None: - """Test that kwargs with None values are passed through correctly via as_agent().""" - agent = _KwargsCapturingAgent(name="none_test_agent") - workflow = SequentialBuilder(participants=[agent]).build() - workflow_agent = workflow.as_agent(name="NoneTestWorkflow") - - _ = await workflow_agent.run("test", optional_param=None, other_param="value") - - assert len(agent.captured_kwargs) >= 1 - received = agent.captured_kwargs[0] - assert "optional_param" in received - assert received["optional_param"] is None - assert received["other_param"] == "value" + assert agent1.captured_kwargs[0].get("function_invocation_kwargs") == fi_kwargs + assert agent2.captured_kwargs[0].get("function_invocation_kwargs") == fi_kwargs async def test_workflow_as_agent_kwargs_with_complex_nested_data() -> None: @@ -945,11 +633,11 @@ async def test_workflow_as_agent_kwargs_with_complex_nested_data() -> None: }, } - _ = await workflow_agent.run("test", complex_data=complex_data) + _ = await workflow_agent.run("test", function_invocation_kwargs=complex_data) assert len(agent.captured_kwargs) >= 1 received = agent.captured_kwargs[0] - assert received.get("complex_data") == complex_data + assert received.get("function_invocation_kwargs") == complex_data # endregion @@ -979,15 +667,13 @@ async def test_subworkflow_kwargs_propagation() -> None: outer_workflow = SequentialBuilder(participants=[subworkflow_executor]).build() # Define kwargs that should propagate to subworkflow - custom_data = {"api_key": "secret123", "endpoint": "https://api.example.com"} - user_token = {"user_name": "alice", "access_level": "admin"} + fi_kwargs = {"api_key": "secret123", "endpoint": "https://api.example.com"} # Run the outer workflow with kwargs async for event in outer_workflow.run( "test message for subworkflow", stream=True, - custom_data=custom_data, - user_token=user_token, + function_invocation_kwargs=fi_kwargs, ): if event.type == "status" and event.state == WorkflowRunState.IDLE: break @@ -998,17 +684,8 @@ async def test_subworkflow_kwargs_propagation() -> None: received_kwargs = inner_agent.captured_kwargs[0] # Verify kwargs were propagated from parent workflow to subworkflow agent - assert "custom_data" in received_kwargs, ( - f"Subworkflow agent should receive 'custom_data' kwarg. Received keys: {list(received_kwargs.keys())}" - ) - assert "user_token" in received_kwargs, ( - f"Subworkflow agent should receive 'user_token' kwarg. Received keys: {list(received_kwargs.keys())}" - ) - assert received_kwargs.get("custom_data") == custom_data, ( - f"Expected custom_data={custom_data}, got {received_kwargs.get('custom_data')}" - ) - assert received_kwargs.get("user_token") == user_token, ( - f"Expected user_token={user_token}, got {received_kwargs.get('user_token')}" + assert received_kwargs.get("function_invocation_kwargs") == fi_kwargs, ( + f"Expected function_invocation_kwargs={fi_kwargs}, got {received_kwargs.get('function_invocation_kwargs')}" ) @@ -1043,11 +720,11 @@ async def read_kwargs(self, msgs: list[Message], ctx: WorkflowContext[list[Messa outer_workflow = SequentialBuilder(participants=[subworkflow_executor]).build() # Run with kwargs + fi_kwargs = {"my_custom_kwarg": "should_be_propagated", "another_kwarg": 42} async for event in outer_workflow.run( "test", stream=True, - my_custom_kwarg="should_be_propagated", - another_kwarg=42, + function_invocation_kwargs=fi_kwargs, ): if event.type == "status" and event.state == WorkflowRunState.IDLE: break @@ -1057,11 +734,8 @@ async def read_kwargs(self, msgs: list[Message], ctx: WorkflowContext[list[Messa kwargs_in_subworkflow = captured_kwargs_from_state[0] - assert kwargs_in_subworkflow.get("my_custom_kwarg") == "should_be_propagated", ( - f"Expected 'my_custom_kwarg' in subworkflow got: {kwargs_in_subworkflow}" - ) - assert kwargs_in_subworkflow.get("another_kwarg") == 42, ( - f"Expected 'another_kwarg'=42 in subworkflow got: {kwargs_in_subworkflow}" + assert "function_invocation_kwargs" in kwargs_in_subworkflow, ( + f"Expected 'function_invocation_kwargs' in subworkflow state, got: {kwargs_in_subworkflow}" ) @@ -1093,7 +767,7 @@ async def test_nested_subworkflow_kwargs_propagation() -> None: async for event in outer_workflow.run( "deeply nested test", stream=True, - deep_kwarg="should_reach_inner", + function_invocation_kwargs={"deep_kwarg": "should_reach_inner"}, ): if event.type == "status" and event.state == WorkflowRunState.IDLE: break @@ -1102,7 +776,7 @@ async def test_nested_subworkflow_kwargs_propagation() -> None: assert len(inner_agent.captured_kwargs) >= 1, "Deeply nested agent should be invoked" received = inner_agent.captured_kwargs[0] - assert received.get("deep_kwarg") == "should_reach_inner", ( + assert received.get("function_invocation_kwargs") == {"deep_kwarg": "should_reach_inner"}, ( f"Deeply nested agent should receive 'deep_kwarg'. Got: {received}" ) @@ -1113,28 +787,8 @@ async def test_nested_subworkflow_kwargs_propagation() -> None: # region Per-Executor Invocation Kwargs Tests -async def test_legacy_kwargs_cannot_coexist_with_new_invocation_kwargs() -> None: - """Passing both legacy **kwargs and function_invocation_kwargs/client_invocation_kwargs must raise ValueError.""" - agent = _KwargsCapturingAgent(name="agent1") - workflow = SequentialBuilder(participants=[agent]).build() - - with pytest.raises(ValueError, match="Cannot provide both deprecated kwargs"): - await workflow.run( - "test", - function_invocation_kwargs={"tool_key": "val"}, - custom_legacy_kwarg="should_conflict", - ) - - with pytest.raises(ValueError, match="Cannot provide both deprecated kwargs"): - await workflow.run( - "test", - client_invocation_kwargs={"model": "gpt-4"}, - custom_legacy_kwarg="should_conflict", - ) - - -async def test_function_and_client_invocation_kwargs_together() -> None: - """Both function_invocation_kwargs and client_invocation_kwargs can be provided in the same call.""" +async def test_function_and_client_kwargs_together() -> None: + """Both function_invocation_kwargs and client_kwargs can be provided in the same call.""" agent1 = _KwargsCapturingAgent(name="agent1") agent2 = _KwargsCapturingAgent(name="agent2") workflow = SequentialBuilder(participants=[agent1, agent2]).build() @@ -1146,7 +800,7 @@ async def test_function_and_client_invocation_kwargs_together() -> None: "test", stream=True, function_invocation_kwargs=fi_kwargs, - client_invocation_kwargs=ci_kwargs, + client_kwargs=ci_kwargs, ): if event.type == "status" and event.state == WorkflowRunState.IDLE: break @@ -1231,8 +885,8 @@ async def test_per_executor_kwargs_unmatched_agent_gets_none() -> None: assert agent2.captured_kwargs[0].get("function_invocation_kwargs") is None -async def test_global_client_invocation_kwargs_flow_to_all_agents() -> None: - """Global client_invocation_kwargs should be received by all agents.""" +async def test_global_client_kwargs_flow_to_all_agents() -> None: + """Global client_kwargs should be received by all agents.""" agent1 = _KwargsCapturingAgent(name="agent1") agent2 = _KwargsCapturingAgent(name="agent2") workflow = SequentialBuilder(participants=[agent1, agent2]).build() @@ -1242,7 +896,7 @@ async def test_global_client_invocation_kwargs_flow_to_all_agents() -> None: async for event in workflow.run( "test", stream=True, - client_invocation_kwargs=ci_kwargs, + client_kwargs=ci_kwargs, ): if event.type == "status" and event.state == WorkflowRunState.IDLE: break @@ -1253,8 +907,8 @@ async def test_global_client_invocation_kwargs_flow_to_all_agents() -> None: assert agent2.captured_kwargs[0].get("client_kwargs") == ci_kwargs -async def test_per_executor_client_invocation_kwargs_routes_correctly() -> None: - """Per-executor client_invocation_kwargs should only be received by the targeted agent.""" +async def test_per_executor_client_kwargs_routes_correctly() -> None: + """Per-executor client_kwargs should only be received by the targeted agent.""" agent1 = _KwargsCapturingAgent(name="agent1") agent2 = _KwargsCapturingAgent(name="agent2") workflow = SequentialBuilder(participants=[agent1, agent2]).build() @@ -1267,7 +921,7 @@ async def test_per_executor_client_invocation_kwargs_routes_correctly() -> None: async for event in workflow.run( "test", stream=True, - client_invocation_kwargs=ci_kwargs, + client_kwargs=ci_kwargs, ): if event.type == "status" and event.state == WorkflowRunState.IDLE: break diff --git a/python/samples/03-workflows/state-management/workflow_kwargs_per_agent.py b/python/samples/03-workflows/state-management/workflow_kwargs_per_agent.py index 9c12ac90fc..d35d409301 100644 --- a/python/samples/03-workflows/state-management/workflow_kwargs_per_agent.py +++ b/python/samples/03-workflows/state-management/workflow_kwargs_per_agent.py @@ -20,7 +20,7 @@ This sample demonstrates how to pass different kwargs to different agents in a workflow using per-agent targeting. When keys in function_invocation_kwargs (or -client_invocation_kwargs) match executor IDs (agent names by default), each agent +client_kwargs) match executor IDs (agent names by default), each agent receives only its own slice of the kwargs. Key Concepts: From 319b0954b166e843469d968e30393fce62f0e5dc Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Wed, 1 Apr 2026 12:48:23 -0700 Subject: [PATCH 7/9] Comments 2 --- .../_workflows/_agent_executor.py | 11 +++- .../tests/workflow/test_agent_executor.py | 16 +++++- .../tests/workflow/test_workflow_kwargs.py | 50 +++++++++++++++++++ 3 files changed, 73 insertions(+), 4 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_agent_executor.py b/python/packages/core/agent_framework/_workflows/_agent_executor.py index 72c3a91b74..3740fc783d 100644 --- a/python/packages/core/agent_framework/_workflows/_agent_executor.py +++ b/python/packages/core/agent_framework/_workflows/_agent_executor.py @@ -461,7 +461,14 @@ def _resolve_executor_kwargs(self, resolved: dict[str, Any] | None) -> dict[str, """ if not isinstance(resolved, dict): return None - executor_kwargs = resolved.get(self.id) or resolved.get(GLOBAL_KWARGS_KEY) + # Use explicit key-presence checks so that an empty per-executor dict is + # honoured (e.g. to clear kwargs) instead of falling through to global. + if self.id in resolved: + executor_kwargs = resolved[self.id] + elif GLOBAL_KWARGS_KEY in resolved: + executor_kwargs = resolved[GLOBAL_KWARGS_KEY] + else: + return None if isinstance(executor_kwargs, dict): - return cast(dict[str, Any], executor_kwargs) or None + return dict(executor_kwargs) or None return None diff --git a/python/packages/core/tests/workflow/test_agent_executor.py b/python/packages/core/tests/workflow/test_agent_executor.py index 4e32265a44..319233a11c 100644 --- a/python/packages/core/tests/workflow/test_agent_executor.py +++ b/python/packages/core/tests/workflow/test_agent_executor.py @@ -584,11 +584,11 @@ async def test_checkpoint_restore_works_without_context_mode_in_state() -> None: async def test_resolve_executor_kwargs_returns_global_kwargs() -> None: - """_resolve_executor_kwargs with __global__ key returns the global kwargs.""" + """_resolve_executor_kwargs with the global kwargs key returns the global kwargs.""" agent = _CountingAgent(id="a", name="A") executor = AgentExecutor(agent, id="exec_a") - resolved = {"__global__": {"tool_param": "value"}} + resolved = {GLOBAL_KWARGS_KEY: {"tool_param": "value"}} result = executor._resolve_executor_kwargs(resolved) # pyright: ignore[reportPrivateUsage] assert result == {"tool_param": "value"} @@ -697,3 +697,15 @@ async def test_prepare_agent_run_args_per_executor_no_match() -> None: } fi_kwargs, _ = executor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage] assert fi_kwargs is None + + +async def test_resolve_executor_kwargs_empty_per_executor_does_not_fallback_to_global() -> None: + """An explicit empty per-executor dict should not fall through to global kwargs.""" + agent = _CountingAgent(id="a", name="A") + executor = AgentExecutor(agent, id="exec_a") + + # Per-executor entry for exec_a is empty, but global has values. + # The empty dict should be honoured (no fallback to global). + resolved = {"exec_a": {}, GLOBAL_KWARGS_KEY: {"global_key": "global_val"}} + result = executor._resolve_executor_kwargs(resolved) # pyright: ignore[reportPrivateUsage] + assert result is None diff --git a/python/packages/core/tests/workflow/test_workflow_kwargs.py b/python/packages/core/tests/workflow/test_workflow_kwargs.py index 83927893cc..fa6af58fbb 100644 --- a/python/packages/core/tests/workflow/test_workflow_kwargs.py +++ b/python/packages/core/tests/workflow/test_workflow_kwargs.py @@ -972,4 +972,54 @@ async def test_resolve_invocation_kwargs_logs_global(caplog: "LogCaptureFixture" assert len(global_logs) >= 1 +async def test_empty_function_invocation_kwargs_clears_previous() -> None: + """Passing function_invocation_kwargs={} should clear previously stored kwargs on a new run.""" + agent = _KwargsCapturingAgent(name="clearing_agent") + workflow = SequentialBuilder(participants=[agent]).build() + + # First run: provide kwargs + await workflow.run( + "first", + function_invocation_kwargs={"key": "value"}, + ) + + assert len(agent.captured_kwargs) >= 1 + assert agent.captured_kwargs[0].get("function_invocation_kwargs") == {"key": "value"} + + # Second run: pass empty dict to explicitly clear + await workflow.run( + "second", + function_invocation_kwargs={}, + ) + + # Agent should receive None because the empty dict resolves to an empty + # __global__ entry which is treated as "no kwargs" for each executor. + assert len(agent.captured_kwargs) >= 2 + assert agent.captured_kwargs[-1].get("function_invocation_kwargs") is None + + +async def test_empty_client_kwargs_clears_previous() -> None: + """Passing client_kwargs={} should clear previously stored kwargs on a new run.""" + agent = _KwargsCapturingAgent(name="clearing_agent") + workflow = SequentialBuilder(participants=[agent]).build() + + # First run: provide kwargs + await workflow.run( + "first", + client_kwargs={"temperature": 0.5}, + ) + + assert len(agent.captured_kwargs) >= 1 + assert agent.captured_kwargs[0].get("client_kwargs") == {"temperature": 0.5} + + # Second run: pass empty dict to explicitly clear + await workflow.run( + "second", + client_kwargs={}, + ) + + assert len(agent.captured_kwargs) >= 2 + assert agent.captured_kwargs[-1].get("client_kwargs") is None + + # endregion From 4400dda34d759b3298369a6c1498e1740007bb22 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Wed, 1 Apr 2026 13:03:41 -0700 Subject: [PATCH 8/9] Comments 3 --- .../core/agent_framework/_workflows/_agent.py | 29 +++++++++++-------- .../_workflows/_agent_executor.py | 5 ++-- .../_workflows/_workflow_executor.py | 15 ++++++++-- 3 files changed, 31 insertions(+), 18 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_agent.py b/python/packages/core/agent_framework/_workflows/_agent.py index ff98befdab..af3eb5c40e 100644 --- a/python/packages/core/agent_framework/_workflows/_agent.py +++ b/python/packages/core/agent_framework/_workflows/_agent.py @@ -400,12 +400,6 @@ async def _run_core( Yields: WorkflowEvent objects from the workflow execution. """ - invocation_kwargs: dict[str, Any] = {} - if function_invocation_kwargs is not None: - invocation_kwargs["function_invocation_kwargs"] = function_invocation_kwargs - if client_kwargs is not None: - invocation_kwargs["client_kwargs"] = client_kwargs - # Determine the execution mode based on state. # The streaming flag controls the workflow's internal streaming mode, # which affects executor behavior (e.g. AgentExecutor emits different event @@ -414,11 +408,18 @@ async def _run_core( function_responses = self._process_pending_requests(input_messages) if streaming: async for event in self.workflow.run( - responses=function_responses, stream=True, **invocation_kwargs + responses=function_responses, + stream=True, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, ): yield event else: - for event in await self.workflow.run(responses=function_responses, **invocation_kwargs): + for event in await self.workflow.run( + responses=function_responses, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, + ): yield event elif checkpoint_id is not None: @@ -427,14 +428,16 @@ async def _run_core( stream=True, checkpoint_id=checkpoint_id, checkpoint_storage=checkpoint_storage, - **invocation_kwargs, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, ): yield event else: for event in await self.workflow.run( checkpoint_id=checkpoint_id, checkpoint_storage=checkpoint_storage, - **invocation_kwargs, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, ): yield event @@ -444,14 +447,16 @@ async def _run_core( message=input_messages, stream=True, checkpoint_storage=checkpoint_storage, - **invocation_kwargs, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, ): yield event else: for event in await self.workflow.run( message=input_messages, checkpoint_storage=checkpoint_storage, - **invocation_kwargs, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, ): yield event diff --git a/python/packages/core/agent_framework/_workflows/_agent_executor.py b/python/packages/core/agent_framework/_workflows/_agent_executor.py index 3740fc783d..c6db6762f6 100644 --- a/python/packages/core/agent_framework/_workflows/_agent_executor.py +++ b/python/packages/core/agent_framework/_workflows/_agent_executor.py @@ -469,6 +469,5 @@ def _resolve_executor_kwargs(self, resolved: dict[str, Any] | None) -> dict[str, executor_kwargs = resolved[GLOBAL_KWARGS_KEY] else: return None - if isinstance(executor_kwargs, dict): - return dict(executor_kwargs) or None - return None + + return executor_kwargs diff --git a/python/packages/core/agent_framework/_workflows/_workflow_executor.py b/python/packages/core/agent_framework/_workflows/_workflow_executor.py index b4dfb6a7f8..afb6145251 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_executor.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_executor.py @@ -391,15 +391,24 @@ async def process_workflow(self, input_data: object, ctx: WorkflowContext[Any]) # The state stores resolved format (with __global__ wrapper for global kwargs). # Unwrap __global__ before passing to the subworkflow so it gets re-resolved # against the subworkflow's own executor IDs. - invocation_kwargs: dict[str, Any] = {} + fi_kwargs: dict[str, Any] | None = None + ci_kwargs: dict[str, Any] | None = None for key in ("function_invocation_kwargs", "client_kwargs"): resolved = parent_kwargs.get(key) if isinstance(resolved, dict): # Unwrap global sentinel; pass per-executor dicts as-is - invocation_kwargs[key] = resolved.get(GLOBAL_KWARGS_KEY, resolved) + unwrapped: dict[str, Any] = resolved.get(GLOBAL_KWARGS_KEY, resolved) # type: ignore + if key == "function_invocation_kwargs": + fi_kwargs = unwrapped # type: ignore + else: + ci_kwargs = unwrapped # type: ignore # Run the sub-workflow and collect all events, passing parent kwargs - result = await self.workflow.run(input_data, **invocation_kwargs) + result = await self.workflow.run( + input_data, + function_invocation_kwargs=fi_kwargs, # type: ignore + client_kwargs=ci_kwargs, # type: ignore + ) logger.debug( f"WorkflowExecutor {self.id} sub-workflow {self.workflow.id} " From 8a17c5b7579530b3160c51b9cb5a7b28dd68ed25 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Wed, 1 Apr 2026 13:13:29 -0700 Subject: [PATCH 9/9] Fix test and typing --- .../agent_framework/_workflows/_agent_executor.py | 11 ++++++++++- .../core/tests/workflow/test_agent_executor.py | 12 +----------- .../core/tests/workflow/test_workflow_kwargs.py | 4 ++-- 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_agent_executor.py b/python/packages/core/agent_framework/_workflows/_agent_executor.py index c6db6762f6..986b086fab 100644 --- a/python/packages/core/agent_framework/_workflows/_agent_executor.py +++ b/python/packages/core/agent_framework/_workflows/_agent_executor.py @@ -470,4 +470,13 @@ def _resolve_executor_kwargs(self, resolved: dict[str, Any] | None) -> dict[str, else: return None - return executor_kwargs + if not isinstance(executor_kwargs, dict): + logger.warning( + "Executor %s expected a dict for its kwargs, but got %s. Ignoring.", + self.id, + type(executor_kwargs), # type: ignore + ) + + return None + + return executor_kwargs # type: ignore diff --git a/python/packages/core/tests/workflow/test_agent_executor.py b/python/packages/core/tests/workflow/test_agent_executor.py index 319233a11c..f6b7eb7f31 100644 --- a/python/packages/core/tests/workflow/test_agent_executor.py +++ b/python/packages/core/tests/workflow/test_agent_executor.py @@ -633,16 +633,6 @@ async def test_resolve_executor_kwargs_prefers_executor_id_over_global() -> None assert result == {"specific": True} -async def test_resolve_executor_kwargs_returns_none_for_empty_dict_value() -> None: - """_resolve_executor_kwargs returns None when the matched value is an empty dict.""" - agent = _CountingAgent(id="a", name="A") - executor = AgentExecutor(agent, id="exec_a") - - resolved = {GLOBAL_KWARGS_KEY: {}} - result = executor._resolve_executor_kwargs(resolved) # pyright: ignore[reportPrivateUsage] - assert result is None - - async def test_prepare_agent_run_args_extracts_function_invocation_kwargs() -> None: """_prepare_agent_run_args extracts function_invocation_kwargs from the state dict.""" agent = _CountingAgent(id="a", name="A") @@ -708,4 +698,4 @@ async def test_resolve_executor_kwargs_empty_per_executor_does_not_fallback_to_g # The empty dict should be honoured (no fallback to global). resolved = {"exec_a": {}, GLOBAL_KWARGS_KEY: {"global_key": "global_val"}} result = executor._resolve_executor_kwargs(resolved) # pyright: ignore[reportPrivateUsage] - assert result is None + assert result == {} diff --git a/python/packages/core/tests/workflow/test_workflow_kwargs.py b/python/packages/core/tests/workflow/test_workflow_kwargs.py index fa6af58fbb..1b14d53375 100644 --- a/python/packages/core/tests/workflow/test_workflow_kwargs.py +++ b/python/packages/core/tests/workflow/test_workflow_kwargs.py @@ -995,7 +995,7 @@ async def test_empty_function_invocation_kwargs_clears_previous() -> None: # Agent should receive None because the empty dict resolves to an empty # __global__ entry which is treated as "no kwargs" for each executor. assert len(agent.captured_kwargs) >= 2 - assert agent.captured_kwargs[-1].get("function_invocation_kwargs") is None + assert agent.captured_kwargs[-1].get("function_invocation_kwargs") == {} async def test_empty_client_kwargs_clears_previous() -> None: @@ -1019,7 +1019,7 @@ async def test_empty_client_kwargs_clears_previous() -> None: ) assert len(agent.captured_kwargs) >= 2 - assert agent.captured_kwargs[-1].get("client_kwargs") is None + assert agent.captured_kwargs[-1].get("client_kwargs") == {} # endregion