diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_client.py b/python/packages/ag-ui/agent_framework_ag_ui/_client.py index d2fb59bbb6..7a1b974a38 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_client.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_client.py @@ -111,8 +111,8 @@ def _map_update(update: ChatResponseUpdate) -> ChatResponseUpdate: @_apply_server_function_call_unwrap class AGUIChatClient( - ChatMiddlewareLayer[AGUIChatOptionsT], FunctionInvocationLayer[AGUIChatOptionsT], + ChatMiddlewareLayer[AGUIChatOptionsT], ChatTelemetryLayer[AGUIChatOptionsT], BaseChatClient[AGUIChatOptionsT], Generic[AGUIChatOptionsT], diff --git a/python/packages/ag-ui/tests/ag_ui/conftest.py b/python/packages/ag-ui/tests/ag_ui/conftest.py index 42a6967371..744196dbdf 100644 --- a/python/packages/ag-ui/tests/ag_ui/conftest.py +++ b/python/packages/ag-ui/tests/ag_ui/conftest.py @@ -45,8 +45,8 @@ def pytest_configure() -> None: class StreamingChatClientStub( - ChatMiddlewareLayer[OptionsCoT], FunctionInvocationLayer[OptionsCoT], + ChatMiddlewareLayer[OptionsCoT], ChatTelemetryLayer[OptionsCoT], BaseChatClient[OptionsCoT], Generic[OptionsCoT], @@ -54,7 +54,7 @@ class StreamingChatClientStub( """Typed streaming stub that satisfies SupportsChatGetResponse.""" def __init__(self, stream_fn: StreamFn, response_fn: ResponseFn | None = None) -> None: - super().__init__(function_middleware=[]) + super().__init__(middleware=[]) self._stream_fn = stream_fn self._response_fn = response_fn self.last_session: AgentSession | None = None diff --git a/python/packages/anthropic/agent_framework_anthropic/__init__.py b/python/packages/anthropic/agent_framework_anthropic/__init__.py index 706740a127..ad0cff9648 100644 --- a/python/packages/anthropic/agent_framework_anthropic/__init__.py +++ b/python/packages/anthropic/agent_framework_anthropic/__init__.py @@ -2,7 +2,7 @@ import importlib.metadata -from ._chat_client import AnthropicChatOptions, AnthropicClient +from ._chat_client import AnthropicChatOptions, AnthropicClient, RawAnthropicClient try: __version__ = importlib.metadata.version(__name__) @@ -12,5 +12,6 @@ __all__ = [ "AnthropicChatOptions", "AnthropicClient", + "RawAnthropicClient", "__version__", ] diff --git a/python/packages/anthropic/agent_framework_anthropic/_chat_client.py b/python/packages/anthropic/agent_framework_anthropic/_chat_client.py index a1915a69fb..b3b61a4640 100644 --- a/python/packages/anthropic/agent_framework_anthropic/_chat_client.py +++ b/python/packages/anthropic/agent_framework_anthropic/_chat_client.py @@ -68,6 +68,7 @@ __all__ = [ "AnthropicChatOptions", "AnthropicClient", + "RawAnthropicClient", "ThinkingConfig", ] @@ -210,14 +211,24 @@ class AnthropicSettings(TypedDict, total=False): chat_model_id: str | None -class AnthropicClient( - ChatMiddlewareLayer[AnthropicOptionsT], - FunctionInvocationLayer[AnthropicOptionsT], - ChatTelemetryLayer[AnthropicOptionsT], +class RawAnthropicClient( BaseChatClient[AnthropicOptionsT], Generic[AnthropicOptionsT], ): - """Anthropic Chat client with middleware, telemetry, and function invocation support.""" + """Raw Anthropic chat client without middleware, telemetry, or function invocation support. + + Warning: + **This class should not normally be used directly.** It does not include middleware, + telemetry, or function invocation support that you most likely need. If you do use it, + you should consider which additional layers to apply. There is a defined ordering that + you should follow: + + 1. **FunctionInvocationLayer** - Owns the tool/function calling loop and routes function middleware + 2. **ChatMiddlewareLayer** - Applies chat middleware per model call and stays outside telemetry + 3. **ChatTelemetryLayer** - Must stay inside chat middleware for correct per-call telemetry + + Use ``AnthropicClient`` instead for a fully-featured client with all layers applied. + """ OTEL_PROVIDER_NAME: ClassVar[str] = "anthropic" # type: ignore[reportIncompatibleVariableOverride, misc] @@ -229,12 +240,10 @@ def __init__( anthropic_client: AsyncAnthropic | None = None, additional_beta_flags: list[str] | None = None, additional_properties: dict[str, Any] | None = None, - middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, - function_invocation_configuration: FunctionInvocationConfiguration | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, ) -> None: - """Initialize an Anthropic Agent client. + """Initialize a raw Anthropic client. Keyword Args: api_key: The Anthropic API key to use for authentication. @@ -245,15 +254,13 @@ def __init__( additional_beta_flags: Additional beta flags to enable on the client. Default flags are: "mcp-client-2025-04-04", "code-execution-2025-08-25". additional_properties: Additional properties stored on the client instance. - middleware: Optional middleware to apply to the client. - function_invocation_configuration: Optional function invocation configuration override. env_file_path: Path to environment file for loading settings. env_file_encoding: Encoding of the environment file. Examples: .. code-block:: python - from agent_framework.anthropic import AnthropicClient + from agent_framework.anthropic import RawAnthropicClient from azure.identity.aio import DefaultAzureCredential # Using environment variables @@ -261,13 +268,13 @@ def __init__( # ANTHROPIC_CHAT_MODEL_ID=claude-sonnet-4-5-20250929 # Or passing parameters directly - client = AnthropicClient( + client = RawAnthropicClient( model_id="claude-sonnet-4-5-20250929", api_key="your_anthropic_api_key", ) # Or loading from a .env file - client = AnthropicClient(env_file_path="path/to/.env") + client = RawAnthropicClient(env_file_path="path/to/.env") # Or passing in an existing client from anthropic import AsyncAnthropic @@ -275,7 +282,7 @@ def __init__( anthropic_client = AsyncAnthropic( api_key="your_anthropic_api_key", base_url="https://custom-anthropic-endpoint.com" ) - client = AnthropicClient( + client = RawAnthropicClient( model_id="claude-sonnet-4-5-20250929", anthropic_client=anthropic_client, ) @@ -289,7 +296,7 @@ class MyOptions(AnthropicChatOptions, total=False): my_custom_option: str - client: AnthropicClient[MyOptions] = AnthropicClient(model_id="claude-sonnet-4-5-20250929") + client: RawAnthropicClient[MyOptions] = RawAnthropicClient(model_id="claude-sonnet-4-5-20250929") response = await client.get_response("Hello", options={"my_custom_option": "value"}) """ @@ -320,8 +327,6 @@ class MyOptions(AnthropicChatOptions, total=False): # Initialize parent super().__init__( additional_properties=additional_properties, - middleware=middleware, - function_invocation_configuration=function_invocation_configuration, ) # Initialize instance variables @@ -1376,3 +1381,95 @@ def service_url(self) -> str: The service URL for the chat client, or None if not set. """ return str(self.anthropic_client.base_url) + + +class AnthropicClient( + FunctionInvocationLayer[AnthropicOptionsT], + ChatMiddlewareLayer[AnthropicOptionsT], + ChatTelemetryLayer[AnthropicOptionsT], + RawAnthropicClient[AnthropicOptionsT], + Generic[AnthropicOptionsT], +): + """Anthropic chat client with middleware, telemetry, and function invocation support.""" + + def __init__( + self, + *, + api_key: str | None = None, + model_id: str | None = None, + anthropic_client: AsyncAnthropic | None = None, + additional_beta_flags: list[str] | None = None, + additional_properties: dict[str, Any] | None = None, + middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, + function_invocation_configuration: FunctionInvocationConfiguration | None = None, + env_file_path: str | None = None, + env_file_encoding: str | None = None, + ) -> None: + """Initialize an Anthropic client. + + Keyword Args: + api_key: The Anthropic API key to use for authentication. + model_id: The ID of the model to use. + anthropic_client: An existing Anthropic client to use. If not provided, one will be created. + This can be used to further configure the client before passing it in. + For instance if you need to set a different base_url for testing or private deployments. + additional_beta_flags: Additional beta flags to enable on the client. + Default flags are: "mcp-client-2025-04-04", "code-execution-2025-08-25". + additional_properties: Additional properties stored on the client instance. + middleware: Optional middleware to apply to the client. + function_invocation_configuration: Optional function invocation configuration override. + env_file_path: Path to environment file for loading settings. + env_file_encoding: Encoding of the environment file. + + Examples: + .. code-block:: python + + from agent_framework.anthropic import AnthropicClient + + # Using environment variables + # Set ANTHROPIC_API_KEY=your_anthropic_api_key + # ANTHROPIC_CHAT_MODEL_ID=claude-sonnet-4-5-20250929 + + # Or passing parameters directly + client = AnthropicClient( + model_id="claude-sonnet-4-5-20250929", + api_key="your_anthropic_api_key", + ) + + # Or loading from a .env file + client = AnthropicClient(env_file_path="path/to/.env") + + # Or passing in an existing client + from anthropic import AsyncAnthropic + + anthropic_client = AsyncAnthropic( + api_key="your_anthropic_api_key", base_url="https://custom-anthropic-endpoint.com" + ) + client = AnthropicClient( + model_id="claude-sonnet-4-5-20250929", + anthropic_client=anthropic_client, + ) + + # Using custom ChatOptions with type safety: + from typing import TypedDict + from agent_framework.anthropic import AnthropicChatOptions + + + class MyOptions(AnthropicChatOptions, total=False): + my_custom_option: str + + + client: AnthropicClient[MyOptions] = AnthropicClient(model_id="claude-sonnet-4-5-20250929") + response = await client.get_response("Hello", options={"my_custom_option": "value"}) + """ + super().__init__( + api_key=api_key, + model_id=model_id, + anthropic_client=anthropic_client, + additional_beta_flags=additional_beta_flags, + additional_properties=additional_properties, + middleware=middleware, + function_invocation_configuration=function_invocation_configuration, + env_file_path=env_file_path, + env_file_encoding=env_file_encoding, + ) diff --git a/python/packages/anthropic/tests/test_anthropic_client.py b/python/packages/anthropic/tests/test_anthropic_client.py index 272239b1d7..258cc275ca 100644 --- a/python/packages/anthropic/tests/test_anthropic_client.py +++ b/python/packages/anthropic/tests/test_anthropic_client.py @@ -6,15 +6,18 @@ import pytest from agent_framework import ( + ChatMiddlewareLayer, ChatOptions, ChatResponseUpdate, Content, + FunctionInvocationLayer, Message, SupportsChatGetResponse, tool, ) from agent_framework._settings import load_settings from agent_framework._tools import SHELL_TOOL_KIND_VALUE +from agent_framework.observability import ChatTelemetryLayer from anthropic.types.beta import ( BetaMessage, BetaTextBlock, @@ -23,7 +26,7 @@ ) from pydantic import BaseModel, Field -from agent_framework_anthropic import AnthropicClient +from agent_framework_anthropic import AnthropicClient, RawAnthropicClient from agent_framework_anthropic._chat_client import AnthropicSettings # Test constants @@ -64,6 +67,8 @@ def create_test_anthropic_client( client.additional_beta_flags = [] client.chat_middleware = [] client.function_middleware = [] + client._cached_chat_middleware_pipeline = None + client._cached_function_middleware_pipeline = None client.function_invocation_configuration = normalize_function_invocation_configuration(None) return client @@ -117,6 +122,19 @@ def test_anthropic_client_init_with_client(mock_anthropic_client: MagicMock) -> assert isinstance(client, SupportsChatGetResponse) +def test_anthropic_client_wraps_raw_client_with_standard_layer_order() -> None: + """Test AnthropicClient composes the standard public layer stack around the raw client.""" + assert issubclass(AnthropicClient, RawAnthropicClient) + mro = AnthropicClient.__mro__ + assert mro.index(FunctionInvocationLayer) < mro.index(ChatMiddlewareLayer) + assert mro.index(ChatMiddlewareLayer) < mro.index(ChatTelemetryLayer) + assert mro.index(ChatTelemetryLayer) < mro.index(RawAnthropicClient) + # RawAnthropicClient must not include the convenience layers + assert not issubclass(RawAnthropicClient, FunctionInvocationLayer) + assert not issubclass(RawAnthropicClient, ChatMiddlewareLayer) + assert not issubclass(RawAnthropicClient, ChatTelemetryLayer) + + def test_anthropic_client_init_auto_create_client( anthropic_unit_test_env: dict[str, str], ) -> None: diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py index d349ef3247..63db1663d8 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py @@ -206,8 +206,8 @@ class AzureAIAgentOptions(ChatOptions, total=False): class AzureAIAgentClient( - ChatMiddlewareLayer[AzureAIAgentOptionsT], FunctionInvocationLayer[AzureAIAgentOptionsT], + ChatMiddlewareLayer[AzureAIAgentOptionsT], ChatTelemetryLayer[AzureAIAgentOptionsT], BaseChatClient[AzureAIAgentOptionsT], Generic[AzureAIAgentOptionsT], diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_client.py index 1fc6c7c1c9..34ac6f29a5 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_client.py @@ -97,9 +97,9 @@ class RawAzureAIClient(RawOpenAIResponsesClient[AzureAIClientOptionsT], Generic[ you should consider which additional layers to apply. There is a defined ordering that you should follow: - 1. **ChatMiddlewareLayer** - Should be applied first as it also prepares function middleware - 2. **FunctionInvocationLayer** - Handles tool/function calling loop - 3. **ChatTelemetryLayer** - Must be inside the function calling loop for correct per-call telemetry + 1. **FunctionInvocationLayer** - Owns the tool/function calling loop and routes function middleware + 2. **ChatMiddlewareLayer** - Applies chat middleware per model call and stays outside telemetry + 3. **ChatTelemetryLayer** - Must stay inside chat middleware for correct per-call telemetry Use ``AzureAIClient`` instead for a fully-featured client with all layers applied. """ @@ -1214,8 +1214,8 @@ def as_agent( class AzureAIClient( - ChatMiddlewareLayer[AzureAIClientOptionsT], FunctionInvocationLayer[AzureAIClientOptionsT], + ChatMiddlewareLayer[AzureAIClientOptionsT], ChatTelemetryLayer[AzureAIClientOptionsT], RawAzureAIClient[AzureAIClientOptionsT], Generic[AzureAIClientOptionsT], diff --git a/python/packages/azure-ai/tests/test_azure_ai_agent_client.py b/python/packages/azure-ai/tests/test_azure_ai_agent_client.py index afa073c6ab..65922e76b2 100644 --- a/python/packages/azure-ai/tests/test_azure_ai_agent_client.py +++ b/python/packages/azure-ai/tests/test_azure_ai_agent_client.py @@ -87,6 +87,8 @@ def create_test_azure_ai_chat_client( client.middleware = None client.chat_middleware = [] client.function_middleware = [] + client._cached_chat_middleware_pipeline = None + client._cached_function_middleware_pipeline = None client.otel_provider_name = "azure.ai" client.function_invocation_configuration = { "enabled": True, @@ -151,6 +153,10 @@ def test_azure_ai_chat_client_init_auto_create_client( chat_client.agent_name = None chat_client.additional_properties = {} chat_client.middleware = None + chat_client.chat_middleware = [] + chat_client.function_middleware = [] + chat_client._cached_chat_middleware_pipeline = None + chat_client._cached_function_middleware_pipeline = None assert chat_client.agents_client is mock_agents_client assert chat_client.agent_id is None diff --git a/python/packages/bedrock/agent_framework_bedrock/_chat_client.py b/python/packages/bedrock/agent_framework_bedrock/_chat_client.py index c546ef5535..0aefbe12f3 100644 --- a/python/packages/bedrock/agent_framework_bedrock/_chat_client.py +++ b/python/packages/bedrock/agent_framework_bedrock/_chat_client.py @@ -216,8 +216,8 @@ class BedrockSettings(TypedDict, total=False): class BedrockChatClient( - ChatMiddlewareLayer[BedrockChatOptionsT], FunctionInvocationLayer[BedrockChatOptionsT], + ChatMiddlewareLayer[BedrockChatOptionsT], ChatTelemetryLayer[BedrockChatOptionsT], BaseChatClient[BedrockChatOptionsT], Generic[BedrockChatOptionsT], diff --git a/python/packages/core/agent_framework/_clients.py b/python/packages/core/agent_framework/_clients.py index 4fd563d3e0..66740f5bf8 100644 --- a/python/packages/core/agent_framework/_clients.py +++ b/python/packages/core/agent_framework/_clients.py @@ -966,16 +966,7 @@ def _apply_get_response_docstrings() -> None: from .observability import ChatTelemetryLayer apply_layered_docstring(ChatTelemetryLayer.get_response, BaseChatClient.get_response) - apply_layered_docstring( - FunctionInvocationLayer.get_response, - ChatTelemetryLayer.get_response, - extra_keyword_args={ - "function_middleware": """ - Optional per-call function middleware. - When omitted, middleware configured on the client or forwarded from higher layers is used. - """, - }, - ) + apply_layered_docstring(FunctionInvocationLayer.get_response, ChatTelemetryLayer.get_response) apply_layered_docstring( ChatMiddlewareLayer.get_response, FunctionInvocationLayer.get_response, diff --git a/python/packages/core/agent_framework/_middleware.py b/python/packages/core/agent_framework/_middleware.py index 66845a2e9d..381482b91a 100644 --- a/python/packages/core/agent_framework/_middleware.py +++ b/python/packages/core/agent_framework/_middleware.py @@ -742,12 +742,17 @@ def __init__(self, *middleware: AgentMiddlewareTypes): middleware: The list of agent middleware to include in the pipeline. """ super().__init__() + self._source_middleware: tuple[AgentMiddlewareTypes, ...] = tuple(middleware) self._middleware: list[AgentMiddleware] = [] if middleware: for mdlware in middleware: self._register_middleware(mdlware) + def matches(self, middleware: Sequence[AgentMiddlewareTypes]) -> bool: + """Return whether this pipeline was built from the provided middleware sequence.""" + return self._source_middleware == tuple(middleware) + def _register_middleware(self, middleware: AgentMiddlewareTypes) -> None: """Register an agent middleware item. @@ -824,12 +829,17 @@ def __init__(self, *middleware: FunctionMiddlewareTypes): middleware: The list of function middleware to include in the pipeline. """ super().__init__() + self._source_middleware: tuple[FunctionMiddlewareTypes, ...] = tuple(middleware) self._middleware: list[FunctionMiddleware] = [] if middleware: for mdlware in middleware: self._register_middleware(mdlware) + def matches(self, middleware: Sequence[FunctionMiddlewareTypes]) -> bool: + """Return whether this pipeline was built from the provided middleware sequence.""" + return self._source_middleware == tuple(middleware) + def _register_middleware(self, middleware: FunctionMiddlewareTypes) -> None: """Register a function middleware item. @@ -892,12 +902,17 @@ def __init__(self, *middleware: ChatMiddlewareTypes): middleware: The list of chat middleware to include in the pipeline. """ super().__init__() + self._source_middleware: tuple[ChatMiddlewareTypes, ...] = tuple(middleware) self._middleware: list[ChatMiddleware] = [] if middleware: for mdlware in middleware: self._register_middleware(mdlware) + def matches(self, middleware: Sequence[ChatMiddlewareTypes]) -> bool: + """Return whether this pipeline was built from the provided middleware sequence.""" + return self._source_middleware == tuple(middleware) + def _register_middleware(self, middleware: ChatMiddlewareTypes) -> None: """Register a chat middleware item. @@ -980,16 +995,26 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]): def __init__( self, *, - middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, + middleware: Sequence[ChatMiddlewareTypes] | None = None, **kwargs: Any, ) -> None: - middleware_list = categorize_middleware(*(middleware or [])) - self.chat_middleware = middleware_list["chat"] - if "function_middleware" in kwargs and middleware_list["function"]: - raise ValueError("Cannot specify 'function_middleware' and 'middleware' at the same time.") - kwargs["function_middleware"] = middleware_list["function"] + self.chat_middleware = list(middleware) if middleware else [] + self._cached_chat_middleware_pipeline: ChatMiddlewarePipeline | None = None super().__init__(**kwargs) + def _get_chat_middleware_pipeline( + self, + middleware: Sequence[ChatMiddlewareTypes], + ) -> ChatMiddlewarePipeline: + effective_middleware = [*self.chat_middleware, *middleware] + if self._cached_chat_middleware_pipeline is not None and self._cached_chat_middleware_pipeline.matches( + effective_middleware + ): + return self._cached_chat_middleware_pipeline + + self._cached_chat_middleware_pipeline = ChatMiddlewarePipeline(*effective_middleware) + return self._cached_chat_middleware_pipeline + @overload def get_response( self, @@ -1052,14 +1077,8 @@ def get_response( kwargs["tokenizer"] = tokenizer effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {} - call_middleware = kwargs.pop("middleware", effective_client_kwargs.pop("middleware", [])) - middleware = categorize_middleware(call_middleware) - effective_client_kwargs["function_middleware"] = middleware["function"] - - pipeline = ChatMiddlewarePipeline( - *self.chat_middleware, - *middleware["chat"], - ) + call_middleware = effective_client_kwargs.pop("middleware", []) + pipeline = self._get_chat_middleware_pipeline(call_middleware) # type: ignore[reportUnknownArgumentType] if not pipeline.has_middlewares: return super_get_response( # type: ignore[no-any-return] messages=messages, @@ -1134,12 +1153,25 @@ def __init__( ) -> None: middleware_list = categorize_middleware(middleware) self.agent_middleware = middleware_list["agent"] + self._cached_agent_middleware_pipeline: AgentMiddlewarePipeline | None = None # Pass middleware to super so BaseAgent can store it for dynamic rebuild super().__init__(*args, middleware=middleware, **kwargs) # type: ignore[call-arg] # Note: We intentionally don't extend client's middleware lists here. # Chat and function middleware is passed to the chat client at runtime via kwargs # in AgentMiddlewareLayer.run(), where it's properly combined with run-level middleware. + def _get_agent_middleware_pipeline( + self, + middleware: Sequence[AgentMiddlewareTypes], + ) -> AgentMiddlewarePipeline: + if self._cached_agent_middleware_pipeline is not None and self._cached_agent_middleware_pipeline.matches( + middleware + ): + return self._cached_agent_middleware_pipeline + + self._cached_agent_middleware_pipeline = AgentMiddlewarePipeline(*middleware) + return self._cached_agent_middleware_pipeline + @overload def run( self, @@ -1210,7 +1242,7 @@ def run( ) base_middleware_list = categorize_middleware(base_middleware) run_middleware_list = categorize_middleware(middleware) - pipeline = AgentMiddlewarePipeline(*base_middleware_list["agent"], *run_middleware_list["agent"]) + pipeline = self._get_agent_middleware_pipeline([*base_middleware_list["agent"], *run_middleware_list["agent"]]) # Combine base and run-level function/chat middleware for forwarding to chat client combined_function_chat_middleware = ( @@ -1392,7 +1424,7 @@ def categorize_middleware( all_middleware: list[Any] = [] for source in middleware_sources: if source: - if isinstance(source, list): + if isinstance(source, Sequence) and not isinstance(source, (str, bytes)): all_middleware.extend(source) # type: ignore else: all_middleware.append(source) diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index c9810771de..cf7384588f 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -63,7 +63,12 @@ from ._clients import SupportsChatGetResponse from ._compaction import CompactionStrategy, TokenizerProtocol from ._mcp import MCPTool - from ._middleware import FunctionInvocationContext, FunctionMiddlewarePipeline, FunctionMiddlewareTypes + from ._middleware import ( + ChatAndFunctionMiddlewareTypes, + FunctionInvocationContext, + FunctionMiddlewarePipeline, + FunctionMiddlewareTypes, + ) from ._sessions import AgentSession from ._types import ( ChatOptions, @@ -2024,18 +2029,37 @@ class FunctionInvocationLayer(Generic[OptionsCoT]): def __init__( self, *, - function_middleware: Sequence[FunctionMiddlewareTypes] | None = None, + middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, function_invocation_configuration: FunctionInvocationConfiguration | None = None, **kwargs: Any, ) -> None: - self.function_middleware: list[FunctionMiddlewareTypes] = ( - list(function_middleware) if function_middleware else [] - ) + from ._middleware import categorize_middleware + + middleware_list = categorize_middleware(middleware) + self.function_middleware: list[FunctionMiddlewareTypes] = list(middleware_list["function"]) + self._cached_function_middleware_pipeline: FunctionMiddlewarePipeline | None = None self.function_invocation_configuration = normalize_function_invocation_configuration( function_invocation_configuration ) + if (chat_middleware := (middleware_list["chat"] or None)) is not None: + kwargs["middleware"] = chat_middleware super().__init__(**kwargs) + def _get_function_middleware_pipeline( + self, + middleware: Sequence[FunctionMiddlewareTypes], + ) -> FunctionMiddlewarePipeline: + from ._middleware import FunctionMiddlewarePipeline + + effective_middleware = [*self.function_middleware, *middleware] + if self._cached_function_middleware_pipeline is not None and self._cached_function_middleware_pipeline.matches( + effective_middleware + ): + return self._cached_function_middleware_pipeline + + self._cached_function_middleware_pipeline = FunctionMiddlewarePipeline(*effective_middleware) + return self._cached_function_middleware_pipeline + @overload def get_response( self, @@ -2043,6 +2067,7 @@ def get_response( *, stream: Literal[False] = ..., options: ChatOptions[ResponseModelBoundT], + middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, function_invocation_kwargs: Mapping[str, Any] | None = None, @@ -2057,6 +2082,7 @@ def get_response( *, stream: Literal[False] = ..., options: OptionsCoT | ChatOptions[None] | None = None, + middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, function_invocation_kwargs: Mapping[str, Any] | None = None, @@ -2071,6 +2097,7 @@ def get_response( *, stream: Literal[True], options: OptionsCoT | ChatOptions[Any] | None = None, + middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, function_invocation_kwargs: Mapping[str, Any] | None = None, @@ -2084,14 +2111,14 @@ def get_response( *, stream: bool = False, options: OptionsCoT | ChatOptions[Any] | None = None, - function_middleware: Sequence[FunctionMiddlewareTypes] | None = None, + middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, function_invocation_kwargs: Mapping[str, Any] | None = None, client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: - from ._middleware import FunctionMiddlewarePipeline + from ._middleware import categorize_middleware from ._types import ( ChatResponse, ChatResponseUpdate, @@ -2109,16 +2136,21 @@ def get_response( ) effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {} - effective_function_middleware = function_middleware - if effective_function_middleware is None: - middleware_from_client_kwargs = effective_client_kwargs.pop("function_middleware", None) - if middleware_from_client_kwargs is not None: - effective_function_middleware = cast(Sequence[Any], middleware_from_client_kwargs) - - # ChatMiddleware adds this kwarg - function_middleware_pipeline = FunctionMiddlewarePipeline( - *(self.function_middleware), *(effective_function_middleware or []) - ) + if middleware is not None: + existing = effective_client_kwargs.get("middleware", []) + effective_client_kwargs["middleware"] = [ + *( + existing + if isinstance(existing, Sequence) and not isinstance(existing, (str, bytes)) + else [existing] + ), + *middleware, + ] + runtime_middleware = categorize_middleware(effective_client_kwargs.pop("middleware", [])) + + function_middleware_pipeline = self._get_function_middleware_pipeline(runtime_middleware["function"]) + if runtime_middleware["chat"]: + effective_client_kwargs["middleware"] = runtime_middleware["chat"] max_errors = self.function_invocation_configuration.get( "max_consecutive_errors_per_request", DEFAULT_MAX_CONSECUTIVE_ERRORS_PER_REQUEST ) diff --git a/python/packages/core/agent_framework/_workflows/_viz.py b/python/packages/core/agent_framework/_workflows/_viz.py index 0fcf8af32d..54015b066c 100644 --- a/python/packages/core/agent_framework/_workflows/_viz.py +++ b/python/packages/core/agent_framework/_workflows/_viz.py @@ -109,7 +109,7 @@ def export( # Create a temporary graphviz Source object dot_content = self.to_digraph(include_internal_executors=include_internal_executors) - source = graphviz.Source(dot_content) + source = graphviz.Source(dot_content) # type: ignore[reportUnknownVariableType] try: if filename: @@ -131,7 +131,7 @@ def export( source.render(base_name, format=format, cleanup=True) # type: ignore return f"{base_name}.{format}" - except graphviz.backend.execute.ExecutableNotFound as e: + except graphviz.backend.execute.ExecutableNotFound as e: # type: ignore raise ImportError( "The graphviz executables are not found. The graphviz Python package is installed, but the " "graphviz executables (dot, neato, etc.) are not available on your system's PATH. " diff --git a/python/packages/core/agent_framework/azure/_chat_client.py b/python/packages/core/agent_framework/azure/_chat_client.py index 2ae21d124c..ef598ebe21 100644 --- a/python/packages/core/agent_framework/azure/_chat_client.py +++ b/python/packages/core/agent_framework/azure/_chat_client.py @@ -152,8 +152,8 @@ class AzureOpenAIChatOptions(OpenAIChatOptions[ResponseModelT], Generic[Response class AzureOpenAIChatClient( # type: ignore[misc] AzureOpenAIConfigMixin, - ChatMiddlewareLayer[AzureOpenAIChatOptionsT], FunctionInvocationLayer[AzureOpenAIChatOptionsT], + ChatMiddlewareLayer[AzureOpenAIChatOptionsT], ChatTelemetryLayer[AzureOpenAIChatOptionsT], RawOpenAIChatClient[AzureOpenAIChatOptionsT], Generic[AzureOpenAIChatOptionsT], diff --git a/python/packages/core/agent_framework/azure/_responses_client.py b/python/packages/core/agent_framework/azure/_responses_client.py index 192576bd04..8387e49591 100644 --- a/python/packages/core/agent_framework/azure/_responses_client.py +++ b/python/packages/core/agent_framework/azure/_responses_client.py @@ -51,8 +51,8 @@ class AzureOpenAIResponsesClient( # type: ignore[misc] AzureOpenAIConfigMixin, - ChatMiddlewareLayer[AzureOpenAIResponsesOptionsT], FunctionInvocationLayer[AzureOpenAIResponsesOptionsT], + ChatMiddlewareLayer[AzureOpenAIResponsesOptionsT], ChatTelemetryLayer[AzureOpenAIResponsesOptionsT], RawOpenAIResponsesClient[AzureOpenAIResponsesOptionsT], Generic[AzureOpenAIResponsesOptionsT], diff --git a/python/packages/core/agent_framework/observability.py b/python/packages/core/agent_framework/observability.py index dcabaae8fc..a0cbd6a1a0 100644 --- a/python/packages/core/agent_framework/observability.py +++ b/python/packages/core/agent_framework/observability.py @@ -362,11 +362,15 @@ def _create_otlp_exporters( if protocol == "grpc": # Import all gRPC exporters try: - from opentelemetry.exporter.otlp.proto.grpc._log_exporter import OTLPLogExporter as GRPCLogExporter - from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import ( - OTLPMetricExporter as GRPCMetricExporter, + from opentelemetry.exporter.otlp.proto.grpc._log_exporter import ( # type: ignore[reportMissingImports] + OTLPLogExporter as GRPCLogExporter, # type: ignore[reportUnknownVariableType] + ) + from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import ( # type: ignore[reportMissingImports] + OTLPMetricExporter as GRPCMetricExporter, # type: ignore[reportUnknownVariableType] + ) + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import ( # type: ignore[reportMissingImports] + OTLPSpanExporter as GRPCSpanExporter, # type: ignore[reportUnknownVariableType] ) - from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCSpanExporter except ImportError as exc: raise ImportError( "opentelemetry-exporter-otlp-proto-grpc is required for OTLP gRPC exporters. " @@ -375,21 +379,21 @@ def _create_otlp_exporters( if actual_logs_endpoint: exporters.append( - GRPCLogExporter( + GRPCLogExporter( # type: ignore[reportUnknownArgumentType] endpoint=actual_logs_endpoint, headers=actual_logs_headers if actual_logs_headers else None, ) ) if actual_traces_endpoint: exporters.append( - GRPCSpanExporter( + GRPCSpanExporter( # type: ignore[reportUnknownArgumentType] endpoint=actual_traces_endpoint, headers=actual_traces_headers if actual_traces_headers else None, ) ) if actual_metrics_endpoint: exporters.append( - GRPCMetricExporter( + GRPCMetricExporter( # type: ignore[reportUnknownArgumentType] endpoint=actual_metrics_endpoint, headers=actual_metrics_headers if actual_metrics_headers else None, ) diff --git a/python/packages/core/agent_framework/openai/_assistants_client.py b/python/packages/core/agent_framework/openai/_assistants_client.py index b1d5e8795c..9179fb4a8c 100644 --- a/python/packages/core/agent_framework/openai/_assistants_client.py +++ b/python/packages/core/agent_framework/openai/_assistants_client.py @@ -210,8 +210,8 @@ class OpenAIAssistantsOptions(ChatOptions[ResponseModelT], Generic[ResponseModel class OpenAIAssistantsClient( # type: ignore[misc] OpenAIConfigMixin, - ChatMiddlewareLayer[OpenAIAssistantsOptionsT], FunctionInvocationLayer[OpenAIAssistantsOptionsT], + ChatMiddlewareLayer[OpenAIAssistantsOptionsT], ChatTelemetryLayer[OpenAIAssistantsOptionsT], BaseChatClient[OpenAIAssistantsOptionsT], Generic[OpenAIAssistantsOptionsT], diff --git a/python/packages/core/agent_framework/openai/_chat_client.py b/python/packages/core/agent_framework/openai/_chat_client.py index 6df57fe428..e69cad7e3f 100644 --- a/python/packages/core/agent_framework/openai/_chat_client.py +++ b/python/packages/core/agent_framework/openai/_chat_client.py @@ -31,7 +31,7 @@ from .._clients import BaseChatClient from .._docstrings import apply_layered_docstring -from .._middleware import ChatAndFunctionMiddlewareTypes, ChatMiddlewareLayer, FunctionMiddlewareTypes +from .._middleware import ChatAndFunctionMiddlewareTypes, ChatMiddlewareLayer from .._settings import load_settings from .._tools import ( FunctionInvocationConfiguration, @@ -156,9 +156,9 @@ class RawOpenAIChatClient( # type: ignore[misc] you should consider which additional layers to apply. There is a defined ordering that you should follow: - 1. **ChatMiddlewareLayer** - Should be applied first as it also prepares function middleware - 2. **FunctionInvocationLayer** - Handles tool/function calling loop - 3. **ChatTelemetryLayer** - Must be inside the function calling loop for correct per-call telemetry + 1. **FunctionInvocationLayer** - Owns the tool/function calling loop and routes function middleware + 2. **ChatMiddlewareLayer** - Applies chat middleware per model call and stays outside telemetry + 3. **ChatTelemetryLayer** - Must stay inside chat middleware for correct per-call telemetry Use ``OpenAIChatClient`` instead for a fully-featured client with all layers applied. """ @@ -772,8 +772,8 @@ def service_url(self) -> str: class OpenAIChatClient( # type: ignore[misc] OpenAIConfigMixin, - ChatMiddlewareLayer[OpenAIChatOptionsT], FunctionInvocationLayer[OpenAIChatOptionsT], + ChatMiddlewareLayer[OpenAIChatOptionsT], ChatTelemetryLayer[OpenAIChatOptionsT], RawOpenAIChatClient[OpenAIChatOptionsT], Generic[OpenAIChatOptionsT], @@ -787,7 +787,6 @@ def get_response( *, stream: Literal[False] = ..., options: ChatOptions[ResponseModelBoundT], - function_middleware: Sequence[FunctionMiddlewareTypes] | None = None, function_invocation_kwargs: Mapping[str, Any] | None = None, client_kwargs: Mapping[str, Any] | None = None, middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, @@ -801,7 +800,6 @@ def get_response( *, stream: Literal[False] = ..., options: OpenAIChatOptionsT | ChatOptions[None] | None = None, - function_middleware: Sequence[FunctionMiddlewareTypes] | None = None, function_invocation_kwargs: Mapping[str, Any] | None = None, client_kwargs: Mapping[str, Any] | None = None, middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, @@ -815,7 +813,6 @@ def get_response( *, stream: Literal[True], options: OpenAIChatOptionsT | ChatOptions[Any] | None = None, - function_middleware: Sequence[FunctionMiddlewareTypes] | None = None, function_invocation_kwargs: Mapping[str, Any] | None = None, client_kwargs: Mapping[str, Any] | None = None, middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, @@ -829,7 +826,6 @@ def get_response( *, stream: bool = False, options: OpenAIChatOptionsT | ChatOptions[Any] | None = None, - function_middleware: Sequence[FunctionMiddlewareTypes] | None = None, function_invocation_kwargs: Mapping[str, Any] | None = None, client_kwargs: Mapping[str, Any] | None = None, middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, @@ -840,14 +836,15 @@ def get_response( "Callable[..., Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]]", super().get_response, # type: ignore[misc] ) + effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {} + if middleware is not None: + effective_client_kwargs["middleware"] = middleware return super_get_response( # type: ignore[no-any-return] messages=messages, stream=stream, options=options, - function_middleware=function_middleware, function_invocation_kwargs=function_invocation_kwargs, - client_kwargs=client_kwargs, - middleware=middleware, + client_kwargs=effective_client_kwargs, **kwargs, ) @@ -963,10 +960,6 @@ def _apply_openai_chat_client_docstrings() -> None: OpenAIChatClient.get_response, RawOpenAIChatClient.get_response, extra_keyword_args={ - "function_middleware": """ - Optional per-call function middleware. - When omitted, middleware configured on the client or forwarded from higher layers is used. - """, "middleware": """ Optional per-call chat and function middleware. This is merged with any middleware configured on the client for the current request. diff --git a/python/packages/core/agent_framework/openai/_responses_client.py b/python/packages/core/agent_framework/openai/_responses_client.py index 0769c3f1f9..0c57dffb39 100644 --- a/python/packages/core/agent_framework/openai/_responses_client.py +++ b/python/packages/core/agent_framework/openai/_responses_client.py @@ -249,9 +249,9 @@ class RawOpenAIResponsesClient( # type: ignore[misc] you should consider which additional layers to apply. There is a defined ordering that you should follow: - 1. **ChatMiddlewareLayer** - Should be applied first as it also prepares function middleware - 2. **FunctionInvocationLayer** - Handles tool/function calling loop - 3. **ChatTelemetryLayer** - Must be inside the function calling loop for correct per-call telemetry + 1. **FunctionInvocationLayer** - Owns the tool/function calling loop and routes function middleware + 2. **ChatMiddlewareLayer** - Applies chat middleware per model call and stays outside telemetry + 3. **ChatTelemetryLayer** - Must stay inside chat middleware for correct per-call telemetry Use ``OpenAIResponsesClient`` instead for a fully-featured client with all layers applied. """ @@ -2259,8 +2259,8 @@ def _get_metadata_from_response(self, output: Any) -> dict[str, Any]: class OpenAIResponsesClient( # type: ignore[misc] OpenAIConfigMixin, - ChatMiddlewareLayer[OpenAIResponsesOptionsT], FunctionInvocationLayer[OpenAIResponsesOptionsT], + ChatMiddlewareLayer[OpenAIResponsesOptionsT], ChatTelemetryLayer[OpenAIResponsesOptionsT], RawOpenAIResponsesClient[OpenAIResponsesOptionsT], Generic[OpenAIResponsesOptionsT], diff --git a/python/packages/core/tests/core/conftest.py b/python/packages/core/tests/core/conftest.py index 2d1eec2d9a..57c0cf5217 100644 --- a/python/packages/core/tests/core/conftest.py +++ b/python/packages/core/tests/core/conftest.py @@ -128,8 +128,8 @@ def _finalize(updates: Sequence[ChatResponseUpdate]) -> ChatResponse: class MockBaseChatClient( - ChatMiddlewareLayer[OptionsCoT], FunctionInvocationLayer[OptionsCoT], + ChatMiddlewareLayer[OptionsCoT], ChatTelemetryLayer[OptionsCoT], BaseChatClient[OptionsCoT], Generic[OptionsCoT], @@ -137,7 +137,7 @@ class MockBaseChatClient( """Mock implementation of a full-featured ChatClient.""" def __init__(self, **kwargs: Any): - super().__init__(function_middleware=[], **kwargs) + super().__init__(middleware=[], **kwargs) self.run_responses: list[ChatResponse] = [] self.streaming_responses: list[list[ChatResponseUpdate]] = [] self.call_count: int = 0 diff --git a/python/packages/core/tests/core/test_clients.py b/python/packages/core/tests/core/test_clients.py index 7e150c47c6..258a31d73b 100644 --- a/python/packages/core/tests/core/test_clients.py +++ b/python/packages/core/tests/core/test_clients.py @@ -74,8 +74,8 @@ def test_openai_chat_client_get_response_docstring_surfaces_layered_runtime_docs assert docstring is not None assert "Get a response from a chat client." in docstring assert "function_invocation_kwargs" in docstring - assert "function_middleware: Optional per-call function middleware." in docstring assert "middleware: Optional per-call chat and function middleware." in docstring + assert "function_middleware: Optional per-call function middleware." not in docstring def test_openai_chat_client_get_response_is_defined_on_openai_class() -> None: @@ -84,7 +84,6 @@ def test_openai_chat_client_get_response_is_defined_on_openai_class() -> None: signature = inspect.signature(OpenAIChatClient.get_response) assert OpenAIChatClient.get_response.__qualname__ == "OpenAIChatClient.get_response" - assert "function_middleware" in signature.parameters assert "middleware" in signature.parameters diff --git a/python/packages/core/tests/core/test_function_invocation_logic.py b/python/packages/core/tests/core/test_function_invocation_logic.py index 3c61040289..d9659837a8 100644 --- a/python/packages/core/tests/core/test_function_invocation_logic.py +++ b/python/packages/core/tests/core/test_function_invocation_logic.py @@ -3226,7 +3226,7 @@ def ai_func(arg1: str) -> str: response = await chat_client_base.get_response( "hello", options={"tool_choice": "auto", "tools": [ai_func]}, - middleware=[TerminateLoopMiddleware()], + client_kwargs={"middleware": [TerminateLoopMiddleware()]}, ) # Function should NOT have been executed - middleware intercepted it @@ -3292,7 +3292,7 @@ def terminating_func(arg1: str) -> str: response = await chat_client_base.get_response( "hello", options={"tool_choice": "auto", "tools": [normal_func, terminating_func]}, - middleware=[SelectiveTerminateMiddleware()], + client_kwargs={"middleware": [SelectiveTerminateMiddleware()]}, ) # normal_function should have executed (middleware calls next_handler) @@ -3345,7 +3345,7 @@ def ai_func(arg1: str) -> str: async for update in chat_client_base.get_response( "hello", options={"tool_choice": "auto", "tools": [ai_func]}, - middleware=[TerminateLoopMiddleware()], + client_kwargs={"middleware": [TerminateLoopMiddleware()]}, stream=True, ): updates.append(update) @@ -3389,12 +3389,12 @@ async def test_conversation_id_updated_in_options_between_tool_iterations(): conversation_ids_received: list[str | None] = [] class TrackingChatClient( - ChatMiddlewareLayer, FunctionInvocationLayer, + ChatMiddlewareLayer, BaseChatClient, ): def __init__(self) -> None: - super().__init__(function_middleware=[]) + super().__init__(middleware=[]) self.run_responses: list[ChatResponse] = [] self.streaming_responses: list[list[ChatResponseUpdate]] = [] self.call_count: int = 0 diff --git a/python/packages/core/tests/core/test_kwargs_propagation_to_ai_function.py b/python/packages/core/tests/core/test_kwargs_propagation_to_ai_function.py index 160ea0fcc4..11a738a0b9 100644 --- a/python/packages/core/tests/core/test_kwargs_propagation_to_ai_function.py +++ b/python/packages/core/tests/core/test_kwargs_propagation_to_ai_function.py @@ -84,8 +84,8 @@ def _finalize(updates: Sequence[ChatResponseUpdate]) -> ChatResponse: class FunctionInvokingMockClient( - ChatMiddlewareLayer[Any], FunctionInvocationLayer[Any], + ChatMiddlewareLayer[Any], ChatTelemetryLayer[Any], _MockBaseChatClient, ): diff --git a/python/packages/core/tests/core/test_middleware.py b/python/packages/core/tests/core/test_middleware.py index 6c559c40d4..0026cbf98f 100644 --- a/python/packages/core/tests/core/test_middleware.py +++ b/python/packages/core/tests/core/test_middleware.py @@ -28,6 +28,7 @@ FunctionMiddleware, FunctionMiddlewarePipeline, MiddlewareTermination, + categorize_middleware, ) from agent_framework._tools import FunctionTool @@ -1681,3 +1682,49 @@ def mock_chat_client() -> Any: client = MagicMock(spec=SupportsChatGetResponse) client.service_url = MagicMock(return_value="mock://test") return client + + +class TestCategorizeMiddleware: + """Test cases for categorize_middleware.""" + + def test_categorize_middleware_with_tuple(self) -> None: + """Test that tuple middleware sources are unpacked, not appended as a single item.""" + chat_mw = TestChatMiddleware() + function_mw = TestFunctionMiddleware() + agent_mw = TestAgentMiddleware() + result = categorize_middleware((chat_mw, function_mw, agent_mw)) + assert result["chat"] == [chat_mw] + assert result["function"] == [function_mw] + assert result["agent"] == [agent_mw] + + def test_categorize_middleware_with_list(self) -> None: + """Test that list middleware sources are unpacked correctly.""" + chat_mw = TestChatMiddleware() + function_mw = TestFunctionMiddleware() + result = categorize_middleware([chat_mw, function_mw]) + assert result["chat"] == [chat_mw] + assert result["function"] == [function_mw] + assert result["agent"] == [] + + def test_categorize_middleware_with_none(self) -> None: + """Test that None middleware sources are handled.""" + result = categorize_middleware(None) + assert result["chat"] == [] + assert result["function"] == [] + assert result["agent"] == [] + + def test_categorize_middleware_with_single_item(self) -> None: + """Test that a single unwrapped middleware item is appended correctly.""" + chat_mw = TestChatMiddleware() + result = categorize_middleware(chat_mw) + assert result["chat"] == [chat_mw] + assert result["function"] == [] + assert result["agent"] == [] + + def test_categorize_middleware_with_string_does_not_decompose(self) -> None: + """Test that a string is not decomposed character-by-character.""" + result = categorize_middleware("not_a_middleware") + # String should be treated as a single item, not decomposed into characters + total_items = len(result["chat"]) + len(result["function"]) + len(result["agent"]) + assert total_items == 1 + assert result["agent"] == ["not_a_middleware"] diff --git a/python/packages/core/tests/core/test_middleware_with_agent.py b/python/packages/core/tests/core/test_middleware_with_agent.py index bfe5ec1293..6470a8202e 100644 --- a/python/packages/core/tests/core/test_middleware_with_agent.py +++ b/python/packages/core/tests/core/test_middleware_with_agent.py @@ -697,6 +697,26 @@ async def process( assert function_calls[0].name == "sample_tool_function" assert function_results[0].call_id == function_calls[0].call_id + def test_agent_middleware_pipeline_cache_reuses_matching_middleware(self) -> None: + """Test that identical agent middleware sets reuse the cached pipeline.""" + + @agent_middleware + async def first_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: + await call_next() + + @agent_middleware + async def second_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: + await call_next() + + agent = Agent(client=MockBaseChatClient()) + + first_pipeline = agent._get_agent_middleware_pipeline([first_middleware]) + second_pipeline = agent._get_agent_middleware_pipeline([first_middleware]) + third_pipeline = agent._get_agent_middleware_pipeline([second_middleware]) + + assert first_pipeline is second_pipeline + assert third_pipeline is not first_pipeline + async def test_function_middleware_can_access_and_override_custom_kwargs( self, chat_client_base: "MockBaseChatClient" ) -> None: @@ -1969,6 +1989,77 @@ async def function_middleware( "agent_middleware_after", ] + async def test_combined_middleware_with_tool_loop(self) -> None: + """Test Agent middleware ordering when tool calls trigger multiple chat rounds.""" + execution_order: list[str] = [] + chat_round = 0 + client = MockBaseChatClient() + client.run_responses = [ + ChatResponse( + messages=[ + Message( + role="assistant", + contents=[ + Content.from_function_call( + call_id="call_123", + name="sample_tool_function", + arguments='{"location": "Seattle"}', + ) + ], + ) + ] + ), + ChatResponse(messages=[Message(role="assistant", text="Final response")]), + ] + + async def tracking_agent_middleware( + context: AgentContext, + call_next: Callable[[], Awaitable[None]], + ) -> None: + execution_order.append("agent_middleware_before") + await call_next() + execution_order.append("agent_middleware_after") + + async def tracking_chat_middleware( + context: ChatContext, + call_next: Callable[[], Awaitable[None]], + ) -> None: + nonlocal chat_round + chat_round += 1 + execution_order.append(f"chat_middleware_before_{chat_round}") + await call_next() + execution_order.append(f"chat_middleware_after_{chat_round}") + + async def tracking_function_middleware( + context: FunctionInvocationContext, + call_next: Callable[[], Awaitable[None]], + ) -> None: + execution_order.append("function_middleware_before") + await call_next() + execution_order.append("function_middleware_after") + + agent = Agent( + client=client, + middleware=[tracking_chat_middleware, tracking_function_middleware, tracking_agent_middleware], + tools=[sample_tool_function], + ) + + response = await agent.run([Message(role="user", text="test")]) + + assert response is not None + assert client.call_count == 2 + assert response.messages[-1].text == "Final response" + assert execution_order == [ + "agent_middleware_before", + "chat_middleware_before_1", + "chat_middleware_after_1", + "function_middleware_before", + "function_middleware_after", + "chat_middleware_before_2", + "chat_middleware_after_2", + "agent_middleware_after", + ] + async def test_agent_middleware_can_access_and_override_custom_kwargs(self) -> None: """Test that agent middleware can access and override custom parameters like temperature.""" captured_kwargs: dict[str, Any] = {} diff --git a/python/packages/core/tests/core/test_middleware_with_chat.py b/python/packages/core/tests/core/test_middleware_with_chat.py index 62a168ccb0..5fa9d64031 100644 --- a/python/packages/core/tests/core/test_middleware_with_chat.py +++ b/python/packages/core/tests/core/test_middleware_with_chat.py @@ -274,7 +274,10 @@ async def counting_middleware(context: ChatContext, call_next: Callable[[], Awai # First call with run-level middleware messages = [Message(role="user", text="first message")] - response1 = await chat_client_base.get_response(messages, middleware=[counting_middleware]) + response1 = await chat_client_base.get_response( + messages, + client_kwargs={"middleware": [counting_middleware]}, + ) assert response1 is not None assert execution_count["count"] == 1 @@ -286,7 +289,10 @@ async def counting_middleware(context: ChatContext, call_next: Callable[[], Awai # Third call with run-level middleware again - should execute messages = [Message(role="user", text="third message")] - response3 = await chat_client_base.get_response(messages, middleware=[counting_middleware]) + response3 = await chat_client_base.get_response( + messages, + client_kwargs={"middleware": [counting_middleware]}, + ) assert response3 is not None assert execution_count["count"] == 2 # Should be 2 now @@ -335,6 +341,81 @@ async def kwargs_middleware(context: ChatContext, call_next: Callable[[], Awaita assert modified_kwargs["new_param"] == "added_by_middleware" assert modified_kwargs["custom_param"] == "test_value" # Should still be there + def test_chat_middleware_pipeline_cache_reuses_matching_middleware( + self, + chat_client_base: "MockBaseChatClient", + ) -> None: + """Test that identical chat middleware sets reuse the cached pipeline.""" + + @chat_middleware + async def first_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: + await call_next() + + @chat_middleware + async def second_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: + await call_next() + + first_pipeline = chat_client_base._get_chat_middleware_pipeline([first_middleware]) + second_pipeline = chat_client_base._get_chat_middleware_pipeline([first_middleware]) + third_pipeline = chat_client_base._get_chat_middleware_pipeline([second_middleware]) + + assert first_pipeline is second_pipeline + assert third_pipeline is not first_pipeline + + def test_chat_middleware_pipeline_cache_includes_base_middleware( + self, + chat_client_base: "MockBaseChatClient", + ) -> None: + """Test that chat middleware cache key includes base middleware to prevent incorrect reuse.""" + + @chat_middleware + async def base_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: + await call_next() + + @chat_middleware + async def runtime_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: + await call_next() + + # Without base middleware + pipeline_no_base = chat_client_base._get_chat_middleware_pipeline([runtime_middleware]) + + # With base middleware + chat_client_base.chat_middleware = [base_middleware] + pipeline_with_base = chat_client_base._get_chat_middleware_pipeline([runtime_middleware]) + + assert pipeline_with_base is not pipeline_no_base + + def test_function_middleware_pipeline_cache_reuses_matching_middleware( + self, + chat_client_base: "MockBaseChatClient", + ) -> None: + """Test that identical function middleware sets reuse the cached pipeline.""" + + @function_middleware + async def base_middleware(context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]) -> None: + await call_next() + + @function_middleware + async def first_runtime_middleware( + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] + ) -> None: + await call_next() + + @function_middleware + async def second_runtime_middleware( + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] + ) -> None: + await call_next() + + chat_client_base.function_middleware = [base_middleware] + + first_pipeline = chat_client_base._get_function_middleware_pipeline([first_runtime_middleware]) + second_pipeline = chat_client_base._get_function_middleware_pipeline([first_runtime_middleware]) + third_pipeline = chat_client_base._get_function_middleware_pipeline([second_runtime_middleware]) + + assert first_pipeline is second_pipeline + assert third_pipeline is not first_pipeline + async def test_function_middleware_registration_on_chat_client( self, chat_client_base: "MockBaseChatClient" ) -> None: @@ -450,7 +531,9 @@ def sample_tool(location: str) -> str: # Execute the chat client directly with run-level middleware and tools messages = [Message(role="user", text="What's the weather in New York?")] response = await client.get_response( - messages, options={"tools": [sample_tool_wrapped]}, middleware=[run_level_function_middleware] + messages, + options={"tools": [sample_tool_wrapped]}, + client_kwargs={"middleware": [run_level_function_middleware]}, ) # Verify response @@ -463,3 +546,156 @@ def sample_tool(location: str) -> str: "run_level_function_middleware_before", "run_level_function_middleware_after", ] + + async def test_run_level_chat_and_function_middleware_split_per_function_loop_round(self) -> None: + """Test mixed run-level middleware is split so chat middleware runs per model call.""" + execution_order: list[str] = [] + chat_round = 0 + + @chat_middleware + async def run_level_chat_middleware( + context: ChatContext, + call_next: Callable[[], Awaitable[None]], + ) -> None: + nonlocal chat_round + chat_round += 1 + execution_order.append(f"chat_middleware_before_{chat_round}") + await call_next() + execution_order.append(f"chat_middleware_after_{chat_round}") + + @function_middleware + async def run_level_function_middleware( + context: FunctionInvocationContext, + call_next: Callable[[], Awaitable[None]], + ) -> None: + execution_order.append("function_middleware_before") + await call_next() + execution_order.append("function_middleware_after") + + def sample_tool(location: str) -> str: + """Get weather for a location.""" + return f"Weather in {location}: sunny" + + sample_tool_wrapped = FunctionTool( + func=sample_tool, + name="sample_tool", + description="Get weather for a location", + approval_mode="never_require", + ) + + client = MockBaseChatClient() + client.run_responses = [ + ChatResponse( + messages=[ + Message( + role="assistant", + contents=[ + Content.from_function_call( + call_id="call_3", + name="sample_tool", + arguments={"location": "Seattle"}, + ) + ], + ) + ] + ), + ChatResponse(messages=[Message(role="assistant", text="Based on the weather data, it's sunny!")]), + ] + + response = await client.get_response( + [Message(role="user", text="What's the weather in Seattle?")], + options={"tools": [sample_tool_wrapped]}, + client_kwargs={"middleware": [run_level_chat_middleware, run_level_function_middleware]}, + ) + + assert response is not None + assert client.call_count == 2 + assert response.messages[-1].text == "Based on the weather data, it's sunny!" + assert execution_order == [ + "chat_middleware_before_1", + "chat_middleware_after_1", + "function_middleware_before", + "function_middleware_after", + "chat_middleware_before_2", + "chat_middleware_after_2", + ] + + async def test_run_level_chat_and_function_middleware_split_per_function_loop_round_streaming(self) -> None: + """Test mixed run-level middleware is split so chat middleware runs per model call in streaming mode.""" + execution_order: list[str] = [] + chat_round = 0 + + @chat_middleware + async def run_level_chat_middleware( + context: ChatContext, + call_next: Callable[[], Awaitable[None]], + ) -> None: + nonlocal chat_round + chat_round += 1 + execution_order.append(f"chat_middleware_before_{chat_round}") + await call_next() + execution_order.append(f"chat_middleware_after_{chat_round}") + + @function_middleware + async def run_level_function_middleware( + context: FunctionInvocationContext, + call_next: Callable[[], Awaitable[None]], + ) -> None: + execution_order.append("function_middleware_before") + await call_next() + execution_order.append("function_middleware_after") + + def sample_tool(location: str) -> str: + """Get weather for a location.""" + return f"Weather in {location}: sunny" + + sample_tool_wrapped = FunctionTool( + func=sample_tool, + name="sample_tool", + description="Get weather for a location", + approval_mode="never_require", + ) + + client = MockBaseChatClient() + client.streaming_responses = [ + [ + ChatResponseUpdate( + contents=[ + Content.from_function_call( + call_id="call_3", + name="sample_tool", + arguments='{"location": "Seattle"}', + ) + ], + role="assistant", + finish_reason="tool_calls", + ), + ], + [ + ChatResponseUpdate( + contents=[Content.from_text("Based on the weather data, it's sunny!")], + role="assistant", + finish_reason="stop", + ), + ], + ] + + updates: list[ChatResponseUpdate] = [] + async for update in client.get_response( + [Message(role="user", text="What's the weather in Seattle?")], + options={"tools": [sample_tool_wrapped]}, + client_kwargs={"middleware": [run_level_chat_middleware, run_level_function_middleware]}, + stream=True, + ): + updates.append(update) + + assert client.call_count == 2 + assert len(updates) > 0 + assert execution_order == [ + "chat_middleware_before_1", + "chat_middleware_after_1", + "function_middleware_before", + "function_middleware_after", + "chat_middleware_before_2", + "chat_middleware_after_2", + ] diff --git a/python/packages/core/tests/core/test_observability.py b/python/packages/core/tests/core/test_observability.py index 367e32bf92..7982985b94 100644 --- a/python/packages/core/tests/core/test_observability.py +++ b/python/packages/core/tests/core/test_observability.py @@ -2437,7 +2437,7 @@ def test_capture_response(span_exporter: InMemorySpanExporter): async def test_layer_ordering_span_sequence_with_function_calling(span_exporter: InMemorySpanExporter): """Test that with correct layer ordering, spans appear in the expected sequence. - When using the correct layer ordering (ChatMiddlewareLayer, FunctionInvocationLayer, + When using the correct layer ordering (FunctionInvocationLayer, ChatMiddlewareLayer, ChatTelemetryLayer, BaseChatClient), the spans should appear in this order: 1. First 'chat' span (initial LLM call that returns function call) 2. 'execute_tool' span (function invocation) @@ -2454,11 +2454,11 @@ async def test_layer_ordering_span_sequence_with_function_calling(span_exporter: def get_weather(location: str) -> str: return f"The weather in {location} is sunny." - # Correct layer ordering: FunctionInvocationLayer BEFORE ChatTelemetryLayer - # This ensures each inner LLM call gets its own telemetry span + # Correct layer ordering: FunctionInvocationLayer BEFORE ChatMiddlewareLayer BEFORE ChatTelemetryLayer + # This ensures each inner LLM call traverses chat middleware and still gets its own telemetry span class MockChatClientWithLayers( - ChatMiddlewareLayer, FunctionInvocationLayer, + ChatMiddlewareLayer, ChatTelemetryLayer, BaseChatClient, ): diff --git a/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py b/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py index 4c1e64cd7c..2566d031aa 100644 --- a/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py +++ b/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py @@ -130,8 +130,8 @@ class FoundryLocalSettings(TypedDict, total=False): class FoundryLocalClient( - ChatMiddlewareLayer[FoundryLocalChatOptionsT], FunctionInvocationLayer[FoundryLocalChatOptionsT], + ChatMiddlewareLayer[FoundryLocalChatOptionsT], ChatTelemetryLayer[FoundryLocalChatOptionsT], RawOpenAIChatClient[FoundryLocalChatOptionsT], Generic[FoundryLocalChatOptionsT], diff --git a/python/packages/lab/gaia/agent_framework_lab_gaia/gaia.py b/python/packages/lab/gaia/agent_framework_lab_gaia/gaia.py index 07b1945882..266ae8a107 100644 --- a/python/packages/lab/gaia/agent_framework_lab_gaia/gaia.py +++ b/python/packages/lab/gaia/agent_framework_lab_gaia/gaia.py @@ -273,7 +273,7 @@ def _load_gaia_local(repo_dir: Path, wanted_levels: list[int] | None = None, max for p in parquet_files: try: - import pyarrow.parquet as pq + import pyarrow.parquet as pq # type: ignore[reportMissingImports] pq_any = cast(Any, pq) table: Any = pq_any.read_table(p) diff --git a/python/packages/lab/lightning/agent_framework_lab_lightning/__init__.py b/python/packages/lab/lightning/agent_framework_lab_lightning/__init__.py index 9526498cc2..3da1121910 100644 --- a/python/packages/lab/lightning/agent_framework_lab_lightning/__init__.py +++ b/python/packages/lab/lightning/agent_framework_lab_lightning/__init__.py @@ -7,8 +7,8 @@ import importlib.metadata from agent_framework.observability import enable_instrumentation -from agentlightning.tracer import ( - AgentOpsTracer, # pyright: ignore[reportMissingImports] # type: ignore[import-not-found] +from agentlightning.tracer import ( # type: ignore[reportMissingImports] + AgentOpsTracer, # type: ignore[reportMissingImports, import-not-found] ) try: diff --git a/python/packages/ollama/agent_framework_ollama/_chat_client.py b/python/packages/ollama/agent_framework_ollama/_chat_client.py index b931c89499..0c7f232797 100644 --- a/python/packages/ollama/agent_framework_ollama/_chat_client.py +++ b/python/packages/ollama/agent_framework_ollama/_chat_client.py @@ -285,8 +285,8 @@ class OllamaSettings(TypedDict, total=False): class OllamaChatClient( - ChatMiddlewareLayer[OllamaChatOptionsT], FunctionInvocationLayer[OllamaChatOptionsT], + ChatMiddlewareLayer[OllamaChatOptionsT], ChatTelemetryLayer[OllamaChatOptionsT], BaseChatClient[OllamaChatOptionsT], ): diff --git a/python/packages/orchestrations/tests/test_handoff.py b/python/packages/orchestrations/tests/test_handoff.py index 43c2f9153a..5c594ed537 100644 --- a/python/packages/orchestrations/tests/test_handoff.py +++ b/python/packages/orchestrations/tests/test_handoff.py @@ -33,7 +33,7 @@ from agent_framework_orchestrations._orchestrator_helpers import clean_conversation_for_handoff -class MockChatClient(ChatMiddlewareLayer[Any], FunctionInvocationLayer[Any], BaseChatClient[Any]): +class MockChatClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]): """Mock chat client for testing handoff workflows.""" def __init__( @@ -134,7 +134,7 @@ def __init__( super().__init__(client=MockChatClient(name=name, handoff_to=handoff_to), name=name, id=name) -class ContextAwareRefundClient(ChatMiddlewareLayer[Any], FunctionInvocationLayer[Any], BaseChatClient[Any]): +class ContextAwareRefundClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]): """Mock client that expects prior user context to remain available on resume.""" def __init__(self) -> None: @@ -298,7 +298,7 @@ def submit_refund_counted() -> str: execution_count += 1 return "ok" - class ApprovalReplayClient(ChatMiddlewareLayer[Any], FunctionInvocationLayer[Any], BaseChatClient[Any]): + class ApprovalReplayClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]): def __init__(self) -> None: ChatMiddlewareLayer.__init__(self) FunctionInvocationLayer.__init__(self) @@ -383,7 +383,7 @@ async def test_handoff_resume_preserves_approval_function_call_for_stateless_run def submit_refund() -> str: return "ok" - class StrictStatelessApprovalClient(ChatMiddlewareLayer[Any], FunctionInvocationLayer[Any], BaseChatClient[Any]): + class StrictStatelessApprovalClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]): def __init__(self) -> None: ChatMiddlewareLayer.__init__(self) FunctionInvocationLayer.__init__(self) @@ -475,7 +475,7 @@ async def _get() -> ChatResponse: async def test_handoff_replay_serializes_handoff_function_results() -> None: """Returning to the same agent must not replay dict tool outputs.""" - class ReplaySafeHandoffClient(ChatMiddlewareLayer[Any], FunctionInvocationLayer[Any], BaseChatClient[Any]): + class ReplaySafeHandoffClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]): def __init__(self, name: str, handoff_sequence: list[str | None]) -> None: ChatMiddlewareLayer.__init__(self) FunctionInvocationLayer.__init__(self) @@ -550,7 +550,7 @@ async def test_handoff_resume_preserves_approved_tool_output_for_stateless_runs( def submit_refund() -> str: return "submitted" - class RefundReplayClient(ChatMiddlewareLayer[Any], FunctionInvocationLayer[Any], BaseChatClient[Any]): + class RefundReplayClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]): def __init__(self) -> None: ChatMiddlewareLayer.__init__(self) FunctionInvocationLayer.__init__(self) @@ -608,7 +608,7 @@ async def _get() -> ChatResponse: return _get() - class OrderReplayClient(ChatMiddlewareLayer[Any], FunctionInvocationLayer[Any], BaseChatClient[Any]): + class OrderReplayClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]): def __init__(self) -> None: ChatMiddlewareLayer.__init__(self) FunctionInvocationLayer.__init__(self) @@ -907,7 +907,7 @@ async def async_termination(conv: list[Message]) -> bool: async def test_handoff_terminates_without_request_info_when_latest_response_meets_condition() -> None: """Termination triggered by the latest assistant response should not emit request_info.""" - class FinalizingClient(ChatMiddlewareLayer[Any], FunctionInvocationLayer[Any], BaseChatClient[Any]): + class FinalizingClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]): def __init__(self) -> None: ChatMiddlewareLayer.__init__(self) FunctionInvocationLayer.__init__(self) diff --git a/python/samples/02-agents/auto_retry.py b/python/samples/02-agents/auto_retry.py index 7c985bd0c1..0a5169ad3d 100644 --- a/python/samples/02-agents/auto_retry.py +++ b/python/samples/02-agents/auto_retry.py @@ -114,10 +114,11 @@ class RetryingAzureOpenAIChatClient(AzureOpenAIChatClient): class RateLimitRetryMiddleware(ChatMiddleware): - """Chat middleware that retries the full request pipeline on rate limit errors. + """Chat middleware that retries a single model-call pipeline on rate limit errors. Register this middleware on an agent (or at the run level) to automatically - retry any call_next() invocation that raises RateLimitError. + retry any chat-model call that raises RateLimitError. In tool-loop scenarios, + the middleware applies independently to each inner model call. """ def __init__(self, *, max_attempts: int = RETRY_ATTEMPTS) -> None: @@ -154,8 +155,9 @@ async def rate_limit_retry_middleware( """Function-based chat middleware that retries on rate limit errors. Wrap call_next() with a tenacity @retry decorator so any RateLimitError - raised during model inference triggers an automatic retry with exponential - back-off. + raised during a single model call triggers an automatic retry with exponential + back-off. In tool-loop scenarios, the middleware applies independently to + each inner model call. """ @retry( diff --git a/python/samples/02-agents/chat_client/custom_chat_client.py b/python/samples/02-agents/chat_client/custom_chat_client.py index 5adcf50d15..7a9aaa95f6 100644 --- a/python/samples/02-agents/chat_client/custom_chat_client.py +++ b/python/samples/02-agents/chat_client/custom_chat_client.py @@ -29,7 +29,10 @@ Custom Chat Client Implementation Example This sample demonstrates implementing a custom chat client and optionally composing -middleware, telemetry, and function invocation layers explicitly. +middleware, telemetry, and function invocation layers explicitly. The recommended +layer order is `FunctionInvocationLayer -> ChatMiddlewareLayer -> ChatTelemetryLayer` +so chat middleware runs within each tool-loop iteration while telemetry records +per-call spans without middleware latency. """ @@ -124,9 +127,9 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: class EchoingChatClientWithLayers( # type: ignore[misc] + FunctionInvocationLayer[OptionsT], ChatMiddlewareLayer[OptionsT], ChatTelemetryLayer[OptionsT], - FunctionInvocationLayer[OptionsT], EchoingChatClient, ): """Echoing chat client that explicitly composes middleware, telemetry, and function layers.""" diff --git a/python/samples/02-agents/middleware/README.md b/python/samples/02-agents/middleware/README.md new file mode 100644 index 0000000000..754f96e815 --- /dev/null +++ b/python/samples/02-agents/middleware/README.md @@ -0,0 +1,37 @@ +# Middleware samples + +This folder contains focused middleware samples for `Agent`, chat clients, tools, sessions, and runtime context behavior. + +## Files + +| File | Description | +|------|-------------| +| [`agent_and_run_level_middleware.py`](./agent_and_run_level_middleware.py) | Demonstrates combining agent-level and run-level middleware. | +| [`chat_middleware.py`](./chat_middleware.py) | Shows class-based and function-based chat middleware that can observe, modify, and override model calls. | +| [`class_based_middleware.py`](./class_based_middleware.py) | Shows class-based agent and function middleware. | +| [`decorator_middleware.py`](./decorator_middleware.py) | Demonstrates middleware registration with decorators. | +| [`exception_handling_with_middleware.py`](./exception_handling_with_middleware.py) | Shows how middleware can handle failures and recover cleanly. | +| [`function_based_middleware.py`](./function_based_middleware.py) | Shows function-based agent and function middleware. | +| [`middleware_termination.py`](./middleware_termination.py) | Demonstrates stopping a middleware pipeline early. | +| [`override_result_with_middleware.py`](./override_result_with_middleware.py) | Shows how middleware can replace the normal result. | +| [`runtime_context_delegation.py`](./runtime_context_delegation.py) | Demonstrates delegating work with runtime context data. | +| [`session_behavior_middleware.py`](./session_behavior_middleware.py) | Shows how middleware interacts with session-backed runs. | +| [`shared_state_middleware.py`](./shared_state_middleware.py) | Demonstrates sharing mutable state across middleware invocations. | +| [`usage_tracking_middleware.py`](./usage_tracking_middleware.py) | Demonstrates one chat middleware function that tracks per-call usage in non-streaming and streaming tool-loop runs. | + +## Running the usage tracking sample + +The new usage tracking sample uses `OpenAIResponsesClient`, so set the usual OpenAI responses environment variables first: + +```bash +export OPENAI_API_KEY="your-openai-api-key" +export OPENAI_RESPONSES_MODEL_ID="gpt-4.1-mini" +``` + +Then run: + +```bash +uv run samples/02-agents/middleware/usage_tracking_middleware.py +``` + +The sample forces a tool call so you can see middleware output for each inner model call in both non-streaming and streaming modes. diff --git a/python/samples/02-agents/middleware/agent_and_run_level_middleware.py b/python/samples/02-agents/middleware/agent_and_run_level_middleware.py index 55ccce3507..158d90daee 100644 --- a/python/samples/02-agents/middleware/agent_and_run_level_middleware.py +++ b/python/samples/02-agents/middleware/agent_and_run_level_middleware.py @@ -51,10 +51,10 @@ - Run middleware wraps only the agent for that specific run - Each middleware can modify the context before AND after calling next() - Note: Function and chat middleware (e.g., ``function_logging_middleware``) execute - during tool invocation *inside* the agent execution, not in the outer agent-middleware - chain shown above. They follow the same ordering principle: agent-level function/chat - middleware runs before run-level function/chat middleware. + Note: Function middleware executes during tool invocation, and chat middleware + executes around each model call inside the agent execution, not in the outer + agent-middleware chain shown above. They follow the same ordering principle: + agent-level function/chat middleware runs before run-level function/chat middleware. """ diff --git a/python/samples/02-agents/middleware/usage_tracking_middleware.py b/python/samples/02-agents/middleware/usage_tracking_middleware.py new file mode 100644 index 0000000000..877d2a8a82 --- /dev/null +++ b/python/samples/02-agents/middleware/usage_tracking_middleware.py @@ -0,0 +1,185 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +This sample demonstrates a single chat middleware that tracks per-model-call usage +for both non-streaming and streaming tool-loop runs. +""" + +import asyncio +from collections.abc import Awaitable, Callable +from random import randint +from typing import Annotated + +from agent_framework import ( + Agent, + ChatContext, + ChatResponse, + ChatResponseUpdate, + ResponseStream, + chat_middleware, + tool, +) +from agent_framework.openai import OpenAIResponsesClient +from dotenv import load_dotenv +from pydantic import Field + +# Load environment variables from .env file +load_dotenv() + + +NON_STREAMING_CALL_COUNT = 0 +STREAMING_CALL_COUNT = 0 + + +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; +# see samples/02-agents/tools/function_tool_with_approval.py +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. +@tool(approval_mode="never_require") +def get_weather( + location: Annotated[str, Field(description="The location to get the weather for.")], +) -> str: + """Get the weather for a given location.""" + conditions = ["sunny", "cloudy", "rainy", "stormy"] + return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C." + + +def _reset_usage_counters() -> None: + """Reset call counters between sample runs.""" + global NON_STREAMING_CALL_COUNT, STREAMING_CALL_COUNT + NON_STREAMING_CALL_COUNT = 0 + STREAMING_CALL_COUNT = 0 + + +def _create_agent( +) -> Agent: + """Create the shared agent used by both demonstrations.""" + return Agent( + client=OpenAIResponsesClient(), + instructions=( + "You are a weather assistant. Always call the weather tool before answering weather questions, " + "then summarize the tool result in one short paragraph." + ), + tools=[get_weather], + middleware=[print_usage], + ) + + +@chat_middleware +async def print_usage( + context: ChatContext, + call_next: Callable[[], Awaitable[None]], +) -> None: + """Print usage for each inner model call in both non-streaming and streaming runs.""" + global NON_STREAMING_CALL_COUNT, STREAMING_CALL_COUNT + + if context.stream: + STREAMING_CALL_COUNT += 1 + call_number = STREAMING_CALL_COUNT + usage_seen_in_updates = False + + def capture_usage_update(update: ChatResponseUpdate) -> ChatResponseUpdate: + nonlocal usage_seen_in_updates + + for content in update.contents: + if content.type == "usage": + usage_seen_in_updates = True + print(f"\n[Streaming model call #{call_number}] Usage update: {content.usage_details}") + return update + + def capture_final_usage(result: ChatResponse) -> ChatResponse: + if not usage_seen_in_updates and result.usage_details: + print(f"\n[Streaming model call #{call_number}] Final usage: {result.usage_details}") + return result + + context.stream_transform_hooks.append(capture_usage_update) + context.stream_result_hooks.append(capture_final_usage) + await call_next() + return + + NON_STREAMING_CALL_COUNT += 1 + call_number = NON_STREAMING_CALL_COUNT + + await call_next() + + response = context.result + if isinstance(response, ChatResponse) and response.usage_details: + print(f"[Non-streaming model call #{call_number}] Usage: {response.usage_details}") + + +async def non_streaming_usage_example() -> None: + """Run the non-streaming usage tracking example.""" + _reset_usage_counters() + print("\n=== Non-streaming per-call usage tracking ===") + + # 1. Create an agent with middleware that prints usage after each inner model call. + agent = _create_agent() + + # 2. Run a weather question and require a tool call so the function loop performs multiple model calls. + query = "What is the weather in Seattle, and should I bring an umbrella?" + print(f"User: {query}") + result = await agent.run( + query, + options={"tool_choice": "required"}, + ) + + # 3. Print the final user-visible answer after the middleware already logged per-call usage. + print(f"Assistant: {result.text}") + + +async def streaming_usage_example() -> None: + """Run the streaming usage tracking example.""" + _reset_usage_counters() + print("\n=== Streaming per-call usage tracking ===") + + # 1. Create an agent with middleware that watches streaming usage for each inner model call. + agent = _create_agent() + + # 2. Start a streaming run and force tool usage so the function loop performs multiple model calls. + query = "What is the weather in Portland, and should I bring a jacket?" + print(f"User: {query}") + print("Assistant: ", end="", flush=True) + stream: ResponseStream = agent.run( + query, + stream=True, + options={"tool_choice": "required"}, + ) + + # 3. Consume the stream normally while the middleware reports usage in the background. + async for update in stream: + if update.text: + print(update.text, end="", flush=True) + print() + + # 4. Finalize the stream so you can inspect the final response if needed. + final_response = await stream.get_final_response() + print(f"Final assistant message: {final_response.text}") + + +async def main() -> None: + """Run both usage tracking demonstrations.""" + print("=== Usage Tracking Middleware Example ===") + + await non_streaming_usage_example() + await streaming_usage_example() + + +if __name__ == "__main__": + asyncio.run(main()) + +""" +Sample output: +=== Usage Tracking Middleware Example === + +=== Non-streaming per-call usage tracking === +User: What is the weather in Seattle, and should I bring an umbrella? +[Non-streaming model call #1] Usage: {'input_tokens': ..., 'output_tokens': ..., ...} +[Non-streaming model call #2] Usage: {'input_tokens': ..., 'output_tokens': ..., ...} +Assistant: Based on the weather in Seattle, ... + +=== Streaming per-call usage tracking === +User: What is the weather in Portland, and should I bring a jacket? +Assistant: Based on the weather in Portland, ... +[Streaming model call #1] Usage update: {'input_tokens': ..., 'output_tokens': ..., ...} +[Streaming model call #2] Usage update: {'input_tokens': ..., 'output_tokens': ..., ...} +Final assistant message: Based on the weather in Portland, ... +""" diff --git a/python/samples/02-agents/observability/advanced_manual_setup_console_output.py b/python/samples/02-agents/observability/advanced_manual_setup_console_output.py index 7afd359264..af7fcc6287 100644 --- a/python/samples/02-agents/observability/advanced_manual_setup_console_output.py +++ b/python/samples/02-agents/observability/advanced_manual_setup_console_output.py @@ -96,10 +96,16 @@ async def run_chat_client() -> None: stream: Whether to use streaming for the plugin Remarks: - When function calling is outside the open telemetry loop - each of the call to the model is handled as a seperate span, - while when the open telemetry is put last, a single span - is shown, which might include one or more rounds of function calling. + By default, the built-in non-`Raw...Client` chat clients already compose + the layers in this order: + `FunctionInvocationLayer -> ChatMiddlewareLayer -> ChatTelemetryLayer -> Raw/Base client`. + + When `FunctionInvocationLayer` is outside `ChatTelemetryLayer`, + each call to the model is handled as a separate span. + Keep `ChatMiddlewareLayer` outside telemetry + so middleware latency does not skew those timings. + By contrast, when telemetry is placed outside the function loop, + a single span can cover one or more rounds of function calling. So for the scenario below, you should see the following: diff --git a/python/samples/02-agents/observability/advanced_zero_code.py b/python/samples/02-agents/observability/advanced_zero_code.py index 477a5b4d9b..981b14a0e6 100644 --- a/python/samples/02-agents/observability/advanced_zero_code.py +++ b/python/samples/02-agents/observability/advanced_zero_code.py @@ -71,10 +71,12 @@ async def run_chat_client(client: "SupportsChatGetResponse", stream: bool = Fals stream: Whether to use streaming for the plugin Remarks: - When function calling is outside the open telemetry loop - each of the call to the model is handled as a separate span, - while when the open telemetry is put last, a single span - is shown, which might include one or more rounds of function calling. + When `FunctionInvocationLayer` is outside `ChatTelemetryLayer`, + each call to the model is handled as a separate span. + If `ChatMiddlewareLayer` is present, keep it outside telemetry + so middleware latency does not skew those timings. + By contrast, when telemetry is placed outside the function loop, + a single span can cover one or more rounds of function calling. So for the scenario below, you should see the following: diff --git a/python/samples/02-agents/providers/custom/README.md b/python/samples/02-agents/providers/custom/README.md index f2d67e0315..ac58a77e69 100644 --- a/python/samples/02-agents/providers/custom/README.md +++ b/python/samples/02-agents/providers/custom/README.md @@ -37,17 +37,17 @@ The framework provides `Raw...Client` classes (e.g., `RawOpenAIChatClient`, `Raw There is a defined ordering for applying layers that you should follow: -1. **ChatMiddlewareLayer** - Should be applied **first** because it also prepares function middleware -2. **FunctionInvocationLayer** - Handles tool/function calling loop -3. **ChatTelemetryLayer** - Must be **inside** the function calling loop for correct per-call telemetry +1. **FunctionInvocationLayer** - Handles the tool/function calling loop and should stay outermost +2. **ChatMiddlewareLayer** - Wraps each model call in the loop and stays outside telemetry +3. **ChatTelemetryLayer** - Must be inside the function calling loop so each model call gets its own telemetry span 4. **Raw...Client** - The base implementation (e.g., `RawOpenAIChatClient`) Example of correct layer composition: ```python class MyCustomClient( - ChatMiddlewareLayer[TOptions], FunctionInvocationLayer[TOptions], + ChatMiddlewareLayer[TOptions], ChatTelemetryLayer[TOptions], RawOpenAIChatClient[TOptions], # or BaseChatClient for custom implementations Generic[TOptions], diff --git a/python/samples/05-end-to-end/hosted_agents/agent_with_local_tools/main.py b/python/samples/05-end-to-end/hosted_agents/agent_with_local_tools/main.py index 1faad6a2e9..4c60902dc2 100644 --- a/python/samples/05-end-to-end/hosted_agents/agent_with_local_tools/main.py +++ b/python/samples/05-end-to-end/hosted_agents/agent_with_local_tools/main.py @@ -16,7 +16,6 @@ from azure.identity.aio import AzureCliCredential, ManagedIdentityCredential from dotenv import load_dotenv - load_dotenv(override=True) # Configure these for your Foundry project