diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_client.py b/python/packages/ag-ui/agent_framework_ag_ui/_client.py index 7c0f7b8cf9..74bb50e306 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_client.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_client.py @@ -28,25 +28,21 @@ from ._message_adapters import agent_framework_messages_to_agui from ._utils import convert_tools_to_agui_format -if TYPE_CHECKING: - from ._types import AGUIChatOptions - -from typing import TypedDict - if sys.version_info >= (3, 13): - from typing import TypeVar + from typing import TypeVar # type: ignore # pragma: no cover else: - from typing_extensions import TypeVar - + from typing_extensions import TypeVar # type: ignore # pragma: no cover if sys.version_info >= (3, 12): from typing import override # type: ignore # pragma: no cover else: from typing_extensions import override # type: ignore[import] # pragma: no cover - if sys.version_info >= (3, 11): - from typing import Self # pragma: no cover + from typing import Self, TypedDict # pragma: no cover else: - from typing_extensions import Self # pragma: no cover + from typing_extensions import Self, TypedDict # pragma: no cover + +if TYPE_CHECKING: + from ._types import AGUIChatOptions logger: logging.Logger = logging.getLogger(__name__) @@ -85,7 +81,7 @@ async def streaming_wrapper(self: Any, *args: Any, **kwargs: Any) -> AsyncIterab @wraps(original_get_response) async def response_wrapper(self: Any, *args: Any, **kwargs: Any) -> ChatResponse: - response = await original_get_response(self, *args, **kwargs) + response: ChatResponse[Any] = await original_get_response(self, *args, **kwargs) # type: ignore[var-annotated] if response.messages: for message in response.messages: _unwrap_server_function_call_contents(cast(MutableSequence[Content | dict[str, Any]], message.contents)) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_types.py b/python/packages/ag-ui/agent_framework_ag_ui/_types.py index a80cd155d2..eb7124208a 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_types.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_types.py @@ -3,15 +3,23 @@ """Type definitions for AG-UI integration.""" import sys -from typing import Any, TypedDict +from typing import Any, Generic from agent_framework import ChatOptions from pydantic import BaseModel, Field if sys.version_info >= (3, 13): - from typing import TypeVar + from typing import TypeVar # type: ignore # pragma: no cover else: - from typing_extensions import TypeVar + from typing_extensions import TypeVar # type: ignore # pragma: no cover +if sys.version_info >= (3, 11): + from typing import TypedDict # type: ignore # pragma: no cover +else: + from typing_extensions import TypedDict # type: ignore # pragma: no cover + + +TAGUIChatOptions = TypeVar("TAGUIChatOptions", bound=TypedDict, default="AGUIChatOptions", covariant=True) # type: ignore[valid-type] +TResponseModel = TypeVar("TResponseModel", bound=BaseModel | None, default=None) class PredictStateConfig(TypedDict): @@ -76,7 +84,7 @@ class AGUIRequest(BaseModel): # region AG-UI Chat Options TypedDict -class AGUIChatOptions(ChatOptions, total=False): +class AGUIChatOptions(ChatOptions[TResponseModel], Generic[TResponseModel], total=False): """AG-UI protocol-specific chat options dict. Extends base ChatOptions for the AG-UI (Agent-UI) protocol. @@ -140,7 +148,5 @@ class AGUIChatOptions(ChatOptions, total=False): AGUI_OPTION_TRANSLATIONS: dict[str, str] = {} """Maps ChatOptions keys to AG-UI parameter names (protocol uses standard names).""" -TAGUIChatOptions = TypeVar("TAGUIChatOptions", bound=TypedDict, default="AGUIChatOptions", covariant=True) # type: ignore[valid-type] - # endregion diff --git a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/ui_generator_agent.py b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/ui_generator_agent.py index c158dd7749..01b333e7f4 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/ui_generator_agent.py +++ b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/ui_generator_agent.py @@ -12,6 +12,10 @@ from typing import TypeVar # type: ignore # pragma: no cover else: from typing_extensions import TypeVar # type: ignore # pragma: no cover +if sys.version_info >= (3, 11): + from typing import TypedDict # type: ignore # pragma: no cover +else: + from typing_extensions import TypedDict # type: ignore # pragma: no cover if TYPE_CHECKING: from agent_framework import ChatOptions diff --git a/python/packages/anthropic/agent_framework_anthropic/_chat_client.py b/python/packages/anthropic/agent_framework_anthropic/_chat_client.py index 0413e8ab3c..630b92ca02 100644 --- a/python/packages/anthropic/agent_framework_anthropic/_chat_client.py +++ b/python/packages/anthropic/agent_framework_anthropic/_chat_client.py @@ -2,7 +2,7 @@ import sys from collections.abc import AsyncIterable, MutableMapping, MutableSequence, Sequence -from typing import Any, ClassVar, Final, Generic, Literal, TypedDict +from typing import Any, ClassVar, Final, Generic, Literal from agent_framework import ( AGENT_FRAMEWORK_USER_AGENT, @@ -47,15 +47,18 @@ ) from pydantic import BaseModel, SecretStr, ValidationError +if sys.version_info >= (3, 11): + from typing import TypedDict # type: ignore # pragma: no cover +else: + from typing_extensions import TypedDict # type: ignore # pragma: no cover if sys.version_info >= (3, 13): - from typing import TypeVar + from typing import TypeVar # type: ignore # pragma: no cover else: - from typing_extensions import TypeVar - + from typing_extensions import TypeVar # type: ignore # pragma: no cover if sys.version_info >= (3, 12): from typing import override # type: ignore # pragma: no cover else: - from typing_extensions import override # type: ignore[import] # pragma: no cover + from typing_extensions import override # type: ignore # pragma: no cover __all__ = [ "AnthropicChatOptions", @@ -69,6 +72,8 @@ BETA_FLAGS: Final[list[str]] = ["mcp-client-2025-04-04", "code-execution-2025-08-25"] STRUCTURED_OUTPUTS_BETA_FLAG: Final[str] = "structured-outputs-2025-11-13" +TResponseModel = TypeVar("TResponseModel", bound=BaseModel | None, default=None) + # region Anthropic Chat Options TypedDict @@ -91,7 +96,7 @@ class ThinkingConfig(TypedDict, total=False): budget_tokens: int -class AnthropicChatOptions(ChatOptions, total=False): +class AnthropicChatOptions(ChatOptions[TResponseModel], Generic[TResponseModel], total=False): """Anthropic-specific chat options. Extends ChatOptions with options specific to Anthropic's Messages API. diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_agent_provider.py b/python/packages/azure-ai/agent_framework_azure_ai/_agent_provider.py index e5e9410ecf..b064294a7c 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_agent_provider.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_agent_provider.py @@ -2,7 +2,7 @@ import sys from collections.abc import Callable, MutableMapping, Sequence -from typing import TYPE_CHECKING, Any, Generic, TypedDict, cast +from typing import TYPE_CHECKING, Any, Generic, cast from agent_framework import ( AGENT_FRAMEWORK_USER_AGENT, @@ -27,9 +27,13 @@ from ._chat_client import AzureAIAgentOptions if sys.version_info >= (3, 13): - from typing import Self, TypeVar # pragma: no cover + from typing import Self, TypeVar # type: ignore # pragma: no cover else: - from typing_extensions import Self, TypeVar # pragma: no cover + from typing_extensions import Self, TypeVar # type: ignore # pragma: no cover +if sys.version_info >= (3, 11): + from typing import TypedDict # type: ignore # pragma: no cover +else: + from typing_extensions import TypedDict # type: ignore # pragma: no cover # Type variable for options - allows typed ChatAgent[TOptions] returns diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py index 5626ead9a6..c54334aaef 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py @@ -6,7 +6,7 @@ import re import sys from collections.abc import AsyncIterable, Callable, Mapping, MutableMapping, MutableSequence, Sequence -from typing import Any, ClassVar, Generic, TypedDict +from typing import Any, ClassVar, Generic from agent_framework import ( AGENT_FRAMEWORK_USER_AGENT, @@ -96,9 +96,9 @@ else: from typing_extensions import override # type: ignore[import] # pragma: no cover if sys.version_info >= (3, 11): - from typing import Self # pragma: no cover + from typing import Self, TypedDict # type: ignore # pragma: no cover else: - from typing_extensions import Self # pragma: no cover + from typing_extensions import Self, TypedDict # type: ignore # pragma: no cover logger = get_logger("agent_framework.azure") @@ -1265,7 +1265,7 @@ def as_agent( | MutableMapping[str, Any] | Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] | None = None, - default_options: TAzureAIAgentOptions | None = None, + default_options: TAzureAIAgentOptions | Mapping[str, Any] | None = None, chat_message_store_factory: Callable[[], ChatMessageStoreProtocol] | None = None, context_provider: ContextProvider | None = None, middleware: Sequence[Middleware] | None = None, diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_client.py index 08623c3aa4..b70cdeafdc 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_client.py @@ -2,7 +2,7 @@ import sys from collections.abc import Callable, Mapping, MutableMapping, MutableSequence, Sequence -from typing import Any, ClassVar, Generic, TypedDict, TypeVar, cast +from typing import Any, ClassVar, Generic, TypeVar, cast from agent_framework import ( AGENT_FRAMEWORK_USER_AGENT, @@ -38,9 +38,9 @@ else: from typing_extensions import override # type: ignore[import] # pragma: no cover if sys.version_info >= (3, 11): - from typing import Self # pragma: no cover + from typing import Self, TypedDict # type: ignore # pragma: no cover else: - from typing_extensions import Self # pragma: no cover + from typing_extensions import Self, TypedDict # type: ignore # pragma: no cover logger = get_logger("agent_framework.azure") @@ -551,7 +551,7 @@ def as_agent( | MutableMapping[str, Any] | Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] | None = None, - default_options: TAzureAIClientOptions | None = None, + default_options: TAzureAIClientOptions | Mapping[str, Any] | None = None, chat_message_store_factory: Callable[[], ChatMessageStoreProtocol] | None = None, context_provider: ContextProvider | None = None, middleware: Sequence[Middleware] | None = None, diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_project_provider.py b/python/packages/azure-ai/agent_framework_azure_ai/_project_provider.py index dc0d8ea279..fa1d80da21 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_project_provider.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_project_provider.py @@ -2,7 +2,7 @@ import sys from collections.abc import Callable, MutableMapping, Sequence -from typing import Any, Generic, TypedDict +from typing import Any, Generic from agent_framework import ( AGENT_FRAMEWORK_USER_AGENT, @@ -33,9 +33,13 @@ from ._shared import AzureAISettings, create_text_format_config, from_azure_ai_tools, to_azure_ai_tools if sys.version_info >= (3, 13): - from typing import Self, TypeVar # pragma: no cover + from typing import TypeVar # type: ignore # pragma: no cover else: - from typing_extensions import Self, TypeVar # pragma: no cover + from typing_extensions import TypeVar # type: ignore # pragma: no cover +if sys.version_info >= (3, 11): + from typing import Self, TypedDict # type: ignore # pragma: no cover +else: + from typing_extensions import Self, TypedDict # type: ignore # pragma: no cover logger = get_logger("agent_framework.azure") diff --git a/python/packages/bedrock/agent_framework_bedrock/_chat_client.py b/python/packages/bedrock/agent_framework_bedrock/_chat_client.py index a58d68e077..d7e0754c2b 100644 --- a/python/packages/bedrock/agent_framework_bedrock/_chat_client.py +++ b/python/packages/bedrock/agent_framework_bedrock/_chat_client.py @@ -5,7 +5,7 @@ import sys from collections import deque from collections.abc import AsyncIterable, MutableMapping, MutableSequence, Sequence -from typing import Any, ClassVar, Generic, Literal, TypedDict +from typing import Any, ClassVar, Generic, Literal from uuid import uuid4 from agent_framework import ( @@ -33,17 +33,20 @@ from boto3.session import Session as Boto3Session from botocore.client import BaseClient from botocore.config import Config as BotoConfig -from pydantic import SecretStr, ValidationError +from pydantic import BaseModel, SecretStr, ValidationError if sys.version_info >= (3, 13): - from typing import TypeVar + from typing import TypeVar # type: ignore # pragma: no cover else: - from typing_extensions import TypeVar - + from typing_extensions import TypeVar # type: ignore # pragma: no cover if sys.version_info >= (3, 12): from typing import override # type: ignore # pragma: no cover else: - from typing_extensions import override # type: ignore[import] # pragma: no cover + from typing_extensions import override # type: ignore # pragma: no cover +if sys.version_info >= (3, 11): + from typing import TypedDict # type: ignore # pragma: no cover +else: + from typing_extensions import TypedDict # type: ignore # pragma: no cover logger = get_logger("agent_framework.bedrock") @@ -55,6 +58,8 @@ "BedrockSettings", ] +TResponseModel = TypeVar("TResponseModel", bound=BaseModel | None, default=None) + # region Bedrock Chat Options TypedDict @@ -82,7 +87,7 @@ class BedrockGuardrailConfig(TypedDict, total=False): """How to process guardrails during streaming (sync blocks, async does not).""" -class BedrockChatOptions(ChatOptions, total=False): +class BedrockChatOptions(ChatOptions[TResponseModel], Generic[TResponseModel], total=False): """Amazon Bedrock Converse API-specific chat options dict. Extends base ChatOptions with Bedrock-specific parameters. diff --git a/python/packages/chatkit/agent_framework_chatkit/_converter.py b/python/packages/chatkit/agent_framework_chatkit/_converter.py index b83fd40812..894d54831d 100644 --- a/python/packages/chatkit/agent_framework_chatkit/_converter.py +++ b/python/packages/chatkit/agent_framework_chatkit/_converter.py @@ -30,9 +30,9 @@ ) if sys.version_info >= (3, 11): - from typing import assert_never + from typing import assert_never # type:ignore # pragma: no cover else: - from typing_extensions import assert_never + from typing_extensions import assert_never # type:ignore # pragma: no cover logger = logging.getLogger(__name__) diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index f33b26d053..4dc6df2eac 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -3,7 +3,7 @@ import inspect import re import sys -from collections.abc import AsyncIterable, Awaitable, Callable, MutableMapping, Sequence +from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, MutableMapping, Sequence from contextlib import AbstractAsyncContextManager, AsyncExitStack from copy import deepcopy from itertools import chain @@ -13,8 +13,8 @@ ClassVar, Generic, Protocol, - TypedDict, cast, + overload, runtime_checkable, ) from uuid import uuid4 @@ -43,24 +43,25 @@ from .exceptions import AgentExecutionException, AgentInitializationError from .observability import use_agent_instrumentation -if TYPE_CHECKING: - from ._types import ChatOptions - - if sys.version_info >= (3, 13): - from typing import TypeVar + from typing import TypeVar # type: ignore # pragma: no cover else: - from typing_extensions import TypeVar - + from typing_extensions import TypeVar # type: ignore # pragma: no cover if sys.version_info >= (3, 12): from typing import override # type: ignore # pragma: no cover else: from typing_extensions import override # type: ignore[import] # pragma: no cover - if sys.version_info >= (3, 11): - from typing import Self # pragma: no cover + from typing import Self, TypedDict # pragma: no cover else: - from typing_extensions import Self # pragma: no cover + from typing_extensions import Self, TypedDict # pragma: no cover + +if TYPE_CHECKING: + from ._types import ChatOptions + + +TResponseModel = TypeVar("TResponseModel", bound=BaseModel | None, default=None, covariant=True) +TResponseModelT = TypeVar("TResponseModelT", bound=BaseModel) logger = get_logger("agent_framework") @@ -622,6 +623,7 @@ def __init__( provider-specific options including temperature, max_tokens, model_id, tool_choice, and provider-specific options like reasoning_effort. You can also create your own TypedDict for custom chat clients. + Note: response_format typing does not flow into run outputs when set via default_options. These can be overridden at runtime via the ``options`` parameter of ``run()`` and ``run_stream()``. tools: The tools to use for the request. kwargs: Any additional keyword arguments. Will be stored as ``additional_properties``. @@ -657,6 +659,14 @@ def __init__( # Get tools from options or named parameter (named param takes precedence) tools_ = tools if tools is not None else opts.pop("tools", None) + tools_ = cast( + ToolProtocol + | Callable[..., Any] + | MutableMapping[str, Any] + | list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] + | None, + tools_, + ) # Handle instructions - named parameter takes precedence over options instructions_ = instructions if instructions is not None else opts.pop("instructions", None) @@ -742,6 +752,7 @@ def _update_agent_name_and_description(self) -> None: ): # type: ignore[reportAttributeAccessIssue, attr-defined] self.chat_client._update_agent_name_and_description(self.name, self.description) # type: ignore[reportAttributeAccessIssue, attr-defined] + @overload async def run( self, messages: str | ChatMessage | Sequence[str | ChatMessage] | None = None, @@ -752,9 +763,38 @@ async def run( | MutableMapping[str, Any] | list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] | None = None, - options: TOptions_co | None = None, + options: "ChatOptions[TResponseModelT]", **kwargs: Any, - ) -> AgentResponse: + ) -> AgentResponse[TResponseModelT]: ... + + @overload + async def run( + self, + messages: str | ChatMessage | Sequence[str | ChatMessage] | None = None, + *, + thread: AgentThread | None = None, + tools: ToolProtocol + | Callable[..., Any] + | MutableMapping[str, Any] + | list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] + | None = None, + options: TOptions_co | Mapping[str, Any] | "ChatOptions[Any]" | None = None, + **kwargs: Any, + ) -> AgentResponse[Any]: ... + + async def run( + self, + messages: str | ChatMessage | Sequence[str | ChatMessage] | None = None, + *, + thread: AgentThread | None = None, + tools: ToolProtocol + | Callable[..., Any] + | MutableMapping[str, Any] + | list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] + | None = None, + options: TOptions_co | Mapping[str, Any] | "ChatOptions[Any]" | None = None, + **kwargs: Any, + ) -> AgentResponse[Any]: """Run the agent with the given messages and options. Note: @@ -784,6 +824,14 @@ async def run( # Get tools from options or named parameter (named param takes precedence) tools_ = tools if tools is not None else opts.pop("tools", None) + tools_ = cast( + ToolProtocol + | Callable[..., Any] + | MutableMapping[str, Any] + | list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] + | None, + tools_, + ) input_messages = normalize_messages(messages) thread, run_chat_options, thread_messages = await self._prepare_thread_and_messages( @@ -860,12 +908,19 @@ async def run( response.messages, **{k: v for k, v in kwargs.items() if k != "thread"}, ) + response_format = co.get("response_format") + if not ( + response_format is not None and isinstance(response_format, type) and issubclass(response_format, BaseModel) + ): + response_format = None + return AgentResponse( messages=response.messages, response_id=response.response_id, created_at=response.created_at, usage_details=response.usage_details, value=response.value, + response_format=response_format, raw_representation=response, additional_properties=response.additional_properties, ) @@ -880,7 +935,7 @@ async def run_stream( | MutableMapping[str, Any] | list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] | None = None, - options: TOptions_co | None = None, + options: TOptions_co | Mapping[str, Any] | None = None, **kwargs: Any, ) -> AsyncIterable[AgentResponseUpdate]: """Stream the agent with the given messages and options. diff --git a/python/packages/core/agent_framework/_clients.py b/python/packages/core/agent_framework/_clients.py index 12e975df6c..68d9d0312f 100644 --- a/python/packages/core/agent_framework/_clients.py +++ b/python/packages/core/agent_framework/_clients.py @@ -6,6 +6,7 @@ from collections.abc import ( AsyncIterable, Callable, + Mapping, MutableMapping, MutableSequence, Sequence, @@ -17,9 +18,13 @@ Generic, Protocol, TypedDict, + cast, + overload, runtime_checkable, ) +from pydantic import BaseModel + from ._logging import get_logger from ._memory import ContextProvider from ._middleware import ( @@ -45,9 +50,10 @@ ) if sys.version_info >= (3, 13): - from typing import TypeVar + from typing import TypeVar # type: ignore # pragma: no cover else: - from typing_extensions import TypeVar + from typing_extensions import TypeVar # type: ignore # pragma: no cover + if TYPE_CHECKING: from ._agents import ChatAgent @@ -120,6 +126,16 @@ async def _stream(): additional_properties: dict[str, Any] + @overload + async def get_response( + self, + messages: str | ChatMessage | Sequence[str | ChatMessage], + *, + options: "ChatOptions[TResponseModelT]", + **kwargs: Any, + ) -> "ChatResponse[TResponseModelT]": ... + + @overload async def get_response( self, messages: str | ChatMessage | Sequence[str | ChatMessage], @@ -175,6 +191,9 @@ def get_streaming_response( covariant=True, ) +TResponseModel = TypeVar("TResponseModel", bound=BaseModel | None, default=None, covariant=True) +TResponseModelT = TypeVar("TResponseModelT", bound=BaseModel) + class BaseChatClient(SerializationMixin, ABC, Generic[TOptions_co]): """Base class for chat clients. @@ -319,13 +338,31 @@ async def _inner_get_streaming_response( # region Public method + @overload + async def get_response( + self, + messages: str | ChatMessage | Sequence[str | ChatMessage], + *, + options: "ChatOptions[TResponseModelT]", + **kwargs: Any, + ) -> ChatResponse[TResponseModelT]: ... + + @overload async def get_response( self, messages: str | ChatMessage | Sequence[str | ChatMessage], *, options: TOptions_co | None = None, **kwargs: Any, - ) -> ChatResponse: + ) -> ChatResponse: ... + + async def get_response( + self, + messages: str | ChatMessage | Sequence[str | ChatMessage], + *, + options: TOptions_co | "ChatOptions[Any]" | None = None, + **kwargs: Any, + ) -> ChatResponse[Any]: """Get a response from a chat client. Args: @@ -389,7 +426,7 @@ def as_agent( | MutableMapping[str, Any] | Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] | None = None, - default_options: TOptions_co | None = None, + default_options: TOptions_co | Mapping[str, Any] | None = None, chat_message_store_factory: Callable[[], ChatMessageStoreProtocol] | None = None, context_provider: ContextProvider | None = None, middleware: Sequence[Middleware] | None = None, @@ -410,6 +447,8 @@ def as_agent( default_options: A TypedDict containing chat options. When using a typed client like ``OpenAIChatClient``, this enables IDE autocomplete for provider-specific options including temperature, max_tokens, model_id, tool_choice, and more. + Note: response_format typing does not flow into run outputs when set via default_options, + and dict literals are accepted without specialized option typing. chat_message_store_factory: Factory function to create an instance of ChatMessageStoreProtocol. If not provided, the default in-memory store will be used. context_provider: Context providers to include during agent invocation. @@ -446,7 +485,7 @@ def as_agent( description=description, instructions=instructions, tools=tools, - default_options=default_options, + default_options=cast(Any, default_options), chat_message_store_factory=chat_message_store_factory, context_provider=context_provider, middleware=middleware, diff --git a/python/packages/core/agent_framework/_middleware.py b/python/packages/core/agent_framework/_middleware.py index 4ba922c464..c41c2e7b5b 100644 --- a/python/packages/core/agent_framework/_middleware.py +++ b/python/packages/core/agent_framework/_middleware.py @@ -1,11 +1,12 @@ # Copyright (c) Microsoft. All rights reserved. import inspect +import sys from abc import ABC, abstractmethod from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, MutableSequence, Sequence from enum import Enum from functools import update_wrapper -from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeAlias, TypedDict, TypeVar +from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeAlias, TypeVar from ._serialization import SerializationMixin from ._types import AgentResponse, AgentResponseUpdate, ChatMessage, normalize_messages, prepare_messages @@ -20,6 +21,10 @@ from ._tools import FunctionTool from ._types import ChatResponse, ChatResponseUpdate +if sys.version_info >= (3, 11): + from typing import TypedDict # type: ignore # pragma: no cover +else: + from typing_extensions import TypedDict # type: ignore # pragma: no cover __all__ = [ "AgentMiddleware", diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 8ef899a5f7..2ebd7b9015 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -24,12 +24,11 @@ Generic, Literal, Protocol, - TypedDict, - TypeVar, Union, cast, get_args, get_origin, + overload, runtime_checkable, ) @@ -42,7 +41,7 @@ from .observability import ( OPERATION_DURATION_BUCKET_BOUNDARIES, OtelAttr, - capture_exception, # type: ignore + capture_exception, get_function_span, get_function_span_attributes, get_meter, @@ -57,20 +56,21 @@ Content, ) -from typing import overload +# TypeVar with defaults support for Python < 3.13 +if sys.version_info >= (3, 13): + from typing import TypeVar as TypeVar # type: ignore # pragma: no cover +else: + from typing_extensions import TypeVar as TypeVar # type: ignore[import] # pragma: no cover if sys.version_info >= (3, 12): from typing import override # type: ignore # pragma: no cover else: from typing_extensions import override # type: ignore[import] # pragma: no cover - -# TypeVar with defaults support for Python < 3.13 -if sys.version_info >= (3, 13): - from typing import TypeVar as TypeVarWithDefaults # type: ignore # pragma: no cover +if sys.version_info >= (3, 11): + from typing import TypedDict # type: ignore # pragma: no cover else: - from typing_extensions import ( - TypeVar as TypeVarWithDefaults, # type: ignore[import] # pragma: no cover - ) + from typing_extensions import TypedDict # type: ignore # pragma: no cover + logger = get_logger() @@ -97,8 +97,8 @@ TChatClient = TypeVar("TChatClient", bound="ChatClientProtocol[Any]") # region Helpers -ArgsT = TypeVarWithDefaults("ArgsT", bound=BaseModel, default=BaseModel) -ReturnT = TypeVarWithDefaults("ReturnT", default=Any) +ArgsT = TypeVar("ArgsT", bound=BaseModel, default=BaseModel) +ReturnT = TypeVar("ReturnT", default=Any) def _parse_inputs( diff --git a/python/packages/core/agent_framework/_types.py b/python/packages/core/agent_framework/_types.py index 3517fcfb41..9c49d25845 100644 --- a/python/packages/core/agent_framework/_types.py +++ b/python/packages/core/agent_framework/_types.py @@ -1,8 +1,8 @@ # Copyright (c) Microsoft. All rights reserved. - import base64 import json import re +import sys from collections.abc import ( AsyncIterable, Callable, @@ -12,7 +12,7 @@ Sequence, ) from copy import deepcopy -from typing import Any, ClassVar, Final, Literal, TypedDict, TypeVar, overload +from typing import TYPE_CHECKING, Any, ClassVar, Final, Generic, Literal, cast, overload from pydantic import BaseModel, ValidationError @@ -21,6 +21,15 @@ from ._tools import ToolProtocol, tool from .exceptions import AdditionItemMismatch, ContentError +if sys.version_info >= (3, 13): + from typing import TypeVar # pragma: no cover +else: + from typing_extensions import TypeVar # pragma: no cover +if sys.version_info >= (3, 11): + from typing import TypedDict # type: ignore # pragma: no cover +else: + from typing_extensions import TypedDict # type: ignore # pragma: no cover + __all__ = [ "AgentResponse", "AgentResponseUpdate", @@ -312,6 +321,8 @@ def _serialize_value(value: Any, exclude_none: bool) -> Any: TChatResponse = TypeVar("TChatResponse", bound="ChatResponse") TToolMode = TypeVar("TToolMode", bound="ToolMode") TAgentRunResponse = TypeVar("TAgentRunResponse", bound="AgentResponse") +TResponseModel = TypeVar("TResponseModel", bound=BaseModel | None, default=None, covariant=True) +TResponseModelT = TypeVar("TResponseModelT", bound=BaseModel) CreatedAtT = str # Use a datetimeoffset type? Or a more specific type like datetime.datetime? @@ -1911,7 +1922,7 @@ def _finalize_response(response: "ChatResponse | AgentResponse") -> None: _coalesce_text_content(msg.contents, "text_reasoning") -class ChatResponse(SerializationMixin): +class ChatResponse(SerializationMixin, Generic[TResponseModel]): """Represents the response to a chat request. Attributes: @@ -1974,7 +1985,7 @@ def __init__( created_at: CreatedAtT | None = None, finish_reason: FinishReason | None = None, usage_details: UsageDetails | None = None, - value: Any | None = None, + value: TResponseModel | None = None, response_format: type[BaseModel] | None = None, additional_properties: dict[str, Any] | None = None, raw_representation: Any | None = None, @@ -2009,7 +2020,7 @@ def __init__( created_at: CreatedAtT | None = None, finish_reason: FinishReason | None = None, usage_details: UsageDetails | None = None, - value: Any | None = None, + value: TResponseModel | None = None, response_format: type[BaseModel] | None = None, additional_properties: dict[str, Any] | None = None, raw_representation: Any | None = None, @@ -2044,7 +2055,7 @@ def __init__( created_at: CreatedAtT | None = None, finish_reason: FinishReason | dict[str, Any] | None = None, usage_details: UsageDetails | dict[str, Any] | None = None, - value: Any | None = None, + value: TResponseModel | None = None, response_format: type[BaseModel] | None = None, additional_properties: dict[str, Any] | None = None, raw_representation: Any | None = None, @@ -2101,13 +2112,31 @@ def __init__( self.created_at = created_at self.finish_reason = finish_reason self.usage_details = usage_details - self._value: Any | None = value + self._value: TResponseModel | None = value self._response_format: type[BaseModel] | None = response_format self._value_parsed: bool = value is not None self.additional_properties = additional_properties or {} self.additional_properties.update(kwargs or {}) self.raw_representation: Any | list[Any] | None = raw_representation + @overload + @classmethod + def from_chat_response_updates( + cls: type["ChatResponse[Any]"], + updates: Sequence["ChatResponseUpdate"], + *, + output_format_type: type[TResponseModelT], + ) -> "ChatResponse[TResponseModelT]": ... + + @overload + @classmethod + def from_chat_response_updates( + cls: type["ChatResponse[Any]"], + updates: Sequence["ChatResponseUpdate"], + *, + output_format_type: None = None, + ) -> "ChatResponse[Any]": ... + @classmethod def from_chat_response_updates( cls: type[TChatResponse], @@ -2146,12 +2175,30 @@ def from_chat_response_updates( msg.try_parse_value(output_format_type) return msg + @overload + @classmethod + async def from_chat_response_generator( + cls: type["ChatResponse[Any]"], + updates: AsyncIterable["ChatResponseUpdate"], + *, + output_format_type: type[TResponseModelT], + ) -> "ChatResponse[TResponseModelT]": ... + + @overload + @classmethod + async def from_chat_response_generator( + cls: type["ChatResponse[Any]"], + updates: AsyncIterable["ChatResponseUpdate"], + *, + output_format_type: None = None, + ) -> "ChatResponse[Any]": ... + @classmethod async def from_chat_response_generator( cls: type[TChatResponse], updates: AsyncIterable["ChatResponseUpdate"], *, - output_format_type: type[BaseModel] | Mapping[str, Any] | None = None, + output_format_type: type[BaseModel] | None = None, ) -> TChatResponse: """Joins multiple updates into a single ChatResponse. @@ -2187,7 +2234,7 @@ def text(self) -> str: return ("\n".join(message.text for message in self.messages if isinstance(message, ChatMessage))).strip() @property - def value(self) -> Any | None: + def value(self) -> TResponseModel | None: """Get the parsed structured output value. If a response_format was provided and parsing hasn't been attempted yet, @@ -2203,14 +2250,20 @@ def value(self) -> Any | None: and isinstance(self._response_format, type) and issubclass(self._response_format, BaseModel) ): - self._value = self._response_format.model_validate_json(self.text) + self._value = cast(TResponseModel, self._response_format.model_validate_json(self.text)) self._value_parsed = True return self._value def __str__(self) -> str: return self.text - def try_parse_value(self, output_format_type: type[_T] | None = None) -> _T | None: + @overload + def try_parse_value(self, output_format_type: type[TResponseModelT]) -> TResponseModelT | None: ... + + @overload + def try_parse_value(self, output_format_type: None = None) -> TResponseModel | None: ... + + def try_parse_value(self, output_format_type: type[BaseModel] | None = None) -> BaseModel | None: """Try to parse the text into a typed value. This is the safe alternative to accessing the value property directly. @@ -2238,7 +2291,7 @@ def try_parse_value(self, output_format_type: type[_T] | None = None) -> _T | No try: parsed_value = format_type.model_validate_json(self.text) # type: ignore[reportUnknownMemberType] if use_cache: - self._value = parsed_value + self._value = cast(TResponseModel, parsed_value) self._value_parsed = True return parsed_value # type: ignore[return-value] except ValidationError as ex: @@ -2376,7 +2429,7 @@ def __str__(self) -> str: # region AgentResponse -class AgentResponse(SerializationMixin): +class AgentResponse(SerializationMixin, Generic[TResponseModel]): """Represents the response to an Agent run request. Provides one or more response messages and metadata about the response. @@ -2428,7 +2481,7 @@ def __init__( response_id: str | None = None, created_at: CreatedAtT | None = None, usage_details: UsageDetails | MutableMapping[str, Any] | None = None, - value: Any | None = None, + value: TResponseModel | None = None, response_format: type[BaseModel] | None = None, raw_representation: Any | None = None, additional_properties: dict[str, Any] | None = None, @@ -2469,7 +2522,7 @@ def __init__( self.response_id = response_id self.created_at = created_at self.usage_details = usage_details - self._value: Any | None = value + self._value: TResponseModel | None = value self._response_format: type[BaseModel] | None = response_format self._value_parsed: bool = value is not None self.additional_properties = additional_properties or {} @@ -2482,7 +2535,7 @@ def text(self) -> str: return "".join(msg.text for msg in self.messages) if self.messages else "" @property - def value(self) -> Any | None: + def value(self) -> TResponseModel | None: """Get the parsed structured output value. If a response_format was provided and parsing hasn't been attempted yet, @@ -2498,7 +2551,7 @@ def value(self) -> Any | None: and isinstance(self._response_format, type) and issubclass(self._response_format, BaseModel) ): - self._value = self._response_format.model_validate_json(self.text) + self._value = cast(TResponseModel, self._response_format.model_validate_json(self.text)) self._value_parsed = True return self._value @@ -2512,6 +2565,24 @@ def user_input_requests(self) -> list[Content]: if isinstance(content, Content) and content.user_input_request ] + @overload + @classmethod + def from_agent_run_response_updates( + cls: type["AgentResponse[Any]"], + updates: Sequence["AgentResponseUpdate"], + *, + output_format_type: type[TResponseModelT], + ) -> "AgentResponse[TResponseModelT]": ... + + @overload + @classmethod + def from_agent_run_response_updates( + cls: type["AgentResponse[Any]"], + updates: Sequence["AgentResponseUpdate"], + *, + output_format_type: None = None, + ) -> "AgentResponse[Any]": ... + @classmethod def from_agent_run_response_updates( cls: type[TAgentRunResponse], @@ -2535,6 +2606,24 @@ def from_agent_run_response_updates( msg.try_parse_value(output_format_type) return msg + @overload + @classmethod + async def from_agent_response_generator( + cls: type["AgentResponse[Any]"], + updates: AsyncIterable["AgentResponseUpdate"], + *, + output_format_type: type[TResponseModelT], + ) -> "AgentResponse[TResponseModelT]": ... + + @overload + @classmethod + async def from_agent_response_generator( + cls: type["AgentResponse[Any]"], + updates: AsyncIterable["AgentResponseUpdate"], + *, + output_format_type: None = None, + ) -> "AgentResponse[Any]": ... + @classmethod async def from_agent_response_generator( cls: type[TAgentRunResponse], @@ -2561,7 +2650,13 @@ async def from_agent_response_generator( def __str__(self) -> str: return self.text - def try_parse_value(self, output_format_type: type[_T] | None = None) -> _T | None: + @overload + def try_parse_value(self, output_format_type: type[TResponseModelT]) -> TResponseModelT | None: ... + + @overload + def try_parse_value(self, output_format_type: None = None) -> TResponseModel | None: ... + + def try_parse_value(self, output_format_type: type[BaseModel] | None = None) -> BaseModel | None: """Try to parse the text into a typed value. This is the safe alternative when you need to parse the response text into a typed value. @@ -2589,7 +2684,7 @@ def try_parse_value(self, output_format_type: type[_T] | None = None) -> _T | No try: parsed_value = format_type.model_validate_json(self.text) # type: ignore[reportUnknownMemberType] if use_cache: - self._value = parsed_value + self._value = cast(TResponseModel, parsed_value) self._value_parsed = True return parsed_value # type: ignore[return-value] except ValidationError as ex: @@ -2718,7 +2813,7 @@ class ToolMode(TypedDict, total=False): # region TypedDict-based Chat Options -class ChatOptions(TypedDict, total=False): +class _ChatOptionsBase(TypedDict, total=False): """Common request settings for AI services as a TypedDict. All fields are optional (total=False) to allow partial specification. @@ -2771,7 +2866,7 @@ class ChatOptions(TypedDict, total=False): allow_multiple_tool_calls: bool # Response configuration - response_format: type[BaseModel] | dict[str, Any] + response_format: type[BaseModel] | Mapping[str, Any] | None # Metadata metadata: dict[str, Any] @@ -2783,6 +2878,15 @@ class ChatOptions(TypedDict, total=False): instructions: str +if TYPE_CHECKING: + + class ChatOptions(_ChatOptionsBase, Generic[TResponseModel], total=False): + response_format: type[TResponseModel] | Mapping[str, Any] | None # type: ignore[misc] + +else: + ChatOptions = _ChatOptionsBase + + # region Chat Options Utility Functions diff --git a/python/packages/core/agent_framework/_workflows/_agent.py b/python/packages/core/agent_framework/_workflows/_agent.py index 6b82823be1..1543ed7db6 100644 --- a/python/packages/core/agent_framework/_workflows/_agent.py +++ b/python/packages/core/agent_framework/_workflows/_agent.py @@ -2,11 +2,12 @@ import json import logging +import sys import uuid from collections.abc import AsyncIterable from dataclasses import dataclass from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast +from typing import TYPE_CHECKING, Any, ClassVar, cast from agent_framework import ( AgentResponse, @@ -32,6 +33,11 @@ from ._message_utils import normalize_messages_input from ._typing_utils import is_type_compatible +if sys.version_info >= (3, 11): + from typing import TypedDict # type: ignore # pragma: no cover +else: + from typing_extensions import TypedDict # type: ignore # pragma: no cover + if TYPE_CHECKING: from ._workflow import Workflow diff --git a/python/packages/core/agent_framework/_workflows/_agent_executor.py b/python/packages/core/agent_framework/_workflows/_agent_executor.py index 9beaf06a65..6a355fc92d 100644 --- a/python/packages/core/agent_framework/_workflows/_agent_executor.py +++ b/python/packages/core/agent_framework/_workflows/_agent_executor.py @@ -16,7 +16,7 @@ from ._conversation_state import encode_chat_messages from ._events import ( AgentRunEvent, - AgentRunUpdateEvent, # type: ignore[reportPrivateUsage] + AgentRunUpdateEvent, ) from ._executor import Executor, handler from ._message_utils import normalize_messages_input @@ -24,9 +24,9 @@ from ._workflow_context import WorkflowContext if sys.version_info >= (3, 12): - from typing import override + from typing import override # type: ignore # pragma: no cover else: - from typing_extensions import override + from typing_extensions import override # type: ignore # pragma: no cover logger = logging.getLogger(__name__) diff --git a/python/packages/core/agent_framework/_workflows/_base_group_chat_orchestrator.py b/python/packages/core/agent_framework/_workflows/_base_group_chat_orchestrator.py index 026933d777..e3cc4bc7d2 100644 --- a/python/packages/core/agent_framework/_workflows/_base_group_chat_orchestrator.py +++ b/python/packages/core/agent_framework/_workflows/_base_group_chat_orchestrator.py @@ -22,9 +22,9 @@ from ._workflow_context import WorkflowContext if sys.version_info >= (3, 12): - from typing import override + from typing import override # type: ignore # pragma: no cover else: - from typing_extensions import override + from typing_extensions import override # type: ignore # pragma: no cover logger = logging.getLogger(__name__) diff --git a/python/packages/core/agent_framework/_workflows/_group_chat.py b/python/packages/core/agent_framework/_workflows/_group_chat.py index d75b805514..26562ff9b6 100644 --- a/python/packages/core/agent_framework/_workflows/_group_chat.py +++ b/python/packages/core/agent_framework/_workflows/_group_chat.py @@ -52,9 +52,9 @@ from ._workflow_context import WorkflowContext if sys.version_info >= (3, 12): - from typing import override + from typing import override # type: ignore # pragma: no cover else: - from typing_extensions import override + from typing_extensions import override # type: ignore # pragma: no cover logger = logging.getLogger(__name__) diff --git a/python/packages/core/agent_framework/_workflows/_handoff.py b/python/packages/core/agent_framework/_workflows/_handoff.py index 00aa36dd99..cce07087e2 100644 --- a/python/packages/core/agent_framework/_workflows/_handoff.py +++ b/python/packages/core/agent_framework/_workflows/_handoff.py @@ -55,9 +55,9 @@ from ._workflow_context import WorkflowContext if sys.version_info >= (3, 12): - from typing import override + from typing import override # type: ignore # pragma: no cover else: - from typing_extensions import override + from typing_extensions import override # type: ignore # pragma: no cover logger = logging.getLogger(__name__) diff --git a/python/packages/core/agent_framework/_workflows/_magentic.py b/python/packages/core/agent_framework/_workflows/_magentic.py index 052a59766f..2f73e636e5 100644 --- a/python/packages/core/agent_framework/_workflows/_magentic.py +++ b/python/packages/core/agent_framework/_workflows/_magentic.py @@ -39,15 +39,14 @@ from ._workflow_builder import WorkflowBuilder from ._workflow_context import WorkflowContext -if sys.version_info >= (3, 11): - from typing import Self -else: - from typing_extensions import Self - if sys.version_info >= (3, 12): - from typing import override + from typing import override # type: ignore # pragma: no cover +else: + from typing_extensions import override # type: ignore # pragma: no cover +if sys.version_info >= (3, 11): + from typing import Self # type: ignore # pragma: no cover else: - from typing_extensions import override + from typing_extensions import Self # type: ignore # pragma: no cover logger = logging.getLogger(__name__) diff --git a/python/packages/core/agent_framework/_workflows/_runner_context.py b/python/packages/core/agent_framework/_workflows/_runner_context.py index 62f3836617..ce9fff6617 100644 --- a/python/packages/core/agent_framework/_workflows/_runner_context.py +++ b/python/packages/core/agent_framework/_workflows/_runner_context.py @@ -2,11 +2,12 @@ import asyncio import logging +import sys import uuid from copy import copy from dataclasses import dataclass from enum import Enum -from typing import Any, Protocol, TypedDict, TypeVar, runtime_checkable +from typing import Any, Protocol, TypeVar, runtime_checkable from ._checkpoint import CheckpointStorage, WorkflowCheckpoint from ._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value @@ -15,6 +16,11 @@ from ._shared_state import SharedState from ._typing_utils import is_instance_of +if sys.version_info >= (3, 11): + from typing import TypedDict # type: ignore # pragma: no cover +else: + from typing_extensions import TypedDict # type: ignore # pragma: no cover + logger = logging.getLogger(__name__) T = TypeVar("T") diff --git a/python/packages/core/agent_framework/_workflows/_workflow_builder.py b/python/packages/core/agent_framework/_workflows/_workflow_builder.py index 8cc31e2cc9..14cabc219b 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_builder.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_builder.py @@ -8,9 +8,8 @@ from typing_extensions import deprecated -from agent_framework import AgentThread - from .._agents import AgentProtocol +from .._threads import AgentThread from ..observability import OtelAttr, capture_exception, create_workflow_span from ._agent_executor import AgentExecutor from ._checkpoint import CheckpointStorage @@ -34,9 +33,9 @@ from ._workflow import Workflow if sys.version_info >= (3, 11): - from typing import Self # pragma: no cover + from typing import Self # type: ignore # pragma: no cover else: - from typing_extensions import Self # pragma: no cover + from typing_extensions import Self # type: ignore # pragma: no cover logger = logging.getLogger(__name__) diff --git a/python/packages/core/agent_framework/_workflows/_workflow_executor.py b/python/packages/core/agent_framework/_workflows/_workflow_executor.py index 69f24bcf2c..2b2426b3eb 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_executor.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_executor.py @@ -26,9 +26,9 @@ from ._workflow_context import WorkflowContext if sys.version_info >= (3, 12): - from typing import override + from typing import override # type: ignore # pragma: no cover else: - from typing_extensions import override + from typing_extensions import override # type: ignore # pragma: no cover logger = logging.getLogger(__name__) diff --git a/python/packages/core/agent_framework/azure/_assistants_client.py b/python/packages/core/agent_framework/azure/_assistants_client.py index a835310435..4f1d2190be 100644 --- a/python/packages/core/agent_framework/azure/_assistants_client.py +++ b/python/packages/core/agent_framework/azure/_assistants_client.py @@ -19,8 +19,10 @@ from typing import TypeVar # type: ignore # pragma: no cover else: from typing_extensions import TypeVar # type: ignore # pragma: no cover - -from typing import TypedDict +if sys.version_info >= (3, 11): + from typing import TypedDict # type: ignore # pragma: no cover +else: + from typing_extensions import TypedDict # type: ignore # pragma: no cover __all__ = ["AzureOpenAIAssistantsClient"] diff --git a/python/packages/core/agent_framework/azure/_chat_client.py b/python/packages/core/agent_framework/azure/_chat_client.py index b60054165f..a372d6f0cc 100644 --- a/python/packages/core/agent_framework/azure/_chat_client.py +++ b/python/packages/core/agent_framework/azure/_chat_client.py @@ -4,13 +4,13 @@ import logging import sys from collections.abc import Mapping -from typing import Any, Generic, TypedDict +from typing import Any, Generic from azure.core.credentials import TokenCredential from openai.lib.azure import AsyncAzureADTokenProvider, AsyncAzureOpenAI from openai.types.chat.chat_completion import Choice from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice -from pydantic import ValidationError +from pydantic import BaseModel, ValidationError from agent_framework import ( Annotation, @@ -36,12 +36,18 @@ if sys.version_info >= (3, 12): from typing import override # type: ignore # pragma: no cover else: - from typing_extensions import override # type: ignore[import] # pragma: no cover + from typing_extensions import override # type: ignore # pragma: no cover +if sys.version_info >= (3, 11): + from typing import TypedDict # type: ignore # pragma: no cover +else: + from typing_extensions import TypedDict # type: ignore # pragma: no cover logger: logging.Logger = logging.getLogger(__name__) __all__ = ["AzureOpenAIChatClient", "AzureOpenAIChatOptions", "AzureUserSecurityContext"] +TResponseModel = TypeVar("TResponseModel", bound=BaseModel | None, default=None) + # region Azure OpenAI Chat Options TypedDict @@ -68,7 +74,7 @@ class AzureUserSecurityContext(TypedDict, total=False): """The original client's IP address.""" -class AzureOpenAIChatOptions(OpenAIChatOptions, total=False): +class AzureOpenAIChatOptions(OpenAIChatOptions[TResponseModel], Generic[TResponseModel], total=False): """Azure OpenAI-specific chat options dict. Extends OpenAIChatOptions with Azure-specific options including diff --git a/python/packages/core/agent_framework/azure/_responses_client.py b/python/packages/core/agent_framework/azure/_responses_client.py index e4f6989fa0..884640375b 100644 --- a/python/packages/core/agent_framework/azure/_responses_client.py +++ b/python/packages/core/agent_framework/azure/_responses_client.py @@ -2,34 +2,38 @@ import sys from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, Generic, TypedDict +from typing import TYPE_CHECKING, Any, Generic from urllib.parse import urljoin from azure.core.credentials import TokenCredential from openai.lib.azure import AsyncAzureADTokenProvider, AsyncAzureOpenAI from pydantic import ValidationError -from agent_framework import use_chat_middleware, use_function_invocation -from agent_framework.exceptions import ServiceInitializationError -from agent_framework.observability import use_instrumentation -from agent_framework.openai._responses_client import OpenAIBaseResponsesClient - +from .._middleware import use_chat_middleware +from .._tools import use_function_invocation +from ..exceptions import ServiceInitializationError +from ..observability import use_instrumentation +from ..openai._responses_client import OpenAIBaseResponsesClient from ._shared import ( AzureOpenAIConfigMixin, AzureOpenAISettings, ) -if TYPE_CHECKING: - from agent_framework.openai._responses_client import OpenAIResponsesOptions - -if sys.version_info >= (3, 12): - from typing import override # type: ignore # pragma: no cover -else: - from typing_extensions import override # type: ignore[import] # pragma: no cover if sys.version_info >= (3, 13): from typing import TypeVar # type: ignore # pragma: no cover else: from typing_extensions import TypeVar # type: ignore # pragma: no cover +if sys.version_info >= (3, 12): + from typing import override # type: ignore # pragma: no cover +else: + from typing_extensions import override # type: ignore # pragma: no cover +if sys.version_info >= (3, 11): + from typing import TypedDict # type: ignore # pragma: no cover +else: + from typing_extensions import TypedDict # type: ignore # pragma: no cover + +if TYPE_CHECKING: + from ..openai._responses_client import OpenAIResponsesOptions __all__ = ["AzureOpenAIResponsesClient"] diff --git a/python/packages/core/agent_framework/openai/_assistant_provider.py b/python/packages/core/agent_framework/openai/_assistant_provider.py index 73acd2d05e..b35b525bf5 100644 --- a/python/packages/core/agent_framework/openai/_assistant_provider.py +++ b/python/packages/core/agent_framework/openai/_assistant_provider.py @@ -2,7 +2,7 @@ import sys from collections.abc import Awaitable, Callable, MutableMapping, Sequence -from typing import TYPE_CHECKING, Any, Generic, TypedDict, cast +from typing import TYPE_CHECKING, Any, Generic, cast from openai import AsyncOpenAI from openai.types.beta.assistant import Assistant @@ -21,10 +21,13 @@ from ._assistants_client import OpenAIAssistantsOptions if sys.version_info >= (3, 13): - from typing import Self, TypeVar # pragma: no cover + from typing import TypeVar # type:ignore # pragma: no cover else: - from typing_extensions import Self, TypeVar # pragma: no cover - + from typing_extensions import TypeVar # type:ignore # pragma: no cover +if sys.version_info >= (3, 11): + from typing import Self, TypedDict # type:ignore # pragma: no cover +else: + from typing_extensions import Self, TypedDict # type:ignore # pragma: no cover __all__ = ["OpenAIAssistantProvider"] diff --git a/python/packages/core/agent_framework/openai/_assistants_client.py b/python/packages/core/agent_framework/openai/_assistants_client.py index afb98f1088..22852bea53 100644 --- a/python/packages/core/agent_framework/openai/_assistants_client.py +++ b/python/packages/core/agent_framework/openai/_assistants_client.py @@ -9,12 +9,8 @@ Mapping, MutableMapping, MutableSequence, - Sequence, ) -from typing import TYPE_CHECKING, Any, Generic, Literal, TypedDict, cast - -if TYPE_CHECKING: - from .._agents import ChatAgent +from typing import Any, Generic, Literal, cast from openai import AsyncOpenAI from openai.types.beta.threads import ( @@ -29,17 +25,14 @@ from openai.types.beta.threads.run_create_params import AdditionalMessage from openai.types.beta.threads.run_submit_tool_outputs_params import ToolOutput from openai.types.beta.threads.runs import RunStep -from pydantic import ValidationError +from pydantic import BaseModel, ValidationError from .._clients import BaseChatClient -from .._memory import ContextProvider -from .._middleware import Middleware, use_chat_middleware -from .._threads import ChatMessageStoreProtocol +from .._middleware import use_chat_middleware from .._tools import ( FunctionTool, HostedCodeInterpreterTool, HostedFileSearchTool, - ToolProtocol, use_function_invocation, ) from .._types import ( @@ -57,19 +50,19 @@ from ._shared import OpenAIConfigMixin, OpenAISettings if sys.version_info >= (3, 13): - from typing import TypeVar + from typing import TypeVar # type: ignore # pragma: no cover else: - from typing_extensions import TypeVar + from typing_extensions import TypeVar # type: ignore # pragma: no cover if sys.version_info >= (3, 12): from typing import override # type: ignore # pragma: no cover else: - from typing_extensions import override # type: ignore[import] # pragma: no cover + from typing_extensions import override # type: ignore # pragma: no cover if sys.version_info >= (3, 11): - from typing import Self # pragma: no cover + from typing import Self, TypedDict # type: ignore # pragma: no cover else: - from typing_extensions import Self # pragma: no cover + from typing_extensions import Self, TypedDict # type: ignore # pragma: no cover __all__ = [ @@ -81,6 +74,8 @@ # region OpenAI Assistants Options TypedDict +TResponseModel = TypeVar("TResponseModel", bound=BaseModel | None, default=None) + class VectorStoreToolResource(TypedDict, total=False): """Vector store configuration for file search tool resources.""" @@ -109,7 +104,7 @@ class AssistantToolResources(TypedDict, total=False): """Resources for file search tool, including vector store IDs.""" -class OpenAIAssistantsOptions(ChatOptions, total=False): +class OpenAIAssistantsOptions(ChatOptions[TResponseModel], Generic[TResponseModel], total=False): """OpenAI Assistants API-specific options dict. Extends base ChatOptions with Assistants API-specific parameters @@ -765,59 +760,3 @@ def _update_agent_name_and_description(self, agent_name: str | None, description self.assistant_name = agent_name if description and not self.assistant_description: self.assistant_description = description - - @override - def as_agent( - self, - *, - id: str | None = None, - name: str | None = None, - description: str | None = None, - instructions: str | None = None, - tools: ToolProtocol - | Callable[..., Any] - | MutableMapping[str, Any] - | Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] - | None = None, - default_options: TOpenAIAssistantsOptions | None = None, - chat_message_store_factory: Callable[[], ChatMessageStoreProtocol] | None = None, - context_provider: ContextProvider | None = None, - middleware: Sequence[Middleware] | None = None, - **kwargs: Any, - ) -> "ChatAgent[TOpenAIAssistantsOptions]": - """Convert this chat client to a ChatAgent. - - This method creates a ChatAgent instance with this client pre-configured. - It does NOT create an assistant on the OpenAI service - the actual assistant - will be created on the server during the first invocation (run). - - For creating and managing persistent assistants on the server, use - :class:`~agent_framework.openai.OpenAIAssistantProvider` instead. - - Keyword Args: - id: The unique identifier for the agent. Will be created automatically if not provided. - name: The name of the agent. - description: A brief description of the agent's purpose. - instructions: Optional instructions for the agent. - tools: The tools to use for the request. - default_options: A TypedDict containing chat options. - chat_message_store_factory: Factory function to create an instance of ChatMessageStoreProtocol. - context_provider: Context providers to include during agent invocation. - middleware: List of middleware to intercept agent and function invocations. - kwargs: Any additional keyword arguments. - - Returns: - A ChatAgent instance configured with this chat client. - """ - return super().as_agent( - id=id, - name=name, - description=description, - instructions=instructions, - tools=tools, - default_options=default_options, - chat_message_store_factory=chat_message_store_factory, - context_provider=context_provider, - middleware=middleware, - **kwargs, - ) diff --git a/python/packages/core/agent_framework/openai/_chat_client.py b/python/packages/core/agent_framework/openai/_chat_client.py index a1bc1f846a..e70b4790f6 100644 --- a/python/packages/core/agent_framework/openai/_chat_client.py +++ b/python/packages/core/agent_framework/openai/_chat_client.py @@ -5,7 +5,7 @@ from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, MutableMapping, MutableSequence, Sequence from datetime import datetime, timezone from itertools import chain -from typing import Any, Generic, Literal, TypedDict +from typing import Any, Generic, Literal from openai import AsyncOpenAI, BadRequestError from openai.lib._parsing._completions import type_to_response_format_param @@ -14,7 +14,7 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice from openai.types.chat.chat_completion_message_custom_tool_call import ChatCompletionMessageCustomToolCall -from pydantic import ValidationError +from pydantic import BaseModel, ValidationError from .._clients import BaseChatClient from .._logging import get_logger @@ -41,19 +41,24 @@ from ._shared import OpenAIBase, OpenAIConfigMixin, OpenAISettings if sys.version_info >= (3, 13): - from typing import TypeVar + from typing import TypeVar # type: ignore # pragma: no cover else: - from typing_extensions import TypeVar - + from typing_extensions import TypeVar # type: ignore # pragma: no cover if sys.version_info >= (3, 12): from typing import override # type: ignore # pragma: no cover else: - from typing_extensions import override # type: ignore[import] # pragma: no cover + from typing_extensions import override # type: ignore # pragma: no cover +if sys.version_info >= (3, 11): + from typing import TypedDict # type: ignore # pragma: no cover +else: + from typing_extensions import TypedDict # type: ignore # pragma: no cover __all__ = ["OpenAIChatClient", "OpenAIChatOptions"] logger = get_logger("agent_framework.openai") +TResponseModel = TypeVar("TResponseModel", bound=BaseModel | None, default=None) + # region OpenAI Chat Options TypedDict @@ -72,7 +77,7 @@ class Prediction(TypedDict, total=False): content: str | list[PredictionTextContent] -class OpenAIChatOptions(ChatOptions, total=False): +class OpenAIChatOptions(ChatOptions[TResponseModel], Generic[TResponseModel], total=False): """OpenAI-specific chat options dict. Extends ChatOptions with options specific to OpenAI's Chat Completions API. diff --git a/python/packages/core/agent_framework/openai/_responses_client.py b/python/packages/core/agent_framework/openai/_responses_client.py index 9c12357e0a..9a3436e5ce 100644 --- a/python/packages/core/agent_framework/openai/_responses_client.py +++ b/python/packages/core/agent_framework/openai/_responses_client.py @@ -12,7 +12,7 @@ ) from datetime import datetime, timezone from itertools import chain -from typing import Any, Generic, Literal, TypedDict, cast +from typing import Any, Generic, Literal, cast from openai import AsyncOpenAI, BadRequestError from openai.types.responses.file_search_tool_param import FileSearchToolParam @@ -79,9 +79,14 @@ from typing import override # type: ignore # pragma: no cover else: from typing_extensions import override # type: ignore[import] # pragma: no cover +if sys.version_info >= (3, 11): + from typing import TypedDict # type: ignore # pragma: no cover +else: + from typing_extensions import TypedDict # type: ignore # pragma: no cover logger = get_logger("agent_framework.openai") + __all__ = ["OpenAIResponsesClient", "OpenAIResponsesOptions"] @@ -108,7 +113,10 @@ class StreamOptions(TypedDict, total=False): """Whether to include usage statistics in stream events.""" -class OpenAIResponsesOptions(ChatOptions, total=False): +TResponseFormat = TypeVar("TResponseFormat", bound=BaseModel | None, default=None) + + +class OpenAIResponsesOptions(ChatOptions[TResponseFormat], Generic[TResponseFormat], total=False): """OpenAI Responses API-specific chat options. Extends ChatOptions with options specific to OpenAI's Responses API. diff --git a/python/packages/core/tests/core/test_tools.py b/python/packages/core/tests/core/test_tools.py index 0f380ff06f..a60018c7a4 100644 --- a/python/packages/core/tests/core/test_tools.py +++ b/python/packages/core/tests/core/test_tools.py @@ -1,5 +1,5 @@ # Copyright (c) Microsoft. All rights reserved. -from typing import Annotated, Any, Literal +from typing import Annotated, Any, Literal, get_args, get_origin from unittest.mock import Mock import pytest @@ -1493,8 +1493,6 @@ def tool_with_kwargs(x: int, **kwargs: Any) -> str: def test_parse_annotation_with_literal_type(): """Test that _parse_annotation returns Literal types unchanged (issue #2891).""" - from typing import get_args, get_origin - # Literal with string values literal_annotation = Literal["Data", "Security", "Network"] result = _parse_annotation(literal_annotation) @@ -1505,7 +1503,6 @@ def test_parse_annotation_with_literal_type(): def test_parse_annotation_with_literal_int_type(): """Test that _parse_annotation returns Literal int types unchanged.""" - from typing import get_args, get_origin literal_annotation = Literal[1, 2, 3] result = _parse_annotation(literal_annotation) @@ -1516,7 +1513,6 @@ def test_parse_annotation_with_literal_int_type(): def test_parse_annotation_with_literal_bool_type(): """Test that _parse_annotation returns Literal bool types unchanged.""" - from typing import get_args, get_origin literal_annotation = Literal[True, False] result = _parse_annotation(literal_annotation) @@ -1535,7 +1531,6 @@ def test_parse_annotation_with_simple_types(): def test_parse_annotation_with_annotated_and_literal(): """Test that Annotated[Literal[...], description] works correctly.""" - from typing import get_args, get_origin # When Literal is inside Annotated, it should still be preserved annotated_literal = Annotated[Literal["A", "B", "C"], "The category"] diff --git a/python/packages/core/tests/core/test_types.py b/python/packages/core/tests/core/test_types.py index 90cb912b3d..0d55822d93 100644 --- a/python/packages/core/tests/core/test_types.py +++ b/python/packages/core/tests/core/test_types.py @@ -3,10 +3,10 @@ import base64 from collections.abc import AsyncIterable from datetime import datetime, timezone -from typing import Any +from typing import Any, Literal import pytest -from pydantic import BaseModel +from pydantic import BaseModel, Field, ValidationError from pytest import fixture, mark, raises from agent_framework import ( @@ -665,9 +665,6 @@ def test_chat_response_with_format_init(): def test_chat_response_value_raises_on_invalid_schema(): """Test that value property raises ValidationError with field constraint details.""" - from typing import Literal - - from pydantic import Field, ValidationError class StrictSchema(BaseModel): id: Literal[5] @@ -689,9 +686,6 @@ class StrictSchema(BaseModel): def test_chat_response_try_parse_value_returns_none_on_invalid(): """Test that try_parse_value returns None on validation failure with Field constraints.""" - from typing import Literal - - from pydantic import Field class StrictSchema(BaseModel): id: Literal[5] @@ -707,7 +701,6 @@ class StrictSchema(BaseModel): def test_chat_response_try_parse_value_returns_value_on_success(): """Test that try_parse_value returns parsed value when all constraints pass.""" - from pydantic import Field class MySchema(BaseModel): name: str = Field(min_length=3) @@ -724,9 +717,6 @@ class MySchema(BaseModel): def test_agent_response_value_raises_on_invalid_schema(): """Test that AgentResponse.value property raises ValidationError with field constraint details.""" - from typing import Literal - - from pydantic import Field, ValidationError class StrictSchema(BaseModel): id: Literal[5] @@ -748,9 +738,6 @@ class StrictSchema(BaseModel): def test_agent_response_try_parse_value_returns_none_on_invalid(): """Test that AgentResponse.try_parse_value returns None on Field constraint failure.""" - from typing import Literal - - from pydantic import Field class StrictSchema(BaseModel): id: Literal[5] @@ -766,7 +753,6 @@ class StrictSchema(BaseModel): def test_agent_response_try_parse_value_returns_value_on_success(): """Test that AgentResponse.try_parse_value returns parsed value when all constraints pass.""" - from pydantic import Field class MySchema(BaseModel): name: str = Field(min_length=3) diff --git a/python/packages/core/tests/workflow/test_executor.py b/python/packages/core/tests/workflow/test_executor.py index a812f6dae6..1988722a5d 100644 --- a/python/packages/core/tests/workflow/test_executor.py +++ b/python/packages/core/tests/workflow/test_executor.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. import pytest +from typing_extensions import Never from agent_framework import ( ChatMessage, @@ -187,7 +188,6 @@ async def handle(self, text: str, ctx: WorkflowContext) -> None: async def test_executor_completed_event_includes_yielded_outputs(): """Test that ExecutorCompletedEvent.data includes yielded outputs.""" - from typing_extensions import Never from agent_framework import WorkflowOutputEvent @@ -318,7 +318,6 @@ async def handle_number(self, num: int, ctx: WorkflowContext[bool]) -> None: def test_executor_workflow_output_types_property(): """Test that the workflow_output_types property correctly identifies workflow output types.""" - from typing_extensions import Never # Test executor with no workflow output types class NoWorkflowOutputExecutor(Executor): diff --git a/python/packages/core/tests/workflow/test_magentic.py b/python/packages/core/tests/workflow/test_magentic.py index 24c830968f..29035dbc6e 100644 --- a/python/packages/core/tests/workflow/test_magentic.py +++ b/python/packages/core/tests/workflow/test_magentic.py @@ -42,9 +42,9 @@ from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage if sys.version_info >= (3, 12): - from typing import override + from typing import override # type: ignore # pragma: no cover else: - from typing_extensions import override + from typing_extensions import override # type: ignore # pragma: no cover def test_magentic_context_reset_behavior(): diff --git a/python/packages/declarative/agent_framework_declarative/_loader.py b/python/packages/declarative/agent_framework_declarative/_loader.py index c7c06d2f0a..7dbd34f12d 100644 --- a/python/packages/declarative/agent_framework_declarative/_loader.py +++ b/python/packages/declarative/agent_framework_declarative/_loader.py @@ -1,8 +1,9 @@ # Copyright (c) Microsoft. All rights reserved. +import sys from collections.abc import Callable, Mapping from pathlib import Path -from typing import Any, Literal, TypedDict, cast +from typing import Any, Literal, cast import yaml from agent_framework import ( @@ -42,6 +43,11 @@ agent_schema_dispatch, ) +if sys.version_info >= (3, 11): + from typing import TypedDict # type: ignore # pragma: no cover +else: + from typing_extensions import TypedDict # type: ignore # pragma: no cover + class ProviderTypeMapping(TypedDict, total=True): package: str diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py index 0d881389b2..309a71a4b7 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py @@ -24,10 +24,11 @@ """ import logging +import sys from collections.abc import Mapping from dataclasses import dataclass from decimal import Decimal as _Decimal -from typing import Any, Literal, TypedDict, cast +from typing import Any, Literal, cast from agent_framework._workflows import ( Executor, @@ -36,6 +37,12 @@ ) from powerfx import Engine +if sys.version_info >= (3, 11): + from typing import TypedDict # type: ignore # pragma: no cover +else: + from typing_extensions import TypedDict # type: ignore # pragma: no cover + + logger = logging.getLogger(__name__) diff --git a/python/packages/devui/agent_framework_devui/_utils.py b/python/packages/devui/agent_framework_devui/_utils.py index 24cdc9c073..b715263075 100644 --- a/python/packages/devui/agent_framework_devui/_utils.py +++ b/python/packages/devui/agent_framework_devui/_utils.py @@ -7,7 +7,7 @@ import logging from dataclasses import fields, is_dataclass from types import UnionType -from typing import Any, Union, get_args, get_origin +from typing import Any, Union, get_args, get_origin, get_type_hints from agent_framework import ChatMessage @@ -270,8 +270,6 @@ def generate_schema_from_serialization_mixin(cls: type[Any]) -> dict[str, Any]: # Get type hints try: - from typing import get_type_hints - type_hints = get_type_hints(cls) except Exception: type_hints = {} @@ -348,8 +346,6 @@ async def handler(self, original_request: RequestType, response: ResponseType, c The response type class, or None if not found """ try: - from typing import get_type_hints - # Introspect handler methods for @response_handler pattern for attr_name in dir(executor): if attr_name.startswith("_"): diff --git a/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py b/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py index 5cd38ea251..380bd64f7b 100644 --- a/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py +++ b/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. import sys -from typing import Any, ClassVar, Generic, TypedDict +from typing import Any, ClassVar, Generic from agent_framework import ChatOptions, use_chat_middleware, use_function_invocation from agent_framework._pydantic import AFBaseSettings @@ -11,12 +11,16 @@ from foundry_local import FoundryLocalManager from foundry_local.models import DeviceType from openai import AsyncOpenAI +from pydantic import BaseModel if sys.version_info >= (3, 13): from typing import TypeVar # type: ignore # pragma: no cover else: from typing_extensions import TypeVar # type: ignore # pragma: no cover - +if sys.version_info >= (3, 11): + from typing import TypedDict # type: ignore # pragma: no cover +else: + from typing_extensions import TypedDict # type: ignore # pragma: no cover __all__ = [ "FoundryLocalChatOptions", @@ -24,11 +28,13 @@ "FoundryLocalSettings", ] +TResponseModel = TypeVar("TResponseModel", bound=BaseModel | None, default=None) + # region Foundry Local Chat Options TypedDict -class FoundryLocalChatOptions(ChatOptions, total=False): +class FoundryLocalChatOptions(ChatOptions[TResponseModel], Generic[TResponseModel], total=False): """Azure Foundry Local (local model deployment) chat options dict. Extends base ChatOptions for local model inference via Foundry Local. diff --git a/python/packages/mem0/agent_framework_mem0/_provider.py b/python/packages/mem0/agent_framework_mem0/_provider.py index e34c2cf435..3ff04df2a4 100644 --- a/python/packages/mem0/agent_framework_mem0/_provider.py +++ b/python/packages/mem0/agent_framework_mem0/_provider.py @@ -3,7 +3,7 @@ import sys from collections.abc import MutableSequence, Sequence from contextlib import AbstractAsyncContextManager -from typing import Any, TypedDict +from typing import Any from agent_framework import ChatMessage, Context, ContextProvider from agent_framework.exceptions import ServiceInitializationError @@ -15,9 +15,9 @@ from typing_extensions import override # type: ignore[import] # pragma: no cover if sys.version_info >= (3, 11): - from typing import NotRequired, Self # pragma: no cover + from typing import NotRequired, Self, TypedDict # pragma: no cover else: - from typing_extensions import NotRequired, Self # pragma: no cover + from typing_extensions import NotRequired, Self, TypedDict # pragma: no cover # Type aliases for Mem0 search response formats (v1.1 and v2; v1 is deprecated, but matches the type definition for v2) diff --git a/python/packages/ollama/agent_framework_ollama/_chat_client.py b/python/packages/ollama/agent_framework_ollama/_chat_client.py index 058aeecb4b..ead729b8e2 100644 --- a/python/packages/ollama/agent_framework_ollama/_chat_client.py +++ b/python/packages/ollama/agent_framework_ollama/_chat_client.py @@ -11,7 +11,7 @@ Sequence, ) from itertools import chain -from typing import Any, ClassVar, Generic, TypedDict +from typing import Any, ClassVar, Generic from agent_framework import ( BaseChatClient, @@ -40,26 +40,32 @@ # Rename imported types to avoid naming conflicts with Agent Framework types from ollama._types import ChatResponse as OllamaChatResponse from ollama._types import Message as OllamaMessage -from pydantic import ValidationError +from pydantic import BaseModel, ValidationError if sys.version_info >= (3, 13): - from typing import TypeVar + from typing import TypeVar # type: ignore # pragma: no cover else: - from typing_extensions import TypeVar + from typing_extensions import TypeVar # type: ignore # pragma: no cover if sys.version_info >= (3, 12): from typing import override # type: ignore # pragma: no cover else: - from typing_extensions import override # type: ignore[import] # pragma: no cover + from typing_extensions import override # type: ignore # pragma: no cover +if sys.version_info >= (3, 11): + from typing import TypedDict # type: ignore # pragma: no cover +else: + from typing_extensions import TypedDict # type: ignore # pragma: no cover __all__ = ["OllamaChatClient", "OllamaChatOptions"] +TResponseModel = TypeVar("TResponseModel", bound=BaseModel | None, default=None) + # region Ollama Chat Options TypedDict -class OllamaChatOptions(ChatOptions, total=False): +class OllamaChatOptions(ChatOptions[TResponseModel], Generic[TResponseModel], total=False): """Ollama-specific chat options dict. Extends base ChatOptions with Ollama-specific parameters. diff --git a/python/samples/getting_started/agents/openai/openai_responses_client_with_structured_output.py b/python/samples/getting_started/agents/openai/openai_responses_client_with_structured_output.py index c33951d5d3..c893f271b1 100644 --- a/python/samples/getting_started/agents/openai/openai_responses_client_with_structured_output.py +++ b/python/samples/getting_started/agents/openai/openai_responses_client_with_structured_output.py @@ -37,9 +37,9 @@ async def non_streaming_example() -> None: # Get structured response from the agent using response_format parameter result = await agent.run(query, options={"response_format": OutputStruct}) - # Access the structured output using try_parse_value for safe parsing - if structured_data := result.try_parse_value(OutputStruct): - print("Structured Output Agent (from result.try_parse_value):") + # Access the structured output using the parsed value + if structured_data := result.value: + print("Structured Output Agent:") print(f"City: {structured_data.city}") print(f"Description: {structured_data.description}") else: @@ -66,8 +66,8 @@ async def streaming_example() -> None: output_format_type=OutputStruct, ) - # Access the structured output using try_parse_value for safe parsing - if structured_data := result.try_parse_value(OutputStruct): + # Access the structured output using the parsed value + if structured_data := result.value: print("Structured Output (from streaming with AgentResponse.from_agent_response_generator):") print(f"City: {structured_data.city}") print(f"Description: {structured_data.description}")