diff --git a/python/packages/core/agent_framework/_evaluation.py b/python/packages/core/agent_framework/_evaluation.py index 92a694cc36..80bc1ecffb 100644 --- a/python/packages/core/agent_framework/_evaluation.py +++ b/python/packages/core/agent_framework/_evaluation.py @@ -96,6 +96,7 @@ def split_before_memory(conversation): # Fallback: split at last user message return EvalItem._split_last_turn_static(conversation) + item.split_messages(split=split_before_memory) """ @@ -468,10 +469,7 @@ def raise_for_status(self, msg: str | None = None) -> None: """ if not self.all_passed: errored = (self.result_counts or {}).get("errored", 0) - detail = msg or ( - f"Eval run {self.run_id} {self.status}: " - f"{self.passed} passed, {self.failed} failed." - ) + detail = msg or (f"Eval run {self.run_id} {self.status}: {self.passed} passed, {self.failed} failed.") if errored: detail += f" {errored} errored." if self.report_url: @@ -1188,8 +1186,7 @@ def _coerce_result(value: Any, check_name: str) -> CheckResult: score = float(d["score"]) except (TypeError, ValueError) as exc: raise TypeError( - f"Function evaluator '{check_name}' returned dict with non-numeric 'score' value:" - f" {d['score']!r}" + f"Function evaluator '{check_name}' returned dict with non-numeric 'score' value: {d['score']!r}" ) from exc # Honour an explicit 'passed' override; otherwise threshold-based. passed = bool(d["passed"]) if "passed" in d else score >= float(d.get("threshold", 0.5)) diff --git a/python/packages/core/agent_framework/_workflows/_agent.py b/python/packages/core/agent_framework/_workflows/_agent.py index bf615814b3..af3eb5c40e 100644 --- a/python/packages/core/agent_framework/_workflows/_agent.py +++ b/python/packages/core/agent_framework/_workflows/_agent.py @@ -6,7 +6,7 @@ import logging import sys import uuid -from collections.abc import AsyncIterable, Awaitable, Sequence +from collections.abc import AsyncIterable, Awaitable, Mapping, Sequence from dataclasses import dataclass from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast, overload @@ -152,7 +152,8 @@ def run( session: AgentSession | None = None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, - **kwargs: Any, + function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, ) -> ResponseStream[AgentResponseUpdate, AgentResponse]: ... @overload @@ -164,7 +165,8 @@ async def run( session: AgentSession | None = None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, - **kwargs: Any, + function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, ) -> AgentResponse: ... def run( @@ -175,7 +177,8 @@ def run( session: AgentSession | None = None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, - **kwargs: Any, + function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, ) -> ResponseStream[AgentResponseUpdate, AgentResponse] | Awaitable[AgentResponse]: """Get a response from the workflow agent. @@ -192,8 +195,12 @@ def run( checkpoint_storage: Runtime checkpoint storage. When provided with checkpoint_id, used to load and restore the checkpoint. When provided without checkpoint_id, enables checkpointing for this run. - **kwargs: Additional keyword arguments passed through to underlying workflow - and tool functions. + function_invocation_kwargs: Keyword arguments forwarded to tool invocations in + subagents. Either a mapping of agent name/executor id to kwargs, or a flat + mapping of kwargs for all tool invocations. + client_kwargs: Keyword arguments forwarded to chat client calls in + subagents. Either a mapping of agent name/executor id to kwargs, or a flat + mapping of kwargs for all chat client calls. Returns: When stream=True: An AsyncIterable[AgentResponseUpdate] for streaming updates. @@ -208,10 +215,26 @@ def run( response_id = str(uuid.uuid4()) if stream: return ResponseStream( - self._run_stream_impl(messages, response_id, session, checkpoint_id, checkpoint_storage, **kwargs), + self._run_stream_impl( + messages, + response_id, + session, + checkpoint_id, + checkpoint_storage, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, + ), finalizer=AgentResponse.from_updates, ) - return self._run_impl(messages, response_id, session, checkpoint_id, checkpoint_storage, **kwargs) + return self._run_impl( + messages, + response_id, + session, + checkpoint_id, + checkpoint_storage, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, + ) async def _run_impl( self, @@ -220,7 +243,8 @@ async def _run_impl( session: AgentSession | None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, - **kwargs: Any, + function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, ) -> AgentResponse: """Internal implementation of non-streaming execution. @@ -230,8 +254,8 @@ async def _run_impl( session: The agent session for conversation context. checkpoint_id: ID of checkpoint to restore from. checkpoint_storage: Runtime checkpoint storage. - **kwargs: Additional keyword arguments passed through to the underlying - workflow and tool functions. + function_invocation_kwargs: Optional kwargs for tool invocations. + client_kwargs: Optional kwargs for chat client calls. Returns: An AgentResponse representing the workflow execution results. @@ -264,7 +288,12 @@ async def _run_impl( output_events: list[WorkflowEvent[Any]] = [] async for event in self._run_core( - session_messages, checkpoint_id, checkpoint_storage, streaming=False, **kwargs + session_messages, + checkpoint_id, + checkpoint_storage, + streaming=False, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, ): if event.type == "output" or event.type == "request_info": output_events.append(event) @@ -285,7 +314,8 @@ async def _run_stream_impl( session: AgentSession | None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, - **kwargs: Any, + function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, ) -> AsyncIterable[AgentResponseUpdate]: """Internal implementation of streaming execution. @@ -295,8 +325,8 @@ async def _run_stream_impl( session: The agent session for conversation context. checkpoint_id: ID of checkpoint to restore from. checkpoint_storage: Runtime checkpoint storage. - **kwargs: Additional keyword arguments passed through to the underlying - workflow and tool functions. + function_invocation_kwargs: Optional kwargs for tool invocations. + client_kwargs: Optional kwargs for chat client calls. Yields: AgentResponseUpdate objects representing the workflow execution progress. @@ -329,7 +359,12 @@ async def _run_stream_impl( session_messages: list[Message] = session_context.get_messages(include_input=True) all_updates: list[AgentResponseUpdate] = [] async for event in self._run_core( - session_messages, checkpoint_id, checkpoint_storage, streaming=True, **kwargs + session_messages, + checkpoint_id, + checkpoint_storage, + streaming=True, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, ): updates = self._convert_workflow_event_to_agent_response_updates(response_id, event) for update in updates: @@ -349,7 +384,8 @@ async def _run_core( checkpoint_id: str | None, checkpoint_storage: CheckpointStorage | None, streaming: bool, - **kwargs: Any, + function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, ) -> AsyncIterable[WorkflowEvent]: """Core implementation that yields workflow events for both streaming and non-streaming modes. @@ -358,8 +394,8 @@ async def _run_core( checkpoint_id: ID of checkpoint to restore from. checkpoint_storage: Runtime checkpoint storage. streaming: Whether to use streaming workflow methods. - **kwargs: Additional keyword arguments passed through to the underlying - workflow and tool functions. + function_invocation_kwargs: Optional kwargs for tool invocations. + client_kwargs: Optional kwargs for chat client calls. Yields: WorkflowEvent objects from the workflow execution. @@ -371,10 +407,19 @@ async def _run_core( if bool(self.pending_requests): function_responses = self._process_pending_requests(input_messages) if streaming: - async for event in self.workflow.run(responses=function_responses, stream=True, **kwargs): + async for event in self.workflow.run( + responses=function_responses, + stream=True, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, + ): yield event else: - for event in await self.workflow.run(responses=function_responses, **kwargs): + for event in await self.workflow.run( + responses=function_responses, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, + ): yield event elif checkpoint_id is not None: @@ -383,14 +428,16 @@ async def _run_core( stream=True, checkpoint_id=checkpoint_id, checkpoint_storage=checkpoint_storage, - **kwargs, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, ): yield event else: for event in await self.workflow.run( checkpoint_id=checkpoint_id, checkpoint_storage=checkpoint_storage, - **kwargs, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, ): yield event @@ -400,14 +447,16 @@ async def _run_core( message=input_messages, stream=True, checkpoint_storage=checkpoint_storage, - **kwargs, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, ): yield event else: for event in await self.workflow.run( message=input_messages, checkpoint_storage=checkpoint_storage, - **kwargs, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, ): yield event diff --git a/python/packages/core/agent_framework/_workflows/_agent_executor.py b/python/packages/core/agent_framework/_workflows/_agent_executor.py index 9653091f75..986b086fab 100644 --- a/python/packages/core/agent_framework/_workflows/_agent_executor.py +++ b/python/packages/core/agent_framework/_workflows/_agent_executor.py @@ -2,7 +2,7 @@ import logging import sys -from collections.abc import Awaitable, Callable, Mapping +from collections.abc import Awaitable, Callable from dataclasses import dataclass from typing import Any, Literal, cast @@ -14,7 +14,7 @@ from .._sessions import AgentSession from .._types import AgentResponse, AgentResponseUpdate, Message, ResponseStream from ._agent_utils import resolve_agent_id -from ._const import WORKFLOW_RUN_KWARGS_KEY +from ._const import GLOBAL_KWARGS_KEY, WORKFLOW_RUN_KWARGS_KEY from ._executor import Executor, handler from ._message_utils import normalize_messages_input from ._request_info_mixin import response_handler @@ -335,15 +335,17 @@ async def _run_agent(self, ctx: WorkflowContext[Never, AgentResponse]) -> AgentR Returns: The complete AgentResponse, or None if waiting for user input. """ - run_kwargs, options = self._prepare_agent_run_args(ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {})) + function_invocation_kwargs, client_kwargs = self._prepare_agent_run_args( + ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {}) + ) run_agent = cast(Callable[..., Awaitable[AgentResponse[Any]]], self._agent.run) response = await run_agent( self._cache, stream=False, session=self._session, - options=options, - **run_kwargs, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, ) await ctx.yield_output(response) @@ -365,7 +367,9 @@ async def _run_agent_streaming(self, ctx: WorkflowContext[Never, AgentResponseUp Returns: The complete AgentResponse, or None if waiting for user input. """ - run_kwargs, options = self._prepare_agent_run_args(ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {})) + function_invocation_kwargs, client_kwargs = self._prepare_agent_run_args( + ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {}) + ) updates: list[AgentResponseUpdate] = [] streamed_user_input_requests: list[Content] = [] @@ -374,8 +378,8 @@ async def _run_agent_streaming(self, ctx: WorkflowContext[Never, AgentResponseUp self._cache, stream=True, session=self._session, - options=options, - **run_kwargs, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, ) async for update in stream: updates.append(update) @@ -421,74 +425,58 @@ async def _run_agent_streaming(self, ctx: WorkflowContext[Never, AgentResponseUp return response - # Parameters that are explicitly passed to agent.run() by AgentExecutor - # and must not appear in **run_kwargs to avoid TypeError from duplicate values. - _RESERVED_RUN_PARAMS: frozenset[str] = frozenset({"session", "stream", "messages"}) + def _prepare_agent_run_args( + self, + raw_run_kwargs: dict[str, Any], + ) -> tuple[dict[str, Any] | None, dict[str, Any] | None]: + """Prepare function_invocation_kwargs and client_kwargs for agent.run(). + + Extracts ``function_invocation_kwargs`` and ``client_kwargs`` from the + workflow state dict, resolving per-executor entries using ``self.id``. The + ``__global__`` sentinel key (set by ``Workflow._resolve_invocation_kwargs``) denotes + global kwargs that apply to all executors. Per-executor dicts use executor IDs as + keys; this executor extracts only its own entry. + + Returns: + A 2-tuple of (function_invocation_kwargs, client_kwargs). + """ + fi_resolved = raw_run_kwargs.get("function_invocation_kwargs") + ci_resolved = raw_run_kwargs.get("client_kwargs") + + function_invocation_kwargs = self._resolve_executor_kwargs(fi_resolved) + client_kwargs = self._resolve_executor_kwargs(ci_resolved) + + return function_invocation_kwargs, client_kwargs - @staticmethod - def _prepare_agent_run_args(raw_run_kwargs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any] | None]: - """Prepare kwargs and options for agent.run(), avoiding duplicate option passing. + def _resolve_executor_kwargs(self, resolved: dict[str, Any] | None) -> dict[str, Any] | None: + """Extract this executor's kwargs from a resolved invocation kwargs dict. - Workflow-level kwargs are propagated to tool calls through - `options.additional_function_arguments`. If workflow kwargs include an - `options` key, merge it into the final options object and remove it from - kwargs before spreading `**run_kwargs`. + Args: + resolved: The resolved dict produced by ``Workflow._resolve_invocation_kwargs``, + containing either a ``__global__`` key (global kwargs) or executor-ID keys + (per-executor kwargs). May also be ``None``. - Reserved parameters (session, stream, messages) that are explicitly - managed by AgentExecutor are stripped from run_kwargs to prevent - ``TypeError: got multiple values for keyword argument`` collisions. + Returns: + The kwargs for this executor, or ``None`` if not applicable. """ - run_kwargs = dict(raw_run_kwargs) - - # Strip reserved params that AgentExecutor passes explicitly to agent.run(). - for key in AgentExecutor._RESERVED_RUN_PARAMS: - if key in run_kwargs: - logger.warning( - "Workflow kwarg '%s' is reserved by AgentExecutor and will be ignored. " - "Remove it from workflow.run() kwargs to silence this warning.", - key, - ) - run_kwargs.pop(key) - - options_from_workflow = run_kwargs.pop("options", None) - workflow_additional_args = run_kwargs.pop("additional_function_arguments", None) - - options: dict[str, Any] = {} - if options_from_workflow is not None: - if isinstance(options_from_workflow, Mapping): - options_from_workflow_map = cast(Mapping[str, Any], options_from_workflow) - for key, value in options_from_workflow_map.items(): - options[key] = value - else: - logger.warning( - "Ignoring non-mapping workflow 'options' kwarg of type %s for AgentExecutor %s.", - type(options_from_workflow).__name__, - AgentExecutor.__name__, - ) - - existing_additional_args = options.get("additional_function_arguments") - additional_args: dict[str, Any] - if isinstance(existing_additional_args, Mapping): - existing_additional_args_map = cast(Mapping[str, Any], existing_additional_args) - additional_args = {key: value for key, value in existing_additional_args_map.items()} + if not isinstance(resolved, dict): + return None + # Use explicit key-presence checks so that an empty per-executor dict is + # honoured (e.g. to clear kwargs) instead of falling through to global. + if self.id in resolved: + executor_kwargs = resolved[self.id] + elif GLOBAL_KWARGS_KEY in resolved: + executor_kwargs = resolved[GLOBAL_KWARGS_KEY] else: - additional_args = {} - - if workflow_additional_args is not None: - if isinstance(workflow_additional_args, Mapping): - workflow_additional_args_map = cast(Mapping[str, Any], workflow_additional_args) - additional_args.update({key: value for key, value in workflow_additional_args_map.items()}) - else: - logger.warning( - "Ignoring non-mapping workflow 'additional_function_arguments' kwarg of type %s for AgentExecutor %s.", # noqa: E501 - type(workflow_additional_args).__name__, - AgentExecutor.__name__, - ) - - if run_kwargs: - additional_args.update(run_kwargs) - - if additional_args: - options["additional_function_arguments"] = additional_args - - return run_kwargs, options or None + return None + + if not isinstance(executor_kwargs, dict): + logger.warning( + "Executor %s expected a dict for its kwargs, but got %s. Ignoring.", + self.id, + type(executor_kwargs), # type: ignore + ) + + return None + + return executor_kwargs # type: ignore diff --git a/python/packages/core/agent_framework/_workflows/_const.py b/python/packages/core/agent_framework/_workflows/_const.py index a8416af790..e83025bbdc 100644 --- a/python/packages/core/agent_framework/_workflows/_const.py +++ b/python/packages/core/agent_framework/_workflows/_const.py @@ -14,6 +14,10 @@ # to pass kwargs from workflow.run() through to agent.run() and @tool functions. WORKFLOW_RUN_KWARGS_KEY = "_workflow_run_kwargs" +# Sentinel key used in resolved invocation kwargs dicts to denote global kwargs +# that apply to all executors (as opposed to per-executor keyed entries). +GLOBAL_KWARGS_KEY = "__global__" + def INTERNAL_SOURCE_ID(executor_id: str) -> str: """Generate an internal source ID for a given executor.""" diff --git a/python/packages/core/agent_framework/_workflows/_workflow.py b/python/packages/core/agent_framework/_workflows/_workflow.py index cf030bf7b0..58050eece9 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow.py +++ b/python/packages/core/agent_framework/_workflows/_workflow.py @@ -10,14 +10,14 @@ import logging import types import uuid -from collections.abc import AsyncIterable, Awaitable, Callable, Sequence +from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, Sequence from typing import Any, Literal, overload from .._types import ResponseStream from ..observability import OtelAttr, capture_exception, create_workflow_span from ._agent import WorkflowAgent from ._checkpoint import CheckpointStorage -from ._const import DEFAULT_MAX_ITERATIONS, WORKFLOW_RUN_KWARGS_KEY +from ._const import DEFAULT_MAX_ITERATIONS, GLOBAL_KWARGS_KEY, WORKFLOW_RUN_KWARGS_KEY from ._edge import ( EdgeGroup, FanOutEdgeGroup, @@ -180,7 +180,6 @@ def __init__( description: str | None = None, max_iterations: int = DEFAULT_MAX_ITERATIONS, output_executors: list[str] | None = None, - **kwargs: Any, ): """Initialize the workflow with a list of edges. @@ -198,7 +197,6 @@ def __init__( WorkflowBuilder, this will be the description of the builder. output_executors: Optional list of executor IDs whose outputs will be considered workflow outputs. If None or empty, all executor outputs are treated as workflow outputs. - kwargs: Additional keyword arguments. Unused in this implementation. """ self.edge_groups = list(edge_groups) self.executors = dict(executors) @@ -300,7 +298,8 @@ async def _run_workflow_with_tracing( initial_executor_fn: Callable[[], Awaitable[None]] | None = None, reset_context: bool = True, streaming: bool = False, - run_kwargs: dict[str, Any] | None = None, + function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, ) -> AsyncIterable[WorkflowEvent]: """Private method to run workflow with proper tracing. @@ -311,7 +310,10 @@ async def _run_workflow_with_tracing( initial_executor_fn: Optional function to execute initial executor reset_context: Whether to reset the context for a new run streaming: Whether to enable streaming mode for agents - run_kwargs: Optional kwargs to store in State for agent invocations + function_invocation_kwargs: Optional kwargs to store in State for function + invocations in subagents + client_kwargs: Optional kwargs to store in State for chat client + invocations in subagents Yields: WorkflowEvent: The events generated during the workflow execution. @@ -350,8 +352,17 @@ async def _run_workflow_with_tracing( # Only overwrite when new kwargs are explicitly provided or state was # just cleared (fresh run). On continuation (reset_context=False) with # no new kwargs, preserve the kwargs from the original run. - if run_kwargs is not None: - self._state.set(WORKFLOW_RUN_KWARGS_KEY, run_kwargs) + if function_invocation_kwargs is not None or client_kwargs is not None: + combined_kwargs: dict[str, Any] = {} + if function_invocation_kwargs is not None: + combined_kwargs["function_invocation_kwargs"] = self._resolve_invocation_kwargs( + function_invocation_kwargs, "function_invocation_kwargs" + ) + if client_kwargs is not None: + combined_kwargs["client_kwargs"] = self._resolve_invocation_kwargs( + client_kwargs, "client_kwargs" + ) + self._state.set(WORKFLOW_RUN_KWARGS_KEY, combined_kwargs) elif reset_context: self._state.set(WORKFLOW_RUN_KWARGS_KEY, {}) self._state.commit() # Commit immediately so kwargs are available @@ -459,10 +470,11 @@ def run( message: Any | None = None, *, stream: Literal[True], - responses: dict[str, Any] | None = None, + responses: Mapping[str, Any] | None = None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, - **kwargs: Any, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, ) -> ResponseStream[WorkflowEvent, WorkflowRunResult]: ... @overload @@ -471,11 +483,12 @@ def run( message: Any | None = None, *, stream: Literal[False] = ..., - responses: dict[str, Any] | None = None, + responses: Mapping[str, Any] | None = None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, include_status_events: bool = False, - **kwargs: Any, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, ) -> Awaitable[WorkflowRunResult]: ... def run( @@ -483,11 +496,12 @@ def run( message: Any | None = None, *, stream: bool = False, - responses: dict[str, Any] | None = None, + responses: Mapping[str, Any] | None = None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, include_status_events: bool = False, - **kwargs: Any, + function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, ) -> ResponseStream[WorkflowEvent, WorkflowRunResult] | Awaitable[WorkflowRunResult]: """Run the workflow, optionally streaming events. @@ -509,7 +523,12 @@ def run( (restore then send responses). checkpoint_storage: Runtime checkpoint storage. include_status_events: Whether to include status events (non-streaming only). - **kwargs: Additional keyword arguments to pass through to agent invocations. + function_invocation_kwargs: Keyword arguments forwarded to tool invocations in + subagents. Either a mapping for agent name or agent executor id to kwargs, + or a flat mapping of kwargs for all tool invocations. + client_kwargs: Keyword arguments forwarded to chat client calls in + subagents. Either a mapping for agent name or agent executor id to kwargs, + or a flat mapping of kwargs for all chat client calls. Returns: When stream=True: A ResponseStream[WorkflowEvent, WorkflowRunResult] for @@ -530,7 +549,8 @@ def run( checkpoint_id=checkpoint_id, checkpoint_storage=checkpoint_storage, streaming=stream, - **kwargs, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, ), finalizer=functools.partial(self._finalize_events, include_status_events=include_status_events), cleanup_hooks=[ @@ -546,11 +566,12 @@ async def _run_core( self, message: Any | None = None, *, - responses: dict[str, Any] | None = None, + responses: Mapping[str, Any] | None = None, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, streaming: bool = False, - **kwargs: Any, + function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, ) -> AsyncIterable[WorkflowEvent]: """Single core execution path for both streaming and non-streaming modes. @@ -569,11 +590,8 @@ async def _run_core( initial_executor_fn=initial_executor_fn, reset_context=reset_context, streaming=streaming, - # Empty **kwargs (no caller-provided kwargs) is collapsed to None so that - # continuation calls without explicit kwargs preserve the original run's kwargs. - # A non-empty kwargs dict (even one with empty values like {"key": {}}) - # is passed through and will overwrite stored kwargs. - run_kwargs=kwargs if kwargs else None, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, ): if event.type == "output" and not self._should_yield_output_event(event): continue @@ -624,7 +642,7 @@ def _finalize_events( @staticmethod def _validate_run_params( message: Any | None, - responses: dict[str, Any] | None, + responses: Mapping[str, Any] | None, checkpoint_id: str | None, ) -> None: """Validate parameter combinations for run(). @@ -650,7 +668,7 @@ def _validate_run_params( def _resolve_execution_mode( self, message: Any | None, - responses: dict[str, Any] | None, + responses: Mapping[str, Any] | None, checkpoint_id: str | None, checkpoint_storage: CheckpointStorage | None, ) -> tuple[Callable[[], Awaitable[None]], bool]: @@ -680,7 +698,7 @@ async def _restore_and_send_responses( self, checkpoint_id: str, checkpoint_storage: CheckpointStorage | None, - responses: dict[str, Any], + responses: Mapping[str, Any], ) -> None: """Restore from a checkpoint then send responses to pending requests. @@ -700,7 +718,7 @@ async def _restore_and_send_responses( await self._runner.restore_from_checkpoint(checkpoint_id, checkpoint_storage) await self._send_responses_internal(responses) - async def _send_responses_internal(self, responses: dict[str, Any]) -> None: + async def _send_responses_internal(self, responses: Mapping[str, Any]) -> None: """Internal method to validate and send responses to the executors.""" pending_requests = await self._runner_context.get_pending_request_info_events() if not pending_requests: @@ -739,6 +757,44 @@ def _get_executor_by_id(self, executor_id: str) -> Executor: raise ValueError(f"Executor with ID {executor_id} not found.") return self.executors[executor_id] + def _resolve_invocation_kwargs( + self, + kwargs: Mapping[str, Any], + param_name: str, + ) -> dict[str, Any]: + """Resolve invocation kwargs into a normalized per-executor or global format. + + Detects whether the provided kwargs dict uses per-executor targeting by checking + if any top-level key matches a known executor ID in the workflow. If at least one + key matches, all entries are treated as per-executor. Otherwise the dict is treated + as global kwargs that apply to every executor. + + Args: + kwargs: The raw invocation kwargs from the caller. + param_name: The parameter name (for logging), e.g. ``"function_invocation_kwargs"``. + + Returns: + A dict with either: + - ``{"__global__": }`` for global kwargs, or + - The original dict unchanged for per-executor kwargs. + """ + executor_ids = set(self.executors.keys()) + matched_ids = kwargs.keys() & executor_ids + if matched_ids: + logger.info( + "Detected per-executor %s: executor ID(s) %s found in keys. " + "All entries will be treated as per-executor.", + param_name, + matched_ids, + ) + return dict(kwargs) + + logger.info( + "No executor IDs found in %s keys; treating as global kwargs for all executors.", + param_name, + ) + return {GLOBAL_KWARGS_KEY: dict(kwargs)} + def _should_yield_output_event(self, event: WorkflowEvent[Any]) -> bool: """Determine if an output event should be yielded as a workflow output. diff --git a/python/packages/core/agent_framework/_workflows/_workflow_executor.py b/python/packages/core/agent_framework/_workflows/_workflow_executor.py index e9e4196bfd..afb6145251 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_executor.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_executor.py @@ -12,7 +12,7 @@ from ._workflow import Workflow from ._checkpoint_encoding import decode_checkpoint_value -from ._const import WORKFLOW_RUN_KWARGS_KEY +from ._const import GLOBAL_KWARGS_KEY, WORKFLOW_RUN_KWARGS_KEY from ._events import ( WorkflowEvent, WorkflowRunState, @@ -387,8 +387,28 @@ async def process_workflow(self, input_data: object, ctx: WorkflowContext[Any]) # Get kwargs from parent workflow's State to propagate to subworkflow parent_kwargs: dict[str, Any] = ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {}) + # Extract invocation kwargs recognised by Workflow.run() + # The state stores resolved format (with __global__ wrapper for global kwargs). + # Unwrap __global__ before passing to the subworkflow so it gets re-resolved + # against the subworkflow's own executor IDs. + fi_kwargs: dict[str, Any] | None = None + ci_kwargs: dict[str, Any] | None = None + for key in ("function_invocation_kwargs", "client_kwargs"): + resolved = parent_kwargs.get(key) + if isinstance(resolved, dict): + # Unwrap global sentinel; pass per-executor dicts as-is + unwrapped: dict[str, Any] = resolved.get(GLOBAL_KWARGS_KEY, resolved) # type: ignore + if key == "function_invocation_kwargs": + fi_kwargs = unwrapped # type: ignore + else: + ci_kwargs = unwrapped # type: ignore + # Run the sub-workflow and collect all events, passing parent kwargs - result = await self.workflow.run(input_data, **parent_kwargs) + result = await self.workflow.run( + input_data, + function_invocation_kwargs=fi_kwargs, # type: ignore + client_kwargs=ci_kwargs, # type: ignore + ) logger.debug( f"WorkflowExecutor {self.id} sub-workflow {self.workflow.id} " diff --git a/python/packages/core/tests/workflow/test_agent_executor.py b/python/packages/core/tests/workflow/test_agent_executor.py index 6298a8963d..f6b7eb7f31 100644 --- a/python/packages/core/tests/workflow/test_agent_executor.py +++ b/python/packages/core/tests/workflow/test_agent_executor.py @@ -1,8 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. -import logging from collections.abc import AsyncIterable, Awaitable -from typing import TYPE_CHECKING, Any, Literal, overload +from typing import Any, Literal, overload import pytest @@ -22,9 +21,7 @@ ) from agent_framework._workflows._agent_executor import AgentExecutorResponse from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage - -if TYPE_CHECKING: - from _pytest.logging import LogCaptureFixture +from agent_framework._workflows._const import GLOBAL_KWARGS_KEY class _CountingAgent(BaseAgent): @@ -309,87 +306,28 @@ async def test_agent_executor_save_and_restore_state_directly() -> None: assert restored_session.session_id == session.session_id -async def test_agent_executor_run_with_session_kwarg_does_not_raise() -> None: - """Passing session= via workflow.run() should not cause a duplicate-keyword TypeError (#4295).""" - agent = _CountingAgent(id="session_kwarg_agent", name="SessionKwargAgent") - executor = AgentExecutor(agent, id="session_kwarg_exec") - workflow = WorkflowBuilder(start_executor=executor).build() - - # This previously raised: TypeError: run() got multiple values for keyword argument 'session' - result = await workflow.run("hello", session="user-supplied-value") - assert result is not None - assert agent.call_count == 1 +async def test_prepare_agent_run_args_extracts_invocation_kwargs() -> None: + """_prepare_agent_run_args extracts function_invocation_kwargs and client_kwargs.""" + agent = _CountingAgent(id="test_agent", name="TestAgent") + executor = AgentExecutor(agent, id="test_exec") - -async def test_agent_executor_run_streaming_with_stream_kwarg_does_not_raise() -> None: - """Passing stream= via workflow.run() kwargs should not cause a duplicate-keyword TypeError.""" - agent = _CountingAgent(id="stream_kwarg_agent", name="StreamKwargAgent") - executor = AgentExecutor(agent, id="stream_kwarg_exec") - workflow = WorkflowBuilder(start_executor=executor).build() - - # stream=True at workflow level triggers streaming mode (returns async iterable) - events: list[WorkflowEvent] = [] - async for event in workflow.run("hello", stream=True): - events.append(event) - assert len(events) > 0 - assert agent.call_count == 1 - - -@pytest.mark.parametrize("reserved_kwarg", ["session", "stream", "messages"]) -async def test_prepare_agent_run_args_strips_reserved_kwargs(reserved_kwarg: str, caplog: "LogCaptureFixture") -> None: - """_prepare_agent_run_args must remove reserved kwargs and log a warning.""" raw: dict[str, Any] = { - reserved_kwarg: "should-be-stripped", - "custom_key": "keep-me", + "function_invocation_kwargs": {"__global__": {"key": "fi_val"}}, + "client_kwargs": {"__global__": {"key": "ci_val"}}, } + fi_kwargs, ci_kwargs = executor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage] + assert fi_kwargs == {"key": "fi_val"} + assert ci_kwargs == {"key": "ci_val"} - with caplog.at_level(logging.WARNING): - run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage] - - assert reserved_kwarg not in run_kwargs - assert "custom_key" in run_kwargs - assert options is not None - assert options["additional_function_arguments"]["custom_key"] == "keep-me" - assert any(reserved_kwarg in record.message for record in caplog.records) +async def test_prepare_agent_run_args_returns_none_when_no_kwargs() -> None: + """_prepare_agent_run_args returns None for both when raw dict has no invocation kwargs.""" + agent = _CountingAgent(id="test_agent", name="TestAgent") + executor = AgentExecutor(agent, id="test_exec") -async def test_prepare_agent_run_args_preserves_non_reserved_kwargs() -> None: - """Non-reserved workflow kwargs should pass through unchanged.""" - raw: dict[str, Any] = {"custom_param": "value", "another": 42} - run_kwargs, _options = AgentExecutor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage] - assert run_kwargs["custom_param"] == "value" - assert run_kwargs["another"] == 42 - - -async def test_prepare_agent_run_args_strips_all_reserved_kwargs_at_once( - caplog: "LogCaptureFixture", -) -> None: - """All reserved kwargs should be stripped when supplied together, each emitting a warning.""" - raw: dict[str, Any] = {"session": "x", "stream": True, "messages": [], "custom": 1} - - with caplog.at_level(logging.WARNING): - run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage] - - assert "session" not in run_kwargs - assert "stream" not in run_kwargs - assert "messages" not in run_kwargs - assert run_kwargs["custom"] == 1 - assert options is not None - assert options["additional_function_arguments"]["custom"] == 1 - - warned_keys = {r.message.split("'")[1] for r in caplog.records if "reserved" in r.message.lower()} - assert warned_keys == {"session", "stream", "messages"} - - -async def test_agent_executor_run_with_messages_kwarg_does_not_raise() -> None: - """Passing messages= via workflow.run() kwargs should not cause a duplicate-keyword TypeError.""" - agent = _CountingAgent(id="messages_kwarg_agent", name="MessagesKwargAgent") - executor = AgentExecutor(agent, id="messages_kwarg_exec") - workflow = WorkflowBuilder(start_executor=executor).build() - - result = await workflow.run("hello", messages=["stale"]) - assert result is not None - assert agent.call_count == 1 + fi_kwargs, ci_kwargs = executor._prepare_agent_run_args({}) # pyright: ignore[reportPrivateUsage] + assert fi_kwargs is None + assert ci_kwargs is None class _NonCopyableRaw: @@ -638,3 +576,126 @@ async def test_checkpoint_restore_works_without_context_mode_in_state() -> None: assert cache[0].text == "cached msg" # context_mode should remain as configured in the constructor, not changed by restore assert executor._context_mode == "last_agent" # pyright: ignore[reportPrivateUsage] + + +# --------------------------------------------------------------------------- +# Per-executor kwargs resolution tests +# --------------------------------------------------------------------------- + + +async def test_resolve_executor_kwargs_returns_global_kwargs() -> None: + """_resolve_executor_kwargs with the global kwargs key returns the global kwargs.""" + agent = _CountingAgent(id="a", name="A") + executor = AgentExecutor(agent, id="exec_a") + + resolved = {GLOBAL_KWARGS_KEY: {"tool_param": "value"}} + result = executor._resolve_executor_kwargs(resolved) # pyright: ignore[reportPrivateUsage] + assert result == {"tool_param": "value"} + + +async def test_resolve_executor_kwargs_returns_per_executor_kwargs() -> None: + """_resolve_executor_kwargs with matching executor ID returns that executor's kwargs.""" + agent = _CountingAgent(id="a", name="A") + executor = AgentExecutor(agent, id="exec_a") + + resolved = {"exec_a": {"my_param": 42}, "exec_b": {"other_param": 99}} + result = executor._resolve_executor_kwargs(resolved) # pyright: ignore[reportPrivateUsage] + assert result == {"my_param": 42} + + +async def test_resolve_executor_kwargs_returns_none_for_unmatched_per_executor() -> None: + """_resolve_executor_kwargs returns None when per-executor dict has no matching ID.""" + agent = _CountingAgent(id="a", name="A") + executor = AgentExecutor(agent, id="exec_c") + + resolved = {"exec_a": {"my_param": 42}, "exec_b": {"other_param": 99}} + result = executor._resolve_executor_kwargs(resolved) # pyright: ignore[reportPrivateUsage] + assert result is None + + +async def test_resolve_executor_kwargs_returns_none_for_none_input() -> None: + """_resolve_executor_kwargs returns None when input is None.""" + agent = _CountingAgent(id="a", name="A") + executor = AgentExecutor(agent, id="exec_a") + + result = executor._resolve_executor_kwargs(None) # pyright: ignore[reportPrivateUsage] + assert result is None + + +async def test_resolve_executor_kwargs_prefers_executor_id_over_global() -> None: + """_resolve_executor_kwargs prefers executor-specific entry over __global__.""" + agent = _CountingAgent(id="a", name="A") + executor = AgentExecutor(agent, id="exec_a") + + # Dict has both a per-executor entry and a global entry + resolved = {"exec_a": {"specific": True}, GLOBAL_KWARGS_KEY: {"global": True}} + result = executor._resolve_executor_kwargs(resolved) # pyright: ignore[reportPrivateUsage] + assert result == {"specific": True} + + +async def test_prepare_agent_run_args_extracts_function_invocation_kwargs() -> None: + """_prepare_agent_run_args extracts function_invocation_kwargs from the state dict.""" + agent = _CountingAgent(id="a", name="A") + executor = AgentExecutor(agent, id="exec_a") + + raw: dict[str, Any] = { + "function_invocation_kwargs": {GLOBAL_KWARGS_KEY: {"tool_key": "tool_val"}}, + } + fi_kwargs, client_kwargs = executor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage] + assert fi_kwargs == {"tool_key": "tool_val"} + assert client_kwargs is None + + +async def test_prepare_agent_run_args_extracts_client_kwargs() -> None: + """_prepare_agent_run_args extracts client_kwargs from the state dict.""" + agent = _CountingAgent(id="a", name="A") + executor = AgentExecutor(agent, id="exec_a") + + raw: dict[str, Any] = { + "client_kwargs": {GLOBAL_KWARGS_KEY: {"model": "gpt-4"}}, + } + fi_kwargs, client_kwargs = executor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage] + assert fi_kwargs is None + assert client_kwargs == {"model": "gpt-4"} + + +async def test_prepare_agent_run_args_per_executor_resolution() -> None: + """_prepare_agent_run_args resolves per-executor function_invocation_kwargs using self.id.""" + agent = _CountingAgent(id="a", name="A") + executor = AgentExecutor(agent, id="exec_a") + + raw: dict[str, Any] = { + "function_invocation_kwargs": { + "exec_a": {"my_tool_key": "my_val"}, + "exec_b": {"other_tool_key": "other_val"}, + }, + } + fi_kwargs, _ = executor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage] + assert fi_kwargs == {"my_tool_key": "my_val"} + + +async def test_prepare_agent_run_args_per_executor_no_match() -> None: + """_prepare_agent_run_args returns None for function_invocation_kwargs when executor ID not found.""" + agent = _CountingAgent(id="a", name="A") + executor = AgentExecutor(agent, id="exec_c") + + raw: dict[str, Any] = { + "function_invocation_kwargs": { + "exec_a": {"my_tool_key": "my_val"}, + "exec_b": {"other_tool_key": "other_val"}, + }, + } + fi_kwargs, _ = executor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage] + assert fi_kwargs is None + + +async def test_resolve_executor_kwargs_empty_per_executor_does_not_fallback_to_global() -> None: + """An explicit empty per-executor dict should not fall through to global kwargs.""" + agent = _CountingAgent(id="a", name="A") + executor = AgentExecutor(agent, id="exec_a") + + # Per-executor entry for exec_a is empty, but global has values. + # The empty dict should be honoured (no fallback to global). + resolved = {"exec_a": {}, GLOBAL_KWARGS_KEY: {"global_key": "global_val"}} + result = executor._resolve_executor_kwargs(resolved) # pyright: ignore[reportPrivateUsage] + assert result == {} diff --git a/python/packages/core/tests/workflow/test_workflow_kwargs.py b/python/packages/core/tests/workflow/test_workflow_kwargs.py index d315f75f85..1b14d53375 100644 --- a/python/packages/core/tests/workflow/test_workflow_kwargs.py +++ b/python/packages/core/tests/workflow/test_workflow_kwargs.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. from collections.abc import AsyncIterable, Awaitable -from typing import Annotated, Any, Literal, overload +from typing import TYPE_CHECKING, Any, Literal, overload import pytest @@ -15,7 +15,6 @@ Message, ResponseStream, WorkflowRunState, - tool, ) from agent_framework._workflows._const import WORKFLOW_RUN_KWARGS_KEY from agent_framework.orchestrations import ( @@ -26,22 +25,11 @@ SequentialBuilder, ) -# Track kwargs received by tools during test execution -_received_kwargs: list[dict[str, Any]] = [] - - -@tool(approval_mode="never_require") -def tool_with_kwargs( - action: Annotated[str, "The action to perform"], - **kwargs: Any, -) -> str: - """A test tool that captures kwargs for verification.""" - _received_kwargs.append(dict(kwargs)) - custom_data = kwargs.get("custom_data", {}) - user_token = kwargs.get("user_token", {}) - return f"Executed {action} with custom_data={custom_data}, user={user_token.get('user_name', 'unknown')}" +if TYPE_CHECKING: + from _pytest.logging import LogCaptureFixture +# Track kwargs received by tools during test execution class _KwargsCapturingAgent(BaseAgent): """Test agent that captures kwargs passed to run.""" @@ -92,76 +80,20 @@ async def _run() -> AgentResponse: return _run() -class _OptionsAwareAgent(BaseAgent): - """Test agent that captures explicit `options` and kwargs passed to run().""" - - captured_options: list[dict[str, Any] | None] - captured_kwargs: list[dict[str, Any]] - - def __init__(self, name: str = "options_agent") -> None: - super().__init__(name=name, description="Test agent for options capture") - self.captured_options = [] - self.captured_kwargs = [] - - @overload - def run( - self, - messages: AgentRunInputs | None = ..., - *, - stream: Literal[False] = ..., - session: AgentSession | None = ..., - **kwargs: Any, - ) -> Awaitable[AgentResponse[Any]]: ... - @overload - def run( - self, - messages: AgentRunInputs | None = ..., - *, - stream: Literal[True], - session: AgentSession | None = ..., - **kwargs: Any, - ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... - - def run( - self, - messages: AgentRunInputs | None = None, - *, - stream: bool = False, - session: AgentSession | None = None, - options: dict[str, Any] | None = None, - **kwargs: Any, - ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: - self.captured_options.append(dict(options) if options is not None else None) - self.captured_kwargs.append(dict(kwargs)) - if stream: - - async def _stream() -> AsyncIterable[AgentResponseUpdate]: - yield AgentResponseUpdate(contents=[Content.from_text(text=f"{self.name} response")]) - - return ResponseStream(_stream(), finalizer=AgentResponse.from_updates) - - async def _run() -> AgentResponse: - return AgentResponse(messages=[Message("assistant", [f"{self.name} response"])]) - - return _run() - - # region Sequential Builder Tests async def test_sequential_kwargs_flow_to_agent() -> None: - """Test that kwargs passed to SequentialBuilder workflow flow through to agent.""" + """Test that function_invocation_kwargs passed to SequentialBuilder workflow flow through to agent.""" agent = _KwargsCapturingAgent(name="seq_agent") workflow = SequentialBuilder(participants=[agent]).build() - custom_data = {"endpoint": "https://api.example.com", "version": "v1"} - user_token = {"user_name": "alice", "access_level": "admin"} + fi_kwargs = {"endpoint": "https://api.example.com", "version": "v1"} async for event in workflow.run( "test message", stream=True, - custom_data=custom_data, - user_token=user_token, + function_invocation_kwargs=fi_kwargs, ): if event.type == "status" and event.state == WorkflowRunState.IDLE: break @@ -169,140 +101,53 @@ async def test_sequential_kwargs_flow_to_agent() -> None: # Verify agent received kwargs assert len(agent.captured_kwargs) >= 1, "Agent should have been invoked at least once" received = agent.captured_kwargs[0] - assert "custom_data" in received, "Agent should receive custom_data kwarg" - assert "user_token" in received, "Agent should receive user_token kwarg" - assert received["custom_data"] == custom_data - assert received["user_token"] == user_token + assert received.get("function_invocation_kwargs") == fi_kwargs async def test_sequential_kwargs_flow_to_multiple_agents() -> None: - """Test that kwargs flow to all agents in a sequential workflow.""" + """Test that function_invocation_kwargs flow to all agents in a sequential workflow.""" agent1 = _KwargsCapturingAgent(name="agent1") agent2 = _KwargsCapturingAgent(name="agent2") workflow = SequentialBuilder(participants=[agent1, agent2]).build() - custom_data = {"key": "value"} + fi_kwargs = {"key": "value"} - async for event in workflow.run("test", custom_data=custom_data, stream=True): + async for event in workflow.run("test", function_invocation_kwargs=fi_kwargs, stream=True): if event.type == "status" and event.state == WorkflowRunState.IDLE: break # Both agents should have received kwargs assert len(agent1.captured_kwargs) >= 1, "First agent should be invoked" assert len(agent2.captured_kwargs) >= 1, "Second agent should be invoked" - assert agent1.captured_kwargs[0].get("custom_data") == custom_data - assert agent2.captured_kwargs[0].get("custom_data") == custom_data + assert agent1.captured_kwargs[0].get("function_invocation_kwargs") == fi_kwargs + assert agent2.captured_kwargs[0].get("function_invocation_kwargs") == fi_kwargs async def test_sequential_run_kwargs_flow() -> None: - """Test that kwargs flow through workflow.run() (non-streaming).""" + """Test that function_invocation_kwargs flow through workflow.run() (non-streaming).""" agent = _KwargsCapturingAgent(name="run_agent") workflow = SequentialBuilder(participants=[agent]).build() - _ = await workflow.run("test message", custom_data={"test": True}) - - assert len(agent.captured_kwargs) >= 1 - assert agent.captured_kwargs[0].get("custom_data") == {"test": True} - - -async def test_sequential_run_options_does_not_conflict_with_agent_options() -> None: - """Test workflow.run(options=...) does not conflict with Agent.run(options=...).""" - agent = _OptionsAwareAgent(name="options_agent") - workflow = SequentialBuilder(participants=[agent]).build() - - custom_data = {"session_id": "abc123"} - user_token = {"user_name": "alice"} - provided_options = { - "store": False, - "additional_function_arguments": {"source": "workflow-options"}, - } - - async for event in workflow.run( - "test message", - stream=True, - options=provided_options, - custom_data=custom_data, - user_token=user_token, - ): - if event.type == "status" and event.state == WorkflowRunState.IDLE: - break - - assert len(agent.captured_options) >= 1 - captured_options: dict[str, Any] | None = agent.captured_options[0] - assert captured_options is not None - assert captured_options.get("store") is False - - additional_args: Any = captured_options.get("additional_function_arguments") - assert isinstance(additional_args, dict) - assert additional_args.get("source") == "workflow-options" # pyright: ignore[reportUnknownMemberType] - assert additional_args.get("custom_data") == custom_data # pyright: ignore[reportUnknownMemberType] - assert additional_args.get("user_token") == user_token # pyright: ignore[reportUnknownMemberType] + _ = await workflow.run("test message", function_invocation_kwargs={"test": True}) - # "options" should be passed once via the dedicated options parameter, - # not duplicated in **kwargs. assert len(agent.captured_kwargs) >= 1 - captured_kwargs = agent.captured_kwargs[0] - assert "options" not in captured_kwargs - assert captured_kwargs.get("custom_data") == custom_data - assert captured_kwargs.get("user_token") == user_token + assert agent.captured_kwargs[0].get("function_invocation_kwargs") == {"test": True} -async def test_sequential_run_additional_function_arguments_flattened() -> None: - """Test workflow.run(additional_function_arguments=...) maps directly to tool kwargs.""" - agent = _OptionsAwareAgent(name="options_agent") +async def test_sequential_run_non_streaming_kwargs_flow() -> None: + """Test workflow.run(function_invocation_kwargs=...) non-streaming path.""" + agent = _KwargsCapturingAgent(name="options_agent") workflow = SequentialBuilder(participants=[agent]).build() - custom_data = {"session_id": "abc123"} - user_token = {"user_name": "alice"} + fi_kwargs = {"session_id": "abc123"} - async for event in workflow.run( + _ = await workflow.run( "test message", - stream=True, - additional_function_arguments={"custom_data": custom_data, "user_token": user_token}, - ): - if event.type == "status" and event.state == WorkflowRunState.IDLE: - break - - assert len(agent.captured_options) >= 1 - captured_options: dict[str, Any] | None = agent.captured_options[0] - assert captured_options is not None - - additional_args: Any = captured_options.get("additional_function_arguments") - assert isinstance(additional_args, dict) - assert additional_args.get("custom_data") == custom_data # pyright: ignore[reportUnknownMemberType] - assert additional_args.get("user_token") == user_token # pyright: ignore[reportUnknownMemberType] - assert "additional_function_arguments" not in additional_args + function_invocation_kwargs=fi_kwargs, + ) assert len(agent.captured_kwargs) >= 1 - captured_kwargs = agent.captured_kwargs[0] - assert "additional_function_arguments" not in captured_kwargs - - -async def test_sequential_run_additional_function_arguments_merges_with_options() -> None: - """Test workflow additional_function_arguments merges with workflow options.""" - agent = _OptionsAwareAgent(name="options_agent") - workflow = SequentialBuilder(participants=[agent]).build() - - async for event in workflow.run( - "test message", - stream=True, - options={"additional_function_arguments": {"source": "workflow-options"}}, - additional_function_arguments={"custom_data": {"session_id": "abc123"}}, - user_token={"user_name": "alice"}, - ): - if event.type == "status" and event.state == WorkflowRunState.IDLE: - break - - assert len(agent.captured_options) >= 1 - captured_options: dict[str, Any] | None = agent.captured_options[0] - assert captured_options is not None - - additional_args: Any = captured_options.get("additional_function_arguments") - assert isinstance(additional_args, dict) - assert additional_args.get("source") == "workflow-options" # pyright: ignore[reportUnknownMemberType] - assert additional_args.get("custom_data") == {"session_id": "abc123"} # pyright: ignore[reportUnknownMemberType] - assert additional_args.get("user_token") == {"user_name": "alice"} # pyright: ignore[reportUnknownMemberType] - assert "additional_function_arguments" not in additional_args + assert agent.captured_kwargs[0].get("function_invocation_kwargs") == fi_kwargs # endregion @@ -312,19 +157,17 @@ async def test_sequential_run_additional_function_arguments_merges_with_options( async def test_concurrent_kwargs_flow_to_agents() -> None: - """Test that kwargs flow to all agents in a concurrent workflow.""" + """Test that function_invocation_kwargs flow to all agents in a concurrent workflow.""" agent1 = _KwargsCapturingAgent(name="concurrent1") agent2 = _KwargsCapturingAgent(name="concurrent2") workflow = ConcurrentBuilder(participants=[agent1, agent2]).build() - custom_data = {"batch_id": "123"} - user_token = {"user_name": "bob"} + fi_kwargs = {"batch_id": "123"} async for event in workflow.run( "concurrent test", stream=True, - custom_data=custom_data, - user_token=user_token, + function_invocation_kwargs=fi_kwargs, ): if event.type == "status" and event.state == WorkflowRunState.IDLE: break @@ -335,8 +178,7 @@ async def test_concurrent_kwargs_flow_to_agents() -> None: for agent in [agent1, agent2]: received = agent.captured_kwargs[0] - assert received.get("custom_data") == custom_data - assert received.get("user_token") == user_token + assert received.get("function_invocation_kwargs") == fi_kwargs # endregion @@ -346,7 +188,7 @@ async def test_concurrent_kwargs_flow_to_agents() -> None: async def test_groupchat_kwargs_flow_to_agents() -> None: - """Test that kwargs flow to agents in a group chat workflow.""" + """Test that function_invocation_kwargs flow to agents in a group chat workflow.""" agent1 = _KwargsCapturingAgent(name="chat1") agent2 = _KwargsCapturingAgent(name="chat2") @@ -368,9 +210,9 @@ def simple_selector(state: GroupChatState) -> str: selection_func=simple_selector, ).build() - custom_data = {"session_id": "group123"} + fi_kwargs = {"session_id": "group123"} - async for event in workflow.run("group chat test", custom_data=custom_data, stream=True): + async for event in workflow.run("group chat test", function_invocation_kwargs=fi_kwargs, stream=True): if event.type == "status" and event.state == WorkflowRunState.IDLE: break @@ -379,7 +221,7 @@ def simple_selector(state: GroupChatState) -> str: assert len(all_kwargs) >= 1, "At least one agent should be invoked in group chat" for received in all_kwargs: - assert received.get("custom_data") == custom_data + assert received.get("function_invocation_kwargs") == fi_kwargs # endregion @@ -389,7 +231,7 @@ def simple_selector(state: GroupChatState) -> str: async def test_kwargs_stored_in_state() -> None: - """Test that kwargs are stored in State with the correct key.""" + """Test that function_invocation_kwargs are stored in State with the correct key.""" from agent_framework import Executor, WorkflowContext, handler stored_kwargs: dict[str, Any] | None = None @@ -404,13 +246,12 @@ async def inspect(self, msgs: list[Message], ctx: WorkflowContext[list[Message]] inspector = _StateInspector(id="inspector") workflow = SequentialBuilder(participants=[inspector]).build() - async for event in workflow.run("test", my_kwarg="my_value", another=123, stream=True): + async for event in workflow.run("test", function_invocation_kwargs={"my_kwarg": "my_value"}, stream=True): if event.type == "status" and event.state == WorkflowRunState.IDLE: break assert stored_kwargs is not None, "kwargs should be stored in State" - assert stored_kwargs.get("my_kwarg") == "my_value" - assert stored_kwargs.get("another") == 123 + assert "function_invocation_kwargs" in stored_kwargs async def test_empty_kwargs_stored_as_empty_dict() -> None: @@ -444,24 +285,8 @@ async def check(self, msgs: list[Message], ctx: WorkflowContext[list[Message]]) # region Edge Cases -async def test_kwargs_with_none_values() -> None: - """Test that kwargs with None values are passed through correctly.""" - agent = _KwargsCapturingAgent(name="none_test") - workflow = SequentialBuilder(participants=[agent]).build() - - async for event in workflow.run("test", optional_param=None, other_param="value", stream=True): - if event.type == "status" and event.state == WorkflowRunState.IDLE: - break - - assert len(agent.captured_kwargs) >= 1 - received = agent.captured_kwargs[0] - assert "optional_param" in received - assert received["optional_param"] is None - assert received["other_param"] == "value" - - async def test_kwargs_with_complex_nested_data() -> None: - """Test that complex nested data structures flow through correctly.""" + """Test that complex nested data structures flow through correctly via function_invocation_kwargs.""" agent = _KwargsCapturingAgent(name="nested_test") workflow = SequentialBuilder(participants=[agent]).build() @@ -476,17 +301,17 @@ async def test_kwargs_with_complex_nested_data() -> None: "tuple_like": [1, 2, 3], } - async for event in workflow.run("test", complex_data=complex_data, stream=True): + async for event in workflow.run("test", function_invocation_kwargs=complex_data, stream=True): if event.type == "status" and event.state == WorkflowRunState.IDLE: break assert len(agent.captured_kwargs) >= 1 received = agent.captured_kwargs[0] - assert received.get("complex_data") == complex_data + assert received.get("function_invocation_kwargs") == complex_data async def test_kwargs_preserved_on_response_continuation() -> None: - """Test that run kwargs are preserved when continuing a paused workflow with run(responses=...). + """Test that function_invocation_kwargs are preserved when continuing a paused workflow with run(responses=...). Regression test for #4293: kwargs were overwritten to {} on continuation calls. """ @@ -550,8 +375,9 @@ async def _done() -> AgentResponse: agent = _ApprovalCapturingAgent() workflow = WorkflowBuilder(start_executor=agent, output_executors=[agent]).build() - # Initial run with kwargs — workflow should pause for approval - result = await workflow.run("go", custom_data={"token": "abc"}) + # Initial run with function_invocation_kwargs — workflow should pause for approval + fi_kwargs = {"token": "abc"} + result = await workflow.run("go", function_invocation_kwargs=fi_kwargs) request_events = result.get_request_info_events() assert len(request_events) == 1 @@ -559,174 +385,14 @@ async def _done() -> AgentResponse: approval = request_events[0] await workflow.run(responses={approval.request_id: approval.data.to_function_approval_response(True)}) - # Both calls should have received the original kwargs + # Both calls should have received the original function_invocation_kwargs assert len(agent.captured_kwargs) == 2 - assert agent.captured_kwargs[0].get("custom_data") == {"token": "abc"} - assert agent.captured_kwargs[1].get("custom_data") == {"token": "abc"}, ( + assert agent.captured_kwargs[0].get("function_invocation_kwargs") == fi_kwargs + assert agent.captured_kwargs[1].get("function_invocation_kwargs") == fi_kwargs, ( f"kwargs should be preserved on continuation, got: {agent.captured_kwargs[1]}" ) -async def test_kwargs_overridden_on_response_continuation() -> None: - """Test that explicitly provided kwargs override prior kwargs on continuation.""" - - class _ApprovalCapturingAgent(BaseAgent): - captured_kwargs: list[dict[str, Any]] - _asked: bool - - def __init__(self) -> None: - super().__init__(name="approval_agent", description="Test agent") - self.captured_kwargs = [] - self._asked = False - - @overload - def run( - self, - messages: AgentRunInputs | None = ..., - *, - stream: Literal[False] = ..., - session: AgentSession | None = ..., - **kwargs: Any, - ) -> Awaitable[AgentResponse[Any]]: ... - @overload - def run( - self, - messages: AgentRunInputs | None = ..., - *, - stream: Literal[True], - session: AgentSession | None = ..., - **kwargs: Any, - ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... - - def run( - self, - messages: AgentRunInputs | None = None, - *, - stream: bool = False, - session: AgentSession | None = None, - **kwargs: Any, - ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: - self.captured_kwargs.append(dict(kwargs)) - if not self._asked: - self._asked = True - - async def _pause() -> AgentResponse: - call = Content.from_function_call(call_id="c1", name="do_thing", arguments="{}") - req = Content.from_function_approval_request(id="r1", function_call=call) - return AgentResponse(messages=[Message("assistant", [req])]) - - return _pause() - - async def _done() -> AgentResponse: - return AgentResponse(messages=[Message("assistant", ["done"])]) - - return _done() - - from agent_framework import WorkflowBuilder - - agent = _ApprovalCapturingAgent() - workflow = WorkflowBuilder(start_executor=agent, output_executors=[agent]).build() - - result = await workflow.run("go", custom_data={"token": "abc"}) - request_events = result.get_request_info_events() - approval = request_events[0] - - # Continue with responses AND new kwargs — should override - await workflow.run( - responses={approval.request_id: approval.data.to_function_approval_response(True)}, - custom_data={"token": "xyz"}, - ) - - assert len(agent.captured_kwargs) == 2 - assert agent.captured_kwargs[0].get("custom_data") == {"token": "abc"} - assert agent.captured_kwargs[1].get("custom_data") == {"token": "xyz"} - - -async def test_kwargs_empty_value_passed_on_continuation() -> None: - """Test that explicitly passing a kwarg with an empty value on continuation overrides prior kwargs. - - This exercises the boundary where the caller provides kwargs (e.g., custom_data={}) - that differ from the original run. Because the kwargs dict is non-empty (it has a key), - it passes the `kwargs if kwargs else None` gate and the `is not None` check, so it - overwrites the previously stored kwargs. - """ - - class _ApprovalCapturingAgent(BaseAgent): - captured_kwargs: list[dict[str, Any]] - _asked: bool - - def __init__(self) -> None: - super().__init__(name="approval_agent", description="Test agent") - self.captured_kwargs = [] - self._asked = False - - @overload - def run( - self, - messages: AgentRunInputs | None = ..., - *, - stream: Literal[False] = ..., - session: AgentSession | None = ..., - **kwargs: Any, - ) -> Awaitable[AgentResponse[Any]]: ... - @overload - def run( - self, - messages: AgentRunInputs | None = ..., - *, - stream: Literal[True], - session: AgentSession | None = ..., - **kwargs: Any, - ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... - - def run( - self, - messages: AgentRunInputs | None = None, - *, - stream: bool = False, - session: AgentSession | None = None, - **kwargs: Any, - ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: - self.captured_kwargs.append(dict(kwargs)) - if not self._asked: - self._asked = True - - async def _pause() -> AgentResponse: - call = Content.from_function_call(call_id="c1", name="do_thing", arguments="{}") - req = Content.from_function_approval_request(id="r1", function_call=call) - return AgentResponse(messages=[Message("assistant", [req])]) - - return _pause() - - async def _done() -> AgentResponse: - return AgentResponse(messages=[Message("assistant", ["done"])]) - - return _done() - - from agent_framework import WorkflowBuilder - - agent = _ApprovalCapturingAgent() - workflow = WorkflowBuilder(start_executor=agent, output_executors=[agent]).build() - - # Initial run with non-empty kwargs - result = await workflow.run("go", custom_data={"token": "abc"}) - request_events = result.get_request_info_events() - assert len(request_events) == 1 - - # Continue with custom_data={} — explicitly clearing the value. - # kwargs={"custom_data": {}} is truthy (has a key), so run_kwargs is set. - approval = request_events[0] - await workflow.run( - responses={approval.request_id: approval.data.to_function_approval_response(True)}, - custom_data={}, - ) - - assert len(agent.captured_kwargs) == 2 - assert agent.captured_kwargs[0].get("custom_data") == {"token": "abc"} - # The continuation explicitly set custom_data={}, overriding the original - assert agent.captured_kwargs[1].get("custom_data") == {} - - async def test_kwargs_reset_context_stores_empty_dict() -> None: """Test that reset_context=True with no kwargs stores an empty dict. @@ -743,33 +409,9 @@ async def test_kwargs_reset_context_stores_empty_dict() -> None: break assert len(agent.captured_kwargs) >= 1 - # The only kwarg should be the framework-injected 'options' (no user-provided kwargs) received = agent.captured_kwargs[0] - assert "custom_data" not in received - assert received.get("options") is None - - -async def test_kwargs_preserved_across_workflow_reruns() -> None: - """Test that kwargs are correctly isolated between workflow runs.""" - agent = _KwargsCapturingAgent(name="rerun_test") - - # Build separate workflows for each run to avoid "already running" error - workflow1 = SequentialBuilder(participants=[agent]).build() - workflow2 = SequentialBuilder(participants=[agent]).build() - - # First run - async for event in workflow1.run("run1", run_id="first", stream=True): - if event.type == "status" and event.state == WorkflowRunState.IDLE: - break - - # Second run with different kwargs (using fresh workflow) - async for event in workflow2.run("run2", run_id="second", stream=True): - if event.type == "status" and event.state == WorkflowRunState.IDLE: - break - - assert len(agent.captured_kwargs) >= 2 - assert agent.captured_kwargs[0].get("run_id") == "first" - assert agent.captured_kwargs[1].get("run_id") == "second" + assert received.get("function_invocation_kwargs") is None + assert received.get("client_kwargs") is None # endregion @@ -780,7 +422,7 @@ async def test_kwargs_preserved_across_workflow_reruns() -> None: @pytest.mark.xfail(reason="Handoff workflow does not yet propagate kwargs to agents") async def test_handoff_kwargs_flow_to_agents() -> None: - """Test that kwargs flow to agents in a handoff workflow.""" + """Test that function_invocation_kwargs flow to agents in a handoff workflow.""" agent1 = _KwargsCapturingAgent(name="coordinator") agent2 = _KwargsCapturingAgent(name="specialist") @@ -792,15 +434,15 @@ async def test_handoff_kwargs_flow_to_agents() -> None: .build() ) - custom_data = {"session_id": "handoff123"} + fi_kwargs = {"session_id": "handoff123"} - async for event in workflow.run("handoff test", custom_data=custom_data, stream=True): + async for event in workflow.run("handoff test", function_invocation_kwargs=fi_kwargs, stream=True): if event.type == "status" and event.state == WorkflowRunState.IDLE: break # Coordinator agent should have received kwargs assert len(agent1.captured_kwargs) >= 1, "Coordinator should be invoked in handoff" - assert agent1.captured_kwargs[0].get("custom_data") == custom_data + assert agent1.captured_kwargs[0].get("function_invocation_kwargs") == fi_kwargs # endregion @@ -852,7 +494,7 @@ async def prepare_final_answer(self, magentic_context: MagenticContext) -> Messa custom_data = {"session_id": "magentic123"} - async for event in workflow.run("magentic test", custom_data=custom_data, stream=True): + async for event in workflow.run("magentic test", function_invocation_kwargs=custom_data, stream=True): if event.type == "status" and event.state == WorkflowRunState.IDLE: break @@ -903,7 +545,7 @@ async def prepare_final_answer(self, magentic_context: MagenticContext) -> Messa # Use MagenticWorkflow.run() which goes through the kwargs attachment path custom_data = {"magentic_key": "magentic_value"} - async for event in magentic_workflow.run("test task", custom_data=custom_data, stream=True): + async for event in magentic_workflow.run("test task", function_invocation_kwargs=custom_data, stream=True): if event.type == "status" and event.state == WorkflowRunState.IDLE: break @@ -918,86 +560,61 @@ async def prepare_final_answer(self, magentic_context: MagenticContext) -> Messa async def test_workflow_as_agent_run_propagates_kwargs_to_underlying_agent() -> None: - """Test that kwargs passed to workflow_agent.run() flow through to the underlying agents.""" + """Test that function_invocation_kwargs passed to workflow_agent.run() flow through to the underlying agents.""" agent = _KwargsCapturingAgent(name="inner_agent") workflow = SequentialBuilder(participants=[agent]).build() workflow_agent = workflow.as_agent(name="TestWorkflowAgent") - custom_data = {"endpoint": "https://api.example.com", "version": "v1"} - user_token = {"user_name": "alice", "access_level": "admin"} + fi_kwargs = {"endpoint": "https://api.example.com", "version": "v1"} _ = await workflow_agent.run( "test message", - custom_data=custom_data, - user_token=user_token, + function_invocation_kwargs=fi_kwargs, ) # Verify inner agent received kwargs assert len(agent.captured_kwargs) >= 1, "Inner agent should have been invoked at least once" received = agent.captured_kwargs[0] - assert "custom_data" in received, "Inner agent should receive custom_data kwarg" - assert "user_token" in received, "Inner agent should receive user_token kwarg" - assert received["custom_data"] == custom_data - assert received["user_token"] == user_token + assert received.get("function_invocation_kwargs") == fi_kwargs async def test_workflow_as_agent_run_stream_propagates_kwargs_to_underlying_agent() -> None: - """Test that kwargs passed to workflow_agent.run() flow through to the underlying agents.""" + """Test that function_invocation_kwargs passed to workflow_agent.run(stream=True) flow through.""" agent = _KwargsCapturingAgent(name="inner_agent") workflow = SequentialBuilder(participants=[agent]).build() workflow_agent = workflow.as_agent(name="TestWorkflowAgent") - custom_data = {"session_id": "xyz123"} - api_token = "secret-token" + fi_kwargs = {"session_id": "xyz123"} async for _ in workflow_agent.run( "test message", stream=True, - custom_data=custom_data, - api_token=api_token, + function_invocation_kwargs=fi_kwargs, ): pass # Verify inner agent received kwargs assert len(agent.captured_kwargs) >= 1, "Inner agent should have been invoked at least once" received = agent.captured_kwargs[0] - assert "custom_data" in received, "Inner agent should receive custom_data kwarg" - assert "api_token" in received, "Inner agent should receive api_token kwarg" - assert received["custom_data"] == custom_data - assert received["api_token"] == api_token + assert received.get("function_invocation_kwargs") == fi_kwargs async def test_workflow_as_agent_propagates_kwargs_to_multiple_agents() -> None: - """Test that kwargs flow to all agents when using workflow.as_agent().""" + """Test that function_invocation_kwargs flow to all agents when using workflow.as_agent().""" agent1 = _KwargsCapturingAgent(name="agent1") agent2 = _KwargsCapturingAgent(name="agent2") workflow = SequentialBuilder(participants=[agent1, agent2]).build() workflow_agent = workflow.as_agent(name="MultiAgentWorkflow") - custom_data = {"batch_id": "batch-001"} + fi_kwargs = {"batch_id": "batch-001"} - _ = await workflow_agent.run("test message", custom_data=custom_data) + _ = await workflow_agent.run("test message", function_invocation_kwargs=fi_kwargs) # Both agents should have received kwargs assert len(agent1.captured_kwargs) >= 1, "First agent should be invoked" assert len(agent2.captured_kwargs) >= 1, "Second agent should be invoked" - assert agent1.captured_kwargs[0].get("custom_data") == custom_data - assert agent2.captured_kwargs[0].get("custom_data") == custom_data - - -async def test_workflow_as_agent_kwargs_with_none_values() -> None: - """Test that kwargs with None values are passed through correctly via as_agent().""" - agent = _KwargsCapturingAgent(name="none_test_agent") - workflow = SequentialBuilder(participants=[agent]).build() - workflow_agent = workflow.as_agent(name="NoneTestWorkflow") - - _ = await workflow_agent.run("test", optional_param=None, other_param="value") - - assert len(agent.captured_kwargs) >= 1 - received = agent.captured_kwargs[0] - assert "optional_param" in received - assert received["optional_param"] is None - assert received["other_param"] == "value" + assert agent1.captured_kwargs[0].get("function_invocation_kwargs") == fi_kwargs + assert agent2.captured_kwargs[0].get("function_invocation_kwargs") == fi_kwargs async def test_workflow_as_agent_kwargs_with_complex_nested_data() -> None: @@ -1016,11 +633,11 @@ async def test_workflow_as_agent_kwargs_with_complex_nested_data() -> None: }, } - _ = await workflow_agent.run("test", complex_data=complex_data) + _ = await workflow_agent.run("test", function_invocation_kwargs=complex_data) assert len(agent.captured_kwargs) >= 1 received = agent.captured_kwargs[0] - assert received.get("complex_data") == complex_data + assert received.get("function_invocation_kwargs") == complex_data # endregion @@ -1050,15 +667,13 @@ async def test_subworkflow_kwargs_propagation() -> None: outer_workflow = SequentialBuilder(participants=[subworkflow_executor]).build() # Define kwargs that should propagate to subworkflow - custom_data = {"api_key": "secret123", "endpoint": "https://api.example.com"} - user_token = {"user_name": "alice", "access_level": "admin"} + fi_kwargs = {"api_key": "secret123", "endpoint": "https://api.example.com"} # Run the outer workflow with kwargs async for event in outer_workflow.run( "test message for subworkflow", stream=True, - custom_data=custom_data, - user_token=user_token, + function_invocation_kwargs=fi_kwargs, ): if event.type == "status" and event.state == WorkflowRunState.IDLE: break @@ -1069,17 +684,8 @@ async def test_subworkflow_kwargs_propagation() -> None: received_kwargs = inner_agent.captured_kwargs[0] # Verify kwargs were propagated from parent workflow to subworkflow agent - assert "custom_data" in received_kwargs, ( - f"Subworkflow agent should receive 'custom_data' kwarg. Received keys: {list(received_kwargs.keys())}" - ) - assert "user_token" in received_kwargs, ( - f"Subworkflow agent should receive 'user_token' kwarg. Received keys: {list(received_kwargs.keys())}" - ) - assert received_kwargs.get("custom_data") == custom_data, ( - f"Expected custom_data={custom_data}, got {received_kwargs.get('custom_data')}" - ) - assert received_kwargs.get("user_token") == user_token, ( - f"Expected user_token={user_token}, got {received_kwargs.get('user_token')}" + assert received_kwargs.get("function_invocation_kwargs") == fi_kwargs, ( + f"Expected function_invocation_kwargs={fi_kwargs}, got {received_kwargs.get('function_invocation_kwargs')}" ) @@ -1114,11 +720,11 @@ async def read_kwargs(self, msgs: list[Message], ctx: WorkflowContext[list[Messa outer_workflow = SequentialBuilder(participants=[subworkflow_executor]).build() # Run with kwargs + fi_kwargs = {"my_custom_kwarg": "should_be_propagated", "another_kwarg": 42} async for event in outer_workflow.run( "test", stream=True, - my_custom_kwarg="should_be_propagated", - another_kwarg=42, + function_invocation_kwargs=fi_kwargs, ): if event.type == "status" and event.state == WorkflowRunState.IDLE: break @@ -1128,11 +734,8 @@ async def read_kwargs(self, msgs: list[Message], ctx: WorkflowContext[list[Messa kwargs_in_subworkflow = captured_kwargs_from_state[0] - assert kwargs_in_subworkflow.get("my_custom_kwarg") == "should_be_propagated", ( - f"Expected 'my_custom_kwarg' in subworkflow got: {kwargs_in_subworkflow}" - ) - assert kwargs_in_subworkflow.get("another_kwarg") == 42, ( - f"Expected 'another_kwarg'=42 in subworkflow got: {kwargs_in_subworkflow}" + assert "function_invocation_kwargs" in kwargs_in_subworkflow, ( + f"Expected 'function_invocation_kwargs' in subworkflow state, got: {kwargs_in_subworkflow}" ) @@ -1164,7 +767,7 @@ async def test_nested_subworkflow_kwargs_propagation() -> None: async for event in outer_workflow.run( "deeply nested test", stream=True, - deep_kwarg="should_reach_inner", + function_invocation_kwargs={"deep_kwarg": "should_reach_inner"}, ): if event.type == "status" and event.state == WorkflowRunState.IDLE: break @@ -1173,9 +776,250 @@ async def test_nested_subworkflow_kwargs_propagation() -> None: assert len(inner_agent.captured_kwargs) >= 1, "Deeply nested agent should be invoked" received = inner_agent.captured_kwargs[0] - assert received.get("deep_kwarg") == "should_reach_inner", ( + assert received.get("function_invocation_kwargs") == {"deep_kwarg": "should_reach_inner"}, ( f"Deeply nested agent should receive 'deep_kwarg'. Got: {received}" ) # endregion + + +# region Per-Executor Invocation Kwargs Tests + + +async def test_function_and_client_kwargs_together() -> None: + """Both function_invocation_kwargs and client_kwargs can be provided in the same call.""" + agent1 = _KwargsCapturingAgent(name="agent1") + agent2 = _KwargsCapturingAgent(name="agent2") + workflow = SequentialBuilder(participants=[agent1, agent2]).build() + + fi_kwargs = {"tool_param": "tool_value"} + ci_kwargs = {"temperature": 0.7} + + async for event in workflow.run( + "test", + stream=True, + function_invocation_kwargs=fi_kwargs, + client_kwargs=ci_kwargs, + ): + if event.type == "status" and event.state == WorkflowRunState.IDLE: + break + + # Both agents should receive both kwargs + for agent in [agent1, agent2]: + assert len(agent.captured_kwargs) >= 1 + assert agent.captured_kwargs[0].get("function_invocation_kwargs") == fi_kwargs + assert agent.captured_kwargs[0].get("client_kwargs") == ci_kwargs + + +async def test_global_function_invocation_kwargs_flow_to_all_agents() -> None: + """Global function_invocation_kwargs should be received by all agents in a sequential workflow.""" + agent1 = _KwargsCapturingAgent(name="agent1") + agent2 = _KwargsCapturingAgent(name="agent2") + workflow = SequentialBuilder(participants=[agent1, agent2]).build() + + fi_kwargs = {"tool_param": "shared_value"} + + async for event in workflow.run( + "test", + stream=True, + function_invocation_kwargs=fi_kwargs, + ): + if event.type == "status" and event.state == WorkflowRunState.IDLE: + break + + # Both agents should receive function_invocation_kwargs + assert len(agent1.captured_kwargs) >= 1 + assert agent1.captured_kwargs[0].get("function_invocation_kwargs") == fi_kwargs + assert len(agent2.captured_kwargs) >= 1 + assert agent2.captured_kwargs[0].get("function_invocation_kwargs") == fi_kwargs + + +async def test_per_executor_function_invocation_kwargs_routes_to_correct_agent() -> None: + """Per-executor function_invocation_kwargs should only be received by the targeted agent.""" + agent1 = _KwargsCapturingAgent(name="agent1") + agent2 = _KwargsCapturingAgent(name="agent2") + workflow = SequentialBuilder(participants=[agent1, agent2]).build() + + # Per-executor: keys match agent names (which are used as executor IDs) + fi_kwargs = { + "agent1": {"tool_param": "value_for_agent1"}, + "agent2": {"tool_param": "value_for_agent2"}, + } + + async for event in workflow.run( + "test", + stream=True, + function_invocation_kwargs=fi_kwargs, + ): + if event.type == "status" and event.state == WorkflowRunState.IDLE: + break + + # Each agent should receive only its own kwargs + assert len(agent1.captured_kwargs) >= 1 + assert agent1.captured_kwargs[0].get("function_invocation_kwargs") == {"tool_param": "value_for_agent1"} + assert len(agent2.captured_kwargs) >= 1 + assert agent2.captured_kwargs[0].get("function_invocation_kwargs") == {"tool_param": "value_for_agent2"} + + +async def test_per_executor_kwargs_unmatched_agent_gets_none() -> None: + """An agent not targeted in per-executor kwargs should receive None for that kwarg.""" + agent1 = _KwargsCapturingAgent(name="agent1") + agent2 = _KwargsCapturingAgent(name="agent2") + workflow = SequentialBuilder(participants=[agent1, agent2]).build() + + # Only agent1 is targeted + fi_kwargs = {"agent1": {"tool_param": "only_for_agent1"}} + + async for event in workflow.run( + "test", + stream=True, + function_invocation_kwargs=fi_kwargs, + ): + if event.type == "status" and event.state == WorkflowRunState.IDLE: + break + + assert len(agent1.captured_kwargs) >= 1 + assert agent1.captured_kwargs[0].get("function_invocation_kwargs") == {"tool_param": "only_for_agent1"} + assert len(agent2.captured_kwargs) >= 1 + assert agent2.captured_kwargs[0].get("function_invocation_kwargs") is None + + +async def test_global_client_kwargs_flow_to_all_agents() -> None: + """Global client_kwargs should be received by all agents.""" + agent1 = _KwargsCapturingAgent(name="agent1") + agent2 = _KwargsCapturingAgent(name="agent2") + workflow = SequentialBuilder(participants=[agent1, agent2]).build() + + ci_kwargs = {"temperature": 0.5} + + async for event in workflow.run( + "test", + stream=True, + client_kwargs=ci_kwargs, + ): + if event.type == "status" and event.state == WorkflowRunState.IDLE: + break + + assert len(agent1.captured_kwargs) >= 1 + assert agent1.captured_kwargs[0].get("client_kwargs") == ci_kwargs + assert len(agent2.captured_kwargs) >= 1 + assert agent2.captured_kwargs[0].get("client_kwargs") == ci_kwargs + + +async def test_per_executor_client_kwargs_routes_correctly() -> None: + """Per-executor client_kwargs should only be received by the targeted agent.""" + agent1 = _KwargsCapturingAgent(name="agent1") + agent2 = _KwargsCapturingAgent(name="agent2") + workflow = SequentialBuilder(participants=[agent1, agent2]).build() + + ci_kwargs = { + "agent1": {"temperature": 0.1}, + "agent2": {"temperature": 0.9}, + } + + async for event in workflow.run( + "test", + stream=True, + client_kwargs=ci_kwargs, + ): + if event.type == "status" and event.state == WorkflowRunState.IDLE: + break + + assert len(agent1.captured_kwargs) >= 1 + assert agent1.captured_kwargs[0].get("client_kwargs") == {"temperature": 0.1} + assert len(agent2.captured_kwargs) >= 1 + assert agent2.captured_kwargs[0].get("client_kwargs") == {"temperature": 0.9} + + +async def test_resolve_invocation_kwargs_logs_per_executor(caplog: "LogCaptureFixture") -> None: + """Workflow._resolve_invocation_kwargs logs info when per-executor format is detected.""" + import logging + + agent = _KwargsCapturingAgent(name="agent1") + workflow = SequentialBuilder(participants=[agent]).build() + + with caplog.at_level(logging.INFO): + async for event in workflow.run( + "test", + stream=True, + function_invocation_kwargs={"agent1": {"key": "val"}}, + ): + if event.type == "status" and event.state == WorkflowRunState.IDLE: + break + + per_executor_logs = [r for r in caplog.records if "per-executor" in r.message.lower()] + assert len(per_executor_logs) >= 1 + + +async def test_resolve_invocation_kwargs_logs_global(caplog: "LogCaptureFixture") -> None: + """Workflow._resolve_invocation_kwargs logs info when global format is detected.""" + import logging + + agent = _KwargsCapturingAgent(name="agent1") + workflow = SequentialBuilder(participants=[agent]).build() + + with caplog.at_level(logging.INFO): + async for event in workflow.run( + "test", + stream=True, + function_invocation_kwargs={"tool_key": "tool_val"}, + ): + if event.type == "status" and event.state == WorkflowRunState.IDLE: + break + + global_logs = [r for r in caplog.records if "global kwargs" in r.message.lower()] + assert len(global_logs) >= 1 + + +async def test_empty_function_invocation_kwargs_clears_previous() -> None: + """Passing function_invocation_kwargs={} should clear previously stored kwargs on a new run.""" + agent = _KwargsCapturingAgent(name="clearing_agent") + workflow = SequentialBuilder(participants=[agent]).build() + + # First run: provide kwargs + await workflow.run( + "first", + function_invocation_kwargs={"key": "value"}, + ) + + assert len(agent.captured_kwargs) >= 1 + assert agent.captured_kwargs[0].get("function_invocation_kwargs") == {"key": "value"} + + # Second run: pass empty dict to explicitly clear + await workflow.run( + "second", + function_invocation_kwargs={}, + ) + + # Agent should receive None because the empty dict resolves to an empty + # __global__ entry which is treated as "no kwargs" for each executor. + assert len(agent.captured_kwargs) >= 2 + assert agent.captured_kwargs[-1].get("function_invocation_kwargs") == {} + + +async def test_empty_client_kwargs_clears_previous() -> None: + """Passing client_kwargs={} should clear previously stored kwargs on a new run.""" + agent = _KwargsCapturingAgent(name="clearing_agent") + workflow = SequentialBuilder(participants=[agent]).build() + + # First run: provide kwargs + await workflow.run( + "first", + client_kwargs={"temperature": 0.5}, + ) + + assert len(agent.captured_kwargs) >= 1 + assert agent.captured_kwargs[0].get("client_kwargs") == {"temperature": 0.5} + + # Second run: pass empty dict to explicitly clear + await workflow.run( + "second", + client_kwargs={}, + ) + + assert len(agent.captured_kwargs) >= 2 + assert agent.captured_kwargs[-1].get("client_kwargs") == {} + + +# endregion diff --git a/python/samples/03-workflows/state-management/workflow_kwargs.py b/python/samples/03-workflows/state-management/workflow_kwargs.py deleted file mode 100644 index 0d50b8710d..0000000000 --- a/python/samples/03-workflows/state-management/workflow_kwargs.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -import asyncio -import json -import os -from typing import Annotated, Any, cast - -from agent_framework import Agent, Message, tool -from agent_framework.foundry import FoundryChatClient -from agent_framework.orchestrations import SequentialBuilder -from azure.identity import AzureCliCredential -from dotenv import load_dotenv -from pydantic import Field - -# Load environment variables from .env file -load_dotenv() - -""" -Sample: Workflow kwargs Flow to @tool Tools - -This sample demonstrates how to flow custom context (skill data, user tokens, etc.) -through any workflow pattern to @tool functions using the **kwargs pattern. - -Key Concepts: -- Pass custom context as kwargs when invoking workflow.run() -- kwargs are stored in State and passed to all agent invocations -- @tool functions receive kwargs via **kwargs parameter -- Works with Sequential, Concurrent, GroupChat, Handoff, and Magentic patterns - -Prerequisites: -- FOUNDRY_PROJECT_ENDPOINT must be your Azure AI Foundry Agent Service (V2) project endpoint. -- FOUNDRY_MODEL must be set to your Azure OpenAI model deployment name. -""" - - -# Define tools that accept custom context via **kwargs -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; -# see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. -@tool(approval_mode="never_require") -def get_user_data( - query: Annotated[str, Field(description="What user data to retrieve")], - **kwargs: Any, -) -> str: - """Retrieve user-specific data based on the authenticated context.""" - user_token = kwargs.get("user_token", {}) - user_name = user_token.get("user_name", "anonymous") - access_level = user_token.get("access_level", "none") - - print(f"\n[get_user_data] Received kwargs keys: {list(kwargs.keys())}") - print(f"[get_user_data] User: {user_name}") - print(f"[get_user_data] Access level: {access_level}") - - return f"Retrieved data for user {user_name} with {access_level} access: {query}" - - -@tool(approval_mode="never_require") -def call_api( - endpoint_name: Annotated[str, Field(description="Name of the API endpoint to call")], - **kwargs: Any, -) -> str: - """Call an API using the configured endpoints from custom_data.""" - custom_data = kwargs.get("custom_data", {}) - api_config = custom_data.get("api_config", {}) - - base_url = api_config.get("base_url", "unknown") - endpoints = api_config.get("endpoints", {}) - - print(f"\n[call_api] Received kwargs keys: {list(kwargs.keys())}") - print(f"[call_api] Base URL: {base_url}") - print(f"[call_api] Available endpoints: {list(endpoints.keys())}") - - if endpoint_name in endpoints: - return f"Called {base_url}{endpoints[endpoint_name]} successfully" - return f"Endpoint '{endpoint_name}' not found in configuration" - - -async def main() -> None: - print("=" * 70) - print("Workflow kwargs Flow Demo (SequentialBuilder)") - print("=" * 70) - - # Create chat client - client = FoundryChatClient( - project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"], - model=os.environ["FOUNDRY_MODEL"], - credential=AzureCliCredential(), - ) - - # Create agent with tools that use kwargs - agent = Agent( - client=client, - name="assistant", - instructions=( - "You are a helpful assistant. Use the available tools to help users. " - "When asked about user data, use get_user_data. " - "When asked to call an API, use call_api." - ), - tools=[get_user_data, call_api], - ) - - # Build a simple sequential workflow - workflow = SequentialBuilder(participants=[agent]).build() - - # Define custom context that will flow to tools via kwargs - custom_data = { - "api_config": { - "base_url": "https://api.example.com", - "endpoints": { - "users": "/v1/users", - "orders": "/v1/orders", - "products": "/v1/products", - }, - }, - } - - user_token = { - "user_name": "bob@contoso.com", - "access_level": "admin", - } - - print("\nCustom Data being passed:") - print(json.dumps(custom_data, indent=2)) - print(f"\nUser: {user_token['user_name']}") - print("\n" + "-" * 70) - print("Workflow Execution (watch for [tool_name] logs showing kwargs received):") - print("-" * 70) - - # Run workflow with kwargs - these will flow through to tools - async for event in workflow.run( - "Please get my user data and then call the users API endpoint.", - additional_function_arguments={"custom_data": custom_data, "user_token": user_token}, - stream=True, - ): - if event.type == "output": - output_data = cast(list[Message], event.data) - if isinstance(output_data, list): - for item in output_data: - if isinstance(item, Message) and item.text: - print(f"\n[Final Answer]: {item.text}") - - print("\n" + "=" * 70) - print("Sample Complete") - print("=" * 70) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/python/samples/03-workflows/state-management/workflow_kwargs_global.py b/python/samples/03-workflows/state-management/workflow_kwargs_global.py new file mode 100644 index 0000000000..89f0c1672e --- /dev/null +++ b/python/samples/03-workflows/state-management/workflow_kwargs_global.py @@ -0,0 +1,170 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +import json +import os +from typing import Annotated, Any, cast + +from agent_framework import Agent, Message, tool +from agent_framework.foundry import FoundryChatClient +from agent_framework.orchestrations import SequentialBuilder +from azure.identity import AzureCliCredential +from dotenv import load_dotenv +from pydantic import Field + +# Load environment variables from .env file +load_dotenv() + +""" +Sample: Global Workflow kwargs + +This sample demonstrates how to pass the same kwargs to every agent in a +workflow using global targeting. When keys in function_invocation_kwargs do NOT +match any executor ID (agent name), the framework treats them as global and +delivers them to all agents. + +Compare with workflow_kwargs_per_agent.py which targets kwargs to specific agents. + +Key Concepts: +- Global function_invocation_kwargs are delivered to every agent in the workflow +- Useful when all agents share the same credentials, config, or context +- @tool functions receive kwargs via the **kwargs parameter + +Prerequisites: +- FOUNDRY_PROJECT_ENDPOINT must be your Azure AI Foundry Agent Service (V2) project endpoint. +- Environment variables configured +""" + + +# 1. Define a tool for the research agent — queries a company's internal +# database using credentials passed via global kwargs. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; +# see samples/02-agents/tools/function_tool_with_approval.py +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. +@tool(approval_mode="never_require") +def query_company_database( + query: Annotated[ + str, Field(description="The database query to run, e.g. 'Q3 revenue' or 'headcount by department'") + ], + **kwargs: Any, +) -> str: + """Query the company's internal database for business metrics and data.""" + db_config = kwargs.get("db_config", {}) + connection_string = db_config.get("connection_string", "") + database = db_config.get("database", "") + + if not connection_string or not database: + return f"ERROR: missing db_config — cannot run query '{query}'" + + print(f"\n [query_company_database] Connecting to {database} at {connection_string[:30]}...") + + # Simulated company data that the LLM would not know on its own + return ( + f"Query results from {database}:\n" + f"- Contoso Q3 2025 revenue: $47.2M (up 12% YoY)\n" + f"- Top product line: CloudSync Pro ($18.6M)\n" + f"- Engineering headcount: 342 (up from 298 in Q2)\n" + f"- Customer churn rate: 4.1% (down from 5.3% in Q2)\n" + f"- Net new enterprise customers: 28" + ) + + +# 2. Define a tool for the writer agent — retrieves the formatting style +# from user preferences passed via global kwargs. +@tool(approval_mode="never_require") +def get_formatting_instructions( + section_title: Annotated[str, Field(description="The title of the section or report to format")], + **kwargs: Any, +) -> str: + """Get the formatting instructions based on user preferences.""" + user_prefs = kwargs.get("user_preferences", {}) + output_format = user_prefs.get("format", "plain") + language = user_prefs.get("language", "en") + + print(f"\n [get_formatting_instructions] Format: {output_format}, Language: {language}") + + return ( + f"Formatting rules for '{section_title}':\n" + f"- Output format: {output_format}\n" + f"- Language/locale: {language}\n" + f"- Include a footer: 'Generated in {output_format} for locale {language}'" + ) + + +async def main() -> None: + print("=" * 70) + print("Global Workflow kwargs Demo") + print("=" * 70) + + # 3. Create a shared chat client. + client = FoundryChatClient( + project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"], + model=os.environ["FOUNDRY_MODEL"], + credential=AzureCliCredential(), + ) + + # 4. Create two agents with different tools and responsibilities. + researcher = Agent( + client=client, + name="researcher", + instructions=( + "You are a data analyst. Call query_company_database exactly once " + "with the user's request as the query. Return the raw results." + ), + tools=[query_company_database], + ) + + writer = Agent( + client=client, + name="writer", + instructions=( + "You are a report writer. Call get_formatting_instructions exactly once, " + "then rewrite the data you receive into a polished report following those rules." + ), + tools=[get_formatting_instructions], + ) + + # 5. Build a sequential workflow: researcher -> writer. + workflow = SequentialBuilder(participants=[researcher, writer]).build() + + # 6. Define global kwargs — every agent receives all of these. + # Because the keys ("db_config", "user_preferences") do NOT match any + # executor ID ("researcher", "writer"), the framework treats them as + # global and delivers the full dict to every agent. + global_fi_kwargs = { + "db_config": { + "connection_string": "Server=contoso-sql.database.windows.net;Database=metrics", + "database": "contoso_metrics_prod", + }, + "user_preferences": { + "format": "markdown", + "language": "en-US", + }, + } + + print("\nGlobal function_invocation_kwargs (sent to all agents):") + print(json.dumps(global_fi_kwargs, indent=2)) + print("\n" + "-" * 70) + print("Workflow Execution:") + print("-" * 70) + + # 7. Run the workflow — every agent receives the same global kwargs. + async for event in workflow.run( + "Pull Contoso's Q3 2025 performance data and write an executive summary.", + function_invocation_kwargs=global_fi_kwargs, + stream=True, + ): + if event.type == "output": + output_data = cast(list[Message], event.data) + if isinstance(output_data, list): + for item in output_data: + if isinstance(item, Message) and item.text: + print(f"\n[{item.author_name}]: {item.text}") + + print("\n" + "=" * 70) + print("Sample Complete") + print("=" * 70) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/03-workflows/state-management/workflow_kwargs_per_agent.py b/python/samples/03-workflows/state-management/workflow_kwargs_per_agent.py new file mode 100644 index 0000000000..d35d409301 --- /dev/null +++ b/python/samples/03-workflows/state-management/workflow_kwargs_per_agent.py @@ -0,0 +1,222 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +import json +import os +from typing import Annotated, Any, cast + +from agent_framework import Agent, Message, tool +from agent_framework.foundry import FoundryChatClient +from agent_framework.orchestrations import SequentialBuilder +from azure.identity import AzureCliCredential +from dotenv import load_dotenv +from pydantic import Field + +# Load environment variables from .env file +load_dotenv() + +""" +Sample: Per-Agent Workflow kwargs + +This sample demonstrates how to pass different kwargs to different agents in a +workflow using per-agent targeting. When keys in function_invocation_kwargs (or +client_kwargs) match executor IDs (agent names by default), each agent +receives only its own slice of the kwargs. + +Key Concepts: +- Per-agent function_invocation_kwargs target specific agents by executor ID +- Agents only receive the kwargs assigned to them (not other agents' kwargs) +- Useful when different agents need different credentials, configs, or context + +Prerequisites: +- FOUNDRY_PROJECT_ENDPOINT must be your Azure AI Foundry Agent Service (V2) project endpoint. +- Environment variables configured +""" + + +# 1. Define a tool for the research agent — queries a company's internal +# database using credentials passed via per-agent kwargs. +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; +# see samples/02-agents/tools/function_tool_with_approval.py +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. +@tool(approval_mode="never_require") +def query_company_database( + query: Annotated[ + str, Field(description="The database query to run, e.g. 'Q3 revenue' or 'headcount by department'") + ], + **kwargs: Any, +) -> str: + """Query the company's internal database for business metrics and data.""" + db_config = kwargs.get("db_config", {}) + connection_string = db_config.get("connection_string", "") + database = db_config.get("database", "") + + if not connection_string or not database: + return f"ERROR: missing db_config — cannot run query '{query}'" + + print(f"\n [query_company_database] Connecting to {database} at {connection_string[:30]}...") + + # Simulated company data that the LLM would not know on its own + return ( + f"Query results from {database}:\n" + f"- Contoso Q3 2025 revenue: $47.2M (up 12% YoY)\n" + f"- Top product line: CloudSync Pro ($18.6M)\n" + f"- Engineering headcount: 342 (up from 298 in Q2)\n" + f"- Customer churn rate: 4.1% (down from 5.3% in Q2)\n" + f"- Net new enterprise customers: 28" + ) + + +# 2. Define a tool for the writer agent — retrieves the formatting style +# from user preferences passed via per-agent kwargs. +@tool(approval_mode="never_require") +def get_formatting_instructions( + section_title: Annotated[str, Field(description="The title of the section or report to format")], + **kwargs: Any, +) -> str: + """Get the formatting instructions based on user preferences.""" + user_prefs = kwargs.get("user_preferences", {}) + output_format = user_prefs.get("format", "plain") + language = user_prefs.get("language", "en") + + print(f"\n [get_formatting_instructions] Format: {output_format}, Language: {language}") + + return ( + f"Formatting rules for '{section_title}':\n" + f"- Output format: {output_format}\n" + f"- Language/locale: {language}\n" + f"- Include a footer: 'Generated in {output_format} for locale {language}'" + ) + + +async def main() -> None: + print("=" * 70) + print("Per-Agent Workflow kwargs Demo") + print("=" * 70) + + # 3. Create a shared chat client. + client = FoundryChatClient( + project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"], + model=os.environ["FOUNDRY_MODEL"], + credential=AzureCliCredential(), + ) + + # 4. Create two agents with different tools and responsibilities. + researcher = Agent( + client=client, + name="researcher", + instructions=( + "You are a data analyst. Call query_company_database exactly once " + "with the user's request as the query. Return the raw results." + ), + tools=[query_company_database], + ) + + writer = Agent( + client=client, + name="writer", + instructions=( + "You are a report writer. Call get_formatting_instructions exactly once, " + "then rewrite the data you receive into a polished report following those rules." + ), + tools=[get_formatting_instructions], + ) + + # 5. Build a sequential workflow: researcher -> writer. + workflow = SequentialBuilder(participants=[researcher, writer]).build() + + # 6. Define per-agent kwargs — each agent gets only its own config. + # The keys ("researcher", "writer") match the agent names, which are + # used as executor IDs by default. + per_agent_fi_kwargs = { + "researcher": { + "db_config": { + "connection_string": "Server=contoso-sql.database.windows.net;Database=metrics", + "database": "contoso_metrics_prod", + }, + }, + "writer": { + "user_preferences": { + "format": "markdown", + "language": "en-US", + }, + }, + } + + print("\nPer-agent function_invocation_kwargs:") + print(json.dumps(per_agent_fi_kwargs, indent=2)) + print("\n" + "-" * 70) + print("Workflow Execution:") + print("-" * 70) + + # 7. Run the workflow — each agent receives only its targeted kwargs. + async for event in workflow.run( + "Pull Contoso's Q3 2025 performance data and write an executive summary.", + function_invocation_kwargs=per_agent_fi_kwargs, + stream=True, + ): + if event.type == "output": + output_data = cast(list[Message], event.data) + if isinstance(output_data, list): + for item in output_data: + if isinstance(item, Message) and item.text: + print(f"\n[{item.author_name}]: {item.text}") + + print("\n" + "=" * 70) + print("Sample Complete") + print("=" * 70) + + +if __name__ == "__main__": + asyncio.run(main()) + +""" +Sample output: + +Per-agent function_invocation_kwargs: +{ + "researcher": { + "db_config": { + "connection_string": "Server=contoso-sql.database.windows.net;Database=metrics", + "database": "contoso_metrics_prod" + } + }, + "writer": { + "user_preferences": { + "format": "markdown", + "language": "en-US" + } + } +} + +---------------------------------------------------------------------- +Workflow Execution: +---------------------------------------------------------------------- + + [query_company_database] Connecting to contoso_metrics_prod at Server=contoso-sql.database.wi... + +[researcher]: Here is Contoso's Q3 2025 data: +- Revenue: $47.2M (up 12% YoY) +- Top product: CloudSync Pro ($18.6M) +- Engineering headcount: 342 +- Churn rate: 4.1% +- Net new enterprise customers: 28 + + [get_formatting_instructions] Format: markdown, Language: en-US + +[writer]: # Contoso Q3 2025 Executive Summary + +| Metric | Value | +|---|---| +| Revenue | $47.2M (+12% YoY) | +| Top Product | CloudSync Pro ($18.6M) | +| Engineering Headcount | 342 | +| Customer Churn | 4.1% | +| New Enterprise Customers | 28 | + +Generated in markdown for locale en-US + +====================================================================== +Sample Complete +====================================================================== +"""