diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index b2b0e6eb..a8e08540 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -53,6 +53,7 @@ lifecycles linting Llm lstrips +mikeas mockurl notif oauthoidc @@ -67,6 +68,7 @@ pyi pypistats pyupgrade pyversions +redef respx resub RUF @@ -76,5 +78,6 @@ sse tagwords taskupdate testuuid +Tful typeerror vulnz diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index f93d1d91..f74718e0 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -57,6 +57,6 @@ jobs: - name: Install dependencies run: uv sync --dev --extra sql --extra encryption --extra grpc --extra telemetry - name: Run tests and check coverage - run: uv run pytest --cov=a2a --cov-report=xml --cov-fail-under=89 + run: uv run pytest --cov=a2a --cov-report term --cov-fail-under=89 - name: Show coverage summary in log run: uv run coverage report diff --git a/.vscode/launch.json b/.vscode/launch.json index 6adb30d5..5c19f481 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -47,7 +47,8 @@ "-s" ], "console": "integratedTerminal", - "justMyCode": true + "justMyCode": true, + "python": "${workspaceFolder}/.venv/bin/python", } ] -} \ No newline at end of file +} diff --git a/Gemini.md b/Gemini.md new file mode 100644 index 00000000..d4367c37 --- /dev/null +++ b/Gemini.md @@ -0,0 +1,27 @@ +**A2A specification:** https://a2a-protocol.org/latest/specification/ + +## Project frameworks +- uv as package manager + +## How to run all tests +1. If dependencies are not installed install them using following command + ``` + uv sync --all-extras + ``` + +2. Run tests + ``` + uv run pytest + ``` + +## Other instructions +1. Whenever writing python code, write types as well. +2. After making the changes run ruff to check and fix the formatting issues + ``` + uv run ruff check --fix + ``` +3. Run mypy type checkers to check for type errors + ``` + uv run mypy + ``` +4. Run the unit tests to make sure that none of the unit tests are broken. diff --git a/error_handlers.py b/error_handlers.py new file mode 100644 index 00000000..e69de29b diff --git a/pyproject.toml b/pyproject.toml index f6da78a4..0bea4e38 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,12 +8,14 @@ authors = [{ name = "Google LLC", email = "googleapis-packages@google.com" }] requires-python = ">=3.10" keywords = ["A2A", "A2A SDK", "A2A Protocol", "Agent2Agent", "Agent 2 Agent"] dependencies = [ - "fastapi>=0.115.2", + "fastapi>=0.116.1", "httpx>=0.28.1", "httpx-sse>=0.4.0", "pydantic>=2.11.3", "sse-starlette", - "starlette" + "starlette", + "protobuf==5.29.5", + "google-api-core>=1.26.0", ] classifiers = [ @@ -35,7 +37,7 @@ mysql = ["sqlalchemy[asyncio,aiomysql]>=2.0.0"] sqlite = ["sqlalchemy[asyncio,aiosqlite]>=2.0.0"] sql = ["sqlalchemy[asyncio,postgresql-asyncpg,aiomysql,aiosqlite]>=2.0.0"] encryption = ["cryptography>=43.0.0"] -grpc = ["grpcio>=1.60", "grpcio-tools>=1.60", "grpcio_reflection>=1.7.0", "protobuf==5.29.5", "google-api-core>=1.26.0"] +grpc = ["grpcio>=1.60", "grpcio-tools>=1.60", "grpcio_reflection>=1.7.0"] telemetry = ["opentelemetry-api>=1.33.0", "opentelemetry-sdk>=1.33.0"] [project.urls] @@ -90,6 +92,7 @@ dev = [ "pyupgrade", "autoflake", "no_implicit_optional", + "trio", ] [[tool.uv.index]] diff --git a/src/a2a/client/__init__.py b/src/a2a/client/__init__.py index 33200ad1..726a089d 100644 --- a/src/a2a/client/__init__.py +++ b/src/a2a/client/__init__.py @@ -7,7 +7,9 @@ CredentialService, InMemoryContextCredentialStore, ) -from a2a.client.client import A2ACardResolver, A2AClient +from a2a.client.card_resolver import A2ACardResolver +from a2a.client.client import Client, ClientConfig, ClientEvent, Consumer +from a2a.client.client_factory import ClientFactory, minimal_agent_card from a2a.client.errors import ( A2AClientError, A2AClientHTTPError, @@ -15,13 +17,14 @@ A2AClientTimeoutError, ) from a2a.client.helpers import create_text_message_object +from a2a.client.legacy import A2AClient from a2a.client.middleware import ClientCallContext, ClientCallInterceptor logger = logging.getLogger(__name__) try: - from a2a.client.grpc_client import A2AGrpcClient # type: ignore + from a2a.client.legacy_grpc import A2AGrpcClient # type: ignore except ImportError as e: _original_error = e logger.debug( @@ -48,9 +51,15 @@ def __init__(self, *args, **kwargs): 'A2AClientTimeoutError', 'A2AGrpcClient', 'AuthInterceptor', + 'Client', 'ClientCallContext', 'ClientCallInterceptor', + 'ClientConfig', + 'ClientEvent', + 'ClientFactory', + 'Consumer', 'CredentialService', 'InMemoryContextCredentialStore', 'create_text_message_object', + 'minimal_agent_card', ] diff --git a/src/a2a/client/base_client.py b/src/a2a/client/base_client.py new file mode 100644 index 00000000..f4a8d03d --- /dev/null +++ b/src/a2a/client/base_client.py @@ -0,0 +1,241 @@ +from collections.abc import AsyncIterator + +from a2a.client.client import ( + Client, + ClientCallContext, + ClientConfig, + ClientEvent, + Consumer, +) +from a2a.client.client_task_manager import ClientTaskManager +from a2a.client.errors import A2AClientInvalidStateError +from a2a.client.middleware import ClientCallInterceptor +from a2a.client.transports.base import ClientTransport +from a2a.types import ( + AgentCard, + GetTaskPushNotificationConfigParams, + Message, + MessageSendConfiguration, + MessageSendParams, + Task, + TaskArtifactUpdateEvent, + TaskIdParams, + TaskPushNotificationConfig, + TaskQueryParams, + TaskStatusUpdateEvent, +) + + +class BaseClient(Client): + """Base implementation of the A2A client, containing transport-independent logic.""" + + def __init__( + self, + card: AgentCard, + config: ClientConfig, + transport: ClientTransport, + consumers: list[Consumer], + middleware: list[ClientCallInterceptor], + ): + super().__init__(consumers, middleware) + self._card = card + self._config = config + self._transport = transport + + async def send_message( + self, + request: Message, + *, + context: ClientCallContext | None = None, + ) -> AsyncIterator[ClientEvent | Message]: + """Sends a message to the agent. + + This method handles both streaming and non-streaming (polling) interactions + based on the client configuration and agent capabilities. It will yield + events as they are received from the agent. + + Args: + request: The message to send to the agent. + context: The client call context. + + Yields: + An async iterator of `ClientEvent` or a final `Message` response. + """ + config = MessageSendConfiguration( + accepted_output_modes=self._config.accepted_output_modes, + blocking=not self._config.polling, + push_notification_config=( + self._config.push_notification_configs[0] + if self._config.push_notification_configs + else None + ), + ) + params = MessageSendParams(message=request, configuration=config) + + if not self._config.streaming or not self._card.capabilities.streaming: + response = await self._transport.send_message( + params, context=context + ) + result = ( + (response, None) if isinstance(response, Task) else response + ) + await self.consume(result, self._card) + yield result + return + + tracker = ClientTaskManager() + stream = self._transport.send_message_streaming(params, context=context) + + first_event = await anext(stream) + # The response from a server may be either exactly one Message or a + # series of Task updates. Separate out the first message for special + # case handling, which allows us to simplify further stream processing. + if isinstance(first_event, Message): + await self.consume(first_event, self._card) + yield first_event + return + + yield await self._process_response(tracker, first_event) + + async for event in stream: + yield await self._process_response(tracker, event) + + async def _process_response( + self, + tracker: ClientTaskManager, + event: Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent, + ) -> ClientEvent: + if isinstance(event, Message): + raise A2AClientInvalidStateError( + 'received a streamed Message from server after first response; this is not supported' + ) + await tracker.process(event) + task = tracker.get_task_or_raise() + update = None if isinstance(event, Task) else event + client_event = (task, update) + await self.consume(client_event, self._card) + return client_event + + async def get_task( + self, + request: TaskQueryParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + """Retrieves the current state and history of a specific task. + + Args: + request: The `TaskQueryParams` object specifying the task ID. + context: The client call context. + + Returns: + A `Task` object representing the current state of the task. + """ + return await self._transport.get_task(request, context=context) + + async def cancel_task( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + """Requests the agent to cancel a specific task. + + Args: + request: The `TaskIdParams` object specifying the task ID. + context: The client call context. + + Returns: + A `Task` object containing the updated task status. + """ + return await self._transport.cancel_task(request, context=context) + + async def set_task_callback( + self, + request: TaskPushNotificationConfig, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Sets or updates the push notification configuration for a specific task. + + Args: + request: The `TaskPushNotificationConfig` object with the new configuration. + context: The client call context. + + Returns: + The created or updated `TaskPushNotificationConfig` object. + """ + return await self._transport.set_task_callback(request, context=context) + + async def get_task_callback( + self, + request: GetTaskPushNotificationConfigParams, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Retrieves the push notification configuration for a specific task. + + Args: + request: The `GetTaskPushNotificationConfigParams` object specifying the task. + context: The client call context. + + Returns: + A `TaskPushNotificationConfig` object containing the configuration. + """ + return await self._transport.get_task_callback(request, context=context) + + async def resubscribe( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> AsyncIterator[ClientEvent]: + """Resubscribes to a task's event stream. + + This is only available if both the client and server support streaming. + + Args: + request: Parameters to identify the task to resubscribe to. + context: The client call context. + + Yields: + An async iterator of `ClientEvent` objects. + + Raises: + NotImplementedError: If streaming is not supported by the client or server. + """ + if not self._config.streaming or not self._card.capabilities.streaming: + raise NotImplementedError( + 'client and/or server do not support resubscription.' + ) + + tracker = ClientTaskManager() + # Note: resubscribe can only be called on an existing task. As such, + # we should never see Message updates, despite the typing of the service + # definition indicating it may be possible. + async for event in self._transport.resubscribe( + request, context=context + ): + yield await self._process_response(tracker, event) + + async def get_card( + self, *, context: ClientCallContext | None = None + ) -> AgentCard: + """Retrieves the agent's card. + + This will fetch the authenticated card if necessary and update the + client's internal state with the new card. + + Args: + context: The client call context. + + Returns: + The `AgentCard` for the agent. + """ + card = await self._transport.get_card(context=context) + self._card = card + return card + + async def close(self) -> None: + """Closes the underlying transport.""" + await self._transport.close() diff --git a/src/a2a/client/card_resolver.py b/src/a2a/client/card_resolver.py new file mode 100644 index 00000000..9df55152 --- /dev/null +++ b/src/a2a/client/card_resolver.py @@ -0,0 +1,108 @@ +import json +import logging + +from typing import Any + +import httpx + +from pydantic import ValidationError + +from a2a.client.errors import ( + A2AClientHTTPError, + A2AClientJSONError, +) +from a2a.types import ( + AgentCard, +) +from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH + + +logger = logging.getLogger(__name__) + + +class A2ACardResolver: + """Agent Card resolver.""" + + def __init__( + self, + httpx_client: httpx.AsyncClient, + base_url: str, + agent_card_path: str = AGENT_CARD_WELL_KNOWN_PATH, + ) -> None: + """Initializes the A2ACardResolver. + + Args: + httpx_client: An async HTTP client instance (e.g., httpx.AsyncClient). + base_url: The base URL of the agent's host. + agent_card_path: The path to the agent card endpoint, relative to the base URL. + """ + self.base_url = base_url.rstrip('/') + self.agent_card_path = agent_card_path.lstrip('/') + self.httpx_client = httpx_client + + async def get_agent_card( + self, + relative_card_path: str | None = None, + http_kwargs: dict[str, Any] | None = None, + ) -> AgentCard: + """Fetches an agent card from a specified path relative to the base_url. + + If relative_card_path is None, it defaults to the resolver's configured + agent_card_path (for the public agent card). + + Args: + relative_card_path: Optional path to the agent card endpoint, + relative to the base URL. If None, uses the default public + agent card path. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.get request. + + Returns: + An `AgentCard` object representing the agent's capabilities. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON + or validated against the AgentCard schema. + """ + if relative_card_path is None: + # Use the default public agent card path configured during initialization + path_segment = self.agent_card_path + else: + path_segment = relative_card_path.lstrip('/') + + target_url = f'{self.base_url}/{path_segment}' + + try: + response = await self.httpx_client.get( + target_url, + **(http_kwargs or {}), + ) + response.raise_for_status() + agent_card_data = response.json() + logger.info( + 'Successfully fetched agent card data from %s: %s', + target_url, + agent_card_data, + ) + agent_card = AgentCard.model_validate(agent_card_data) + except httpx.HTTPStatusError as e: + raise A2AClientHTTPError( + e.response.status_code, + f'Failed to fetch agent card from {target_url}: {e}', + ) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError( + f'Failed to parse JSON for agent card from {target_url}: {e}' + ) from e + except httpx.RequestError as e: + raise A2AClientHTTPError( + 503, + f'Network communication error fetching agent card from {target_url}: {e}', + ) from e + except ValidationError as e: # Pydantic validation error + raise A2AClientJSONError( + f'Failed to validate agent card structure from {target_url}: {e.json()}' + ) from e + + return agent_card diff --git a/src/a2a/client/client.py b/src/a2a/client/client.py index 66dfe0a4..7cc10423 100644 --- a/src/a2a/client/client.py +++ b/src/a2a/client/client.py @@ -1,500 +1,197 @@ -import json +import dataclasses import logging -from collections.abc import AsyncGenerator +from abc import ABC, abstractmethod +from collections.abc import AsyncIterator, Callable, Coroutine from typing import Any -from uuid import uuid4 import httpx -from httpx_sse import SSEError, aconnect_sse -from pydantic import ValidationError - -from a2a.client.errors import ( - A2AClientHTTPError, - A2AClientJSONError, - A2AClientTimeoutError, -) from a2a.client.middleware import ClientCallContext, ClientCallInterceptor +from a2a.client.optionals import Channel from a2a.types import ( AgentCard, - CancelTaskRequest, - CancelTaskResponse, - GetTaskPushNotificationConfigRequest, - GetTaskPushNotificationConfigResponse, - GetTaskRequest, - GetTaskResponse, - SendMessageRequest, - SendMessageResponse, - SendStreamingMessageRequest, - SendStreamingMessageResponse, - SetTaskPushNotificationConfigRequest, - SetTaskPushNotificationConfigResponse, + GetTaskPushNotificationConfigParams, + Message, + PushNotificationConfig, + Task, + TaskArtifactUpdateEvent, + TaskIdParams, + TaskPushNotificationConfig, + TaskQueryParams, + TaskStatusUpdateEvent, + TransportProtocol, ) -from a2a.utils.constants import ( - AGENT_CARD_WELL_KNOWN_PATH, -) -from a2a.utils.telemetry import SpanKind, trace_class logger = logging.getLogger(__name__) -class A2ACardResolver: - """Agent Card resolver.""" +@dataclasses.dataclass +class ClientConfig: + """Configuration class for the A2AClient Factory.""" - def __init__( - self, - httpx_client: httpx.AsyncClient, - base_url: str, - agent_card_path: str = AGENT_CARD_WELL_KNOWN_PATH, - ) -> None: - """Initializes the A2ACardResolver. + streaming: bool = True + """Whether client supports streaming""" - Args: - httpx_client: An async HTTP client instance (e.g., httpx.AsyncClient). - base_url: The base URL of the agent's host. - agent_card_path: The path to the agent card endpoint, relative to the base URL. - """ - self.base_url = base_url.rstrip('/') - self.agent_card_path = agent_card_path.lstrip('/') - self.httpx_client = httpx_client + polling: bool = False + """Whether client prefers to poll for updates from message:send. It is + the callers job to check if the response is completed and if not run a + polling loop.""" - async def get_agent_card( - self, - relative_card_path: str | None = None, - http_kwargs: dict[str, Any] | None = None, - ) -> AgentCard: - """Fetches an agent card from a specified path relative to the base_url. + httpx_client: httpx.AsyncClient | None = None + """Http client to use to connect to agent.""" - If relative_card_path is None, it defaults to the resolver's configured - agent_card_path (for the public agent card). + grpc_channel_factory: Callable[[str], Channel] | None = None + """Generates a grpc connection channel for a given url.""" - Args: - relative_card_path: Optional path to the agent card endpoint, - relative to the base URL. If None, uses the default public - agent card path. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.get request. - - Returns: - An `AgentCard` object representing the agent's capabilities. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON - or validated against the AgentCard schema. - """ - if relative_card_path is None: - # Use the default public agent card path configured during initialization - path_segment = self.agent_card_path - else: - path_segment = relative_card_path.lstrip('/') - - target_url = f'{self.base_url}/{path_segment}' - - try: - response = await self.httpx_client.get( - target_url, - **(http_kwargs or {}), - ) - response.raise_for_status() - agent_card_data = response.json() - logger.info( - 'Successfully fetched agent card data from %s: %s', - target_url, - agent_card_data, - ) - agent_card = AgentCard.model_validate(agent_card_data) - except httpx.HTTPStatusError as e: - raise A2AClientHTTPError( - e.response.status_code, - f'Failed to fetch agent card from {target_url}: {e}', - ) from e - except json.JSONDecodeError as e: - raise A2AClientJSONError( - f'Failed to parse JSON for agent card from {target_url}: {e}' - ) from e - except httpx.RequestError as e: - raise A2AClientHTTPError( - 503, - f'Network communication error fetching agent card from {target_url}: {e}', - ) from e - except ValidationError as e: # Pydantic validation error - raise A2AClientJSONError( - f'Failed to validate agent card structure from {target_url}: {e.json()}' - ) from e - - return agent_card - - -@trace_class(kind=SpanKind.CLIENT) -class A2AClient: - """A2A Client for interacting with an A2A agent.""" + supported_transports: list[TransportProtocol | str] = dataclasses.field( + default_factory=list + ) + """Ordered list of transports for connecting to agent + (in order of preference). Empty implies JSONRPC only. - def __init__( - self, - httpx_client: httpx.AsyncClient, - agent_card: AgentCard | None = None, - url: str | None = None, - interceptors: list[ClientCallInterceptor] | None = None, - ): - """Initializes the A2AClient. + This is a string type to allow custom + transports to exist in closed ecosystems. + """ - Requires either an `AgentCard` or a direct `url` to the agent's RPC endpoint. + use_client_preference: bool = False + """Whether to use client transport preferences over server preferences. + Recommended to use server preferences in most situations.""" - Args: - httpx_client: An async HTTP client instance (e.g., httpx.AsyncClient). - agent_card: The agent card object. If provided, `url` is taken from `agent_card.url`. - url: The direct URL to the agent's A2A RPC endpoint. Required if `agent_card` is None. - interceptors: An optional list of client call interceptors to apply to requests. + accepted_output_modes: list[str] = dataclasses.field(default_factory=list) + """The set of accepted output modes for the client.""" - Raises: - ValueError: If neither `agent_card` nor `url` is provided. - """ - if agent_card: - self.url = agent_card.url - elif url: - self.url = url - else: - raise ValueError('Must provide either agent_card or url') - - self.httpx_client = httpx_client - self.agent_card = agent_card - self.interceptors = interceptors or [] - - async def _apply_interceptors( - self, - method_name: str, - request_payload: dict[str, Any], - http_kwargs: dict[str, Any] | None, - context: ClientCallContext | None, - ) -> tuple[dict[str, Any], dict[str, Any]]: - """Applies all registered interceptors to the request.""" - final_http_kwargs = http_kwargs or {} - final_request_payload = request_payload - - for interceptor in self.interceptors: - ( - final_request_payload, - final_http_kwargs, - ) = await interceptor.intercept( - method_name, - final_request_payload, - final_http_kwargs, - self.agent_card, - context, - ) - return final_request_payload, final_http_kwargs - - @staticmethod - async def get_client_from_agent_card_url( - httpx_client: httpx.AsyncClient, - base_url: str, - agent_card_path: str = AGENT_CARD_WELL_KNOWN_PATH, - http_kwargs: dict[str, Any] | None = None, - ) -> 'A2AClient': - """Fetches the public AgentCard and initializes an A2A client. - - This method will always fetch the public agent card. If an authenticated - or extended agent card is required, the A2ACardResolver should be used - directly to fetch the specific card, and then the A2AClient should be - instantiated with it. + push_notification_configs: list[PushNotificationConfig] = dataclasses.field( + default_factory=list + ) + """Push notification callbacks to use for every request.""" - Args: - httpx_client: An async HTTP client instance (e.g., httpx.AsyncClient). - base_url: The base URL of the agent's host. - agent_card_path: The path to the agent card endpoint, relative to the base URL. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.get request when fetching the agent card. - - Returns: - An initialized `A2AClient` instance. - - Raises: - A2AClientHTTPError: If an HTTP error occurs fetching the agent card. - A2AClientJSONError: If the agent card response is invalid. - """ - agent_card: AgentCard = await A2ACardResolver( - httpx_client, base_url=base_url, agent_card_path=agent_card_path - ).get_agent_card( - http_kwargs=http_kwargs - ) # Fetches public card by default - return A2AClient(httpx_client=httpx_client, agent_card=agent_card) - async def send_message( - self, - request: SendMessageRequest, - *, - http_kwargs: dict[str, Any] | None = None, - context: ClientCallContext | None = None, - ) -> SendMessageResponse: - """Sends a non-streaming message request to the agent. +UpdateEvent = TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None +# Alias for emitted events from client +ClientEvent = tuple[Task, UpdateEvent] +# Alias for an event consuming callback. It takes either a (task, update) pair +# or a message as well as the agent card for the agent this came from. +Consumer = Callable[ + [ClientEvent | Message, AgentCard], Coroutine[None, Any, Any] +] - Args: - request: The `SendMessageRequest` object containing the message and configuration. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. - Returns: - A `SendMessageResponse` object containing the agent's response (Task or Message) or an error. +class Client(ABC): + """Abstract base class defining the interface for an A2A client. - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - if not request.id: - request.id = str(uuid4()) - - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - 'message/send', - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - context, - ) - response_data = await self._send_request(payload, modified_kwargs) - return SendMessageResponse.model_validate(response_data) - - async def send_message_streaming( - self, - request: SendStreamingMessageRequest, - *, - http_kwargs: dict[str, Any] | None = None, - context: ClientCallContext | None = None, - ) -> AsyncGenerator[SendStreamingMessageResponse]: - """Sends a streaming message request to the agent and yields responses as they arrive. + This class provides a standard set of methods for interacting with an A2A + agent, regardless of the underlying transport protocol (e.g., gRPC, JSON-RPC). + It supports sending messages, managing tasks, and handling event streams. + """ - This method uses Server-Sent Events (SSE) to receive a stream of updates from the agent. + def __init__( + self, + consumers: list[Consumer] | None = None, + middleware: list[ClientCallInterceptor] | None = None, + ): + """Initializes the client with consumers and middleware. Args: - request: The `SendStreamingMessageRequest` object containing the message and configuration. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. A default `timeout=None` is set but can be overridden. - context: The client call context. - - Yields: - `SendStreamingMessageResponse` objects as they are received in the SSE stream. - These can be Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent. - - Raises: - A2AClientHTTPError: If an HTTP or SSE protocol error occurs during the request. - A2AClientJSONError: If an SSE event data cannot be decoded as JSON or validated. + consumers: A list of callables to process events from the agent. + middleware: A list of interceptors to process requests and responses. """ - if not request.id: - request.id = str(uuid4()) - - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - 'message/stream', - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - context, - ) - - modified_kwargs.setdefault('timeout', None) - - async with aconnect_sse( - self.httpx_client, - 'POST', - self.url, - json=payload, - **modified_kwargs, - ) as event_source: - try: - async for sse in event_source.aiter_sse(): - yield SendStreamingMessageResponse.model_validate( - json.loads(sse.data) - ) - except SSEError as e: - raise A2AClientHTTPError( - 400, - f'Invalid SSE response or protocol error: {e}', - ) from e - except json.JSONDecodeError as e: - raise A2AClientJSONError(str(e)) from e - except httpx.RequestError as e: - raise A2AClientHTTPError( - 503, f'Network communication error: {e}' - ) from e - - async def _send_request( + if middleware is None: + middleware = [] + if consumers is None: + consumers = [] + self._consumers = consumers + self._middleware = middleware + + @abstractmethod + async def send_message( self, - rpc_request_payload: dict[str, Any], - http_kwargs: dict[str, Any] | None = None, - ) -> dict[str, Any]: - """Sends a non-streaming JSON-RPC request to the agent. - - Args: - rpc_request_payload: JSON RPC payload for sending the request. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - - Returns: - The JSON response payload as a dictionary. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON. + request: Message, + *, + context: ClientCallContext | None = None, + ) -> AsyncIterator[ClientEvent | Message]: + """Sends a message to the server. + + This will automatically use the streaming or non-streaming approach + as supported by the server and the client config. Client will + aggregate update events and return an iterator of (`Task`,`Update`) + pairs, or a `Message`. Client will also send these values to any + configured `Consumer`s in the client. """ - try: - response = await self.httpx_client.post( - self.url, json=rpc_request_payload, **(http_kwargs or {}) - ) - response.raise_for_status() - return response.json() - except httpx.ReadTimeout as e: - raise A2AClientTimeoutError('Client Request timed out') from e - except httpx.HTTPStatusError as e: - raise A2AClientHTTPError(e.response.status_code, str(e)) from e - except json.JSONDecodeError as e: - raise A2AClientJSONError(str(e)) from e - except httpx.RequestError as e: - raise A2AClientHTTPError( - 503, f'Network communication error: {e}' - ) from e + return + yield + @abstractmethod async def get_task( self, - request: GetTaskRequest, + request: TaskQueryParams, *, - http_kwargs: dict[str, Any] | None = None, context: ClientCallContext | None = None, - ) -> GetTaskResponse: - """Retrieves the current state and history of a specific task. - - Args: - request: The `GetTaskRequest` object specifying the task ID and history length. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. - - Returns: - A `GetTaskResponse` object containing the Task or an error. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - if not request.id: - request.id = str(uuid4()) - - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - 'tasks/get', - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - context, - ) - response_data = await self._send_request(payload, modified_kwargs) - return GetTaskResponse.model_validate(response_data) + ) -> Task: + """Retrieves the current state and history of a specific task.""" + @abstractmethod async def cancel_task( self, - request: CancelTaskRequest, + request: TaskIdParams, *, - http_kwargs: dict[str, Any] | None = None, context: ClientCallContext | None = None, - ) -> CancelTaskResponse: - """Requests the agent to cancel a specific task. - - Args: - request: The `CancelTaskRequest` object specifying the task ID. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. - - Returns: - A `CancelTaskResponse` object containing the updated Task with canceled status or an error. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - if not request.id: - request.id = str(uuid4()) - - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - 'tasks/cancel', - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - context, - ) - response_data = await self._send_request(payload, modified_kwargs) - return CancelTaskResponse.model_validate(response_data) + ) -> Task: + """Requests the agent to cancel a specific task.""" + @abstractmethod async def set_task_callback( self, - request: SetTaskPushNotificationConfigRequest, + request: TaskPushNotificationConfig, *, - http_kwargs: dict[str, Any] | None = None, context: ClientCallContext | None = None, - ) -> SetTaskPushNotificationConfigResponse: - """Sets or updates the push notification configuration for a specific task. - - Args: - request: The `SetTaskPushNotificationConfigRequest` object specifying the task ID and configuration. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. - - Returns: - A `SetTaskPushNotificationConfigResponse` object containing the confirmation or an error. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - if not request.id: - request.id = str(uuid4()) - - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - 'tasks/pushNotificationConfig/set', - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - context, - ) - response_data = await self._send_request(payload, modified_kwargs) - return SetTaskPushNotificationConfigResponse.model_validate( - response_data - ) + ) -> TaskPushNotificationConfig: + """Sets or updates the push notification configuration for a specific task.""" + @abstractmethod async def get_task_callback( self, - request: GetTaskPushNotificationConfigRequest, + request: GetTaskPushNotificationConfigParams, *, - http_kwargs: dict[str, Any] | None = None, context: ClientCallContext | None = None, - ) -> GetTaskPushNotificationConfigResponse: - """Retrieves the push notification configuration for a specific task. + ) -> TaskPushNotificationConfig: + """Retrieves the push notification configuration for a specific task.""" - Args: - request: The `GetTaskPushNotificationConfigRequest` object specifying the task ID. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. + @abstractmethod + async def resubscribe( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> AsyncIterator[ClientEvent]: + """Resubscribes to a task's event stream.""" + return + yield + + @abstractmethod + async def get_card( + self, *, context: ClientCallContext | None = None + ) -> AgentCard: + """Retrieves the agent's card.""" - Returns: - A `GetTaskPushNotificationConfigResponse` object containing the configuration or an error. + async def add_event_consumer(self, consumer: Consumer) -> None: + """Attaches additional consumers to the `Client`.""" + self._consumers.append(consumer) - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - if not request.id: - request.id = str(uuid4()) - - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - 'tasks/pushNotificationConfig/get', - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - context, - ) - response_data = await self._send_request(payload, modified_kwargs) - return GetTaskPushNotificationConfigResponse.model_validate( - response_data - ) + async def add_request_middleware( + self, middleware: ClientCallInterceptor + ) -> None: + """Attaches additional middleware to the `Client`.""" + self._middleware.append(middleware) + + async def consume( + self, + event: tuple[Task, UpdateEvent] | Message | None, + card: AgentCard, + ) -> None: + """Processes the event via all the registered `Consumer`s.""" + if not event: + return + for c in self._consumers: + await c(event, card) diff --git a/src/a2a/client/client_factory.py b/src/a2a/client/client_factory.py new file mode 100644 index 00000000..c568331f --- /dev/null +++ b/src/a2a/client/client_factory.py @@ -0,0 +1,197 @@ +from __future__ import annotations + +import logging + +from collections.abc import Callable + +import httpx + +from a2a.client.base_client import BaseClient +from a2a.client.client import Client, ClientConfig, Consumer +from a2a.client.middleware import ClientCallInterceptor +from a2a.client.transports.base import ClientTransport +from a2a.client.transports.jsonrpc import JsonRpcTransport +from a2a.client.transports.rest import RestTransport +from a2a.types import ( + AgentCapabilities, + AgentCard, + AgentInterface, + TransportProtocol, +) + + +try: + from a2a.client.transports.grpc import GrpcTransport +except ImportError: + GrpcTransport = None # type: ignore # pyright: ignore + + +logger = logging.getLogger(__name__) + + +TransportProducer = Callable[ + [AgentCard, str, ClientConfig, list[ClientCallInterceptor]], + ClientTransport, +] + + +class ClientFactory: + """ClientFactory is used to generate the appropriate client for the agent. + + The factory is configured with a `ClientConfig` and optionally a list of + `Consumer`s to use for all generated `Client`s. The expected use is: + + factory = ClientFactory(config, consumers) + # Optionally register custom client implementations + factory.register('my_customer_transport', NewCustomTransportClient) + # Then with an agent card make a client with additional consumers and + # interceptors + client = factory.create(card, additional_consumers, interceptors) + # Now the client can be used the same regardless of transport and + # aligns client config with server capabilities. + """ + + def __init__( + self, + config: ClientConfig, + consumers: list[Consumer] | None = None, + ): + if consumers is None: + consumers = [] + self._config = config + self._consumers = consumers + self._registry: dict[str, TransportProducer] = {} + self._register_defaults(config.supported_transports) + + def _register_defaults( + self, supported: list[str | TransportProtocol] + ) -> None: + # Empty support list implies JSON-RPC only. + if TransportProtocol.jsonrpc in supported or not supported: + self.register( + TransportProtocol.jsonrpc, + lambda card, url, config, interceptors: JsonRpcTransport( + config.httpx_client or httpx.AsyncClient(), + card, + url, + interceptors, + ), + ) + if TransportProtocol.http_json in supported: + self.register( + TransportProtocol.http_json, + lambda card, url, config, interceptors: RestTransport( + config.httpx_client or httpx.AsyncClient(), + card, + url, + interceptors, + ), + ) + if TransportProtocol.grpc in supported: + if GrpcTransport is None: + raise ImportError( + 'To use GrpcClient, its dependencies must be installed. ' + 'You can install them with \'pip install "a2a-sdk[grpc]"\'' + ) + self.register( + TransportProtocol.grpc, + GrpcTransport.create, + ) + + def register(self, label: str, generator: TransportProducer) -> None: + """Register a new transport producer for a given transport label.""" + self._registry[label] = generator + + def create( + self, + card: AgentCard, + consumers: list[Consumer] | None = None, + interceptors: list[ClientCallInterceptor] | None = None, + ) -> Client: + """Create a new `Client` for the provided `AgentCard`. + + Args: + card: An `AgentCard` defining the characteristics of the agent. + consumers: A list of `Consumer` methods to pass responses to. + interceptors: A list of interceptors to use for each request. These + are used for things like attaching credentials or http headers + to all outbound requests. + + Returns: + A `Client` object. + + Raises: + If there is no valid matching of the client configuration with the + server configuration, a `ValueError` is raised. + """ + server_preferred = card.preferred_transport or TransportProtocol.jsonrpc + server_set = {server_preferred: card.url} + if card.additional_interfaces: + server_set.update( + {x.transport: x.url for x in card.additional_interfaces} + ) + client_set = self._config.supported_transports or [ + TransportProtocol.jsonrpc + ] + transport_protocol = None + transport_url = None + if self._config.use_client_preference: + for x in client_set: + if x in server_set: + transport_protocol = x + transport_url = server_set[x] + break + else: + for x, url in server_set.items(): + if x in client_set: + transport_protocol = x + transport_url = url + break + if not transport_protocol or not transport_url: + raise ValueError('no compatible transports found.') + if transport_protocol not in self._registry: + raise ValueError(f'no client available for {transport_protocol}') + + all_consumers = self._consumers.copy() + if consumers: + all_consumers.extend(consumers) + + transport = self._registry[transport_protocol]( + card, transport_url, self._config, interceptors or [] + ) + + return BaseClient( + card, self._config, transport, all_consumers, interceptors or [] + ) + + +def minimal_agent_card( + url: str, transports: list[str] | None = None +) -> AgentCard: + """Generates a minimal card to simplify bootstrapping client creation. + + This minimal card is not viable itself to interact with the remote agent. + Instead this is a short hand way to take a known url and transport option + and interact with the get card endpoint of the agent server to get the + correct agent card. This pattern is necessary for gRPC based card access + as typically these servers won't expose a well known path card. + """ + if transports is None: + transports = [] + return AgentCard( + url=url, + preferred_transport=transports[0] if transports else None, + additional_interfaces=[ + AgentInterface(transport=t, url=url) for t in transports[1:] + ] + if len(transports) > 1 + else [], + supports_authenticated_extended_card=True, + capabilities=AgentCapabilities(), + default_input_modes=[], + default_output_modes=[], + description='', + skills=[], + version='', + name='', + ) diff --git a/src/a2a/client/client_task_manager.py b/src/a2a/client/client_task_manager.py new file mode 100644 index 00000000..060983e1 --- /dev/null +++ b/src/a2a/client/client_task_manager.py @@ -0,0 +1,192 @@ +import logging + +from a2a.client.errors import ( + A2AClientInvalidArgsError, + A2AClientInvalidStateError, +) +from a2a.server.events.event_queue import Event +from a2a.types import ( + Message, + Task, + TaskArtifactUpdateEvent, + TaskState, + TaskStatus, + TaskStatusUpdateEvent, +) +from a2a.utils import append_artifact_to_task + + +logger = logging.getLogger(__name__) + + +class ClientTaskManager: + """Helps manage a task's lifecycle during execution of a request. + + Responsible for retrieving, saving, and updating the `Task` object based on + events received from the agent. + """ + + def __init__( + self, + ) -> None: + """Initializes the `ClientTaskManager`.""" + self._current_task: Task | None = None + self._task_id: str | None = None + self._context_id: str | None = None + + def get_task(self) -> Task | None: + """Retrieves the current task object, either from memory. + + If `task_id` is set, it returns `_current_task` otherwise None. + + Returns: + The `Task` object if found, otherwise `None`. + """ + if not self._task_id: + logger.debug('task_id is not set, cannot get task.') + return None + + return self._current_task + + def get_task_or_raise(self) -> Task: + """Retrieves the current task object. + + Returns: + The `Task` object. + + Raises: + A2AClientInvalidStateError: If there is no current known Task. + """ + if not (task := self.get_task()): + # Note: The source of this error is either from bad client usage + # or from the server sending invalid updates. It indicates that this + # task manager has not consumed any information about a task, yet + # the caller is attempting to retrieve the current state of the task + # it expects to be present. + raise A2AClientInvalidStateError('no current Task') + return task + + async def save_task_event( + self, event: Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent + ) -> Task | None: + """Processes a task-related event (Task, Status, Artifact) and saves the updated task state. + + Ensures task and context IDs match or are set from the event. + + Args: + event: The task-related event (`Task`, `TaskStatusUpdateEvent`, or `TaskArtifactUpdateEvent`). + + Returns: + The updated `Task` object after processing the event. + + Raises: + ClientError: If the task ID in the event conflicts with the TaskManager's ID + when the TaskManager's ID is already set. + """ + if isinstance(event, Task): + if self._current_task: + raise A2AClientInvalidArgsError( + 'Task is already set, create new manager for new tasks.' + ) + await self._save_task(event) + return event + task_id_from_event = ( + event.id if isinstance(event, Task) else event.task_id + ) + if not self._task_id: + self._task_id = task_id_from_event + if not self._context_id: + self._context_id = event.context_id + + logger.debug( + 'Processing save of task event of type %s for task_id: %s', + type(event).__name__, + task_id_from_event, + ) + + task = self._current_task + if not task: + task = Task( + status=TaskStatus(state=TaskState.unknown), + id=task_id_from_event, + context_id=self._context_id if self._context_id else '', + ) + if isinstance(event, TaskStatusUpdateEvent): + logger.debug( + 'Updating task %s status to: %s', + event.task_id, + event.status.state, + ) + if event.status.message: + if not task.history: + task.history = [event.status.message] + else: + task.history.append(event.status.message) + if event.metadata: + if not task.metadata: + task.metadata = {} + task.metadata.update(event.metadata) + task.status = event.status + else: + logger.debug('Appending artifact to task %s', task.id) + append_artifact_to_task(task, event) + self._current_task = task + return task + + async def process(self, event: Event) -> Event: + """Processes an event, updates the task state if applicable, stores it, and returns the event. + + If the event is task-related (`Task`, `TaskStatusUpdateEvent`, `TaskArtifactUpdateEvent`), + the internal task state is updated and persisted. + + Args: + event: The event object received from the agent. + + Returns: + The same event object that was processed. + """ + if isinstance( + event, Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent + ): + await self.save_task_event(event) + + return event + + async def _save_task(self, task: Task) -> None: + """Saves the given task to the `_current_task` and updated `_task_id` and `_context_id`. + + Args: + task: The `Task` object to save. + """ + logger.debug('Saving task with id: %s', task.id) + self._current_task = task + if not self._task_id: + logger.info('New task created with id: %s', task.id) + self._task_id = task.id + self._context_id = task.context_id + + def update_with_message(self, message: Message, task: Task) -> Task: + """Updates a task object adding a new message to its history. + + If the task has a message in its current status, that message is moved + to the history first. + + Args: + message: The new `Message` to add to the history. + task: The `Task` object to update. + + Returns: + The updated `Task` object (updated in-place). + """ + if task.status.message: + if task.history: + task.history.append(task.status.message) + else: + task.history = [task.status.message] + task.status.message = None + if task.history: + task.history.append(message) + else: + task.history = [message] + self._current_task = task + return task diff --git a/src/a2a/client/errors.py b/src/a2a/client/errors.py index 5fe5512a..890c3726 100644 --- a/src/a2a/client/errors.py +++ b/src/a2a/client/errors.py @@ -1,5 +1,7 @@ """Custom exceptions for the A2A client.""" +from a2a.types import JSONRPCErrorResponse + class A2AClientError(Exception): """Base exception for A2A Client errors.""" @@ -44,3 +46,42 @@ def __init__(self, message: str): """ self.message = message super().__init__(f'Timeout Error: {message}') + + +class A2AClientInvalidArgsError(A2AClientError): + """Client exception for invalid arguments passed to a method.""" + + def __init__(self, message: str): + """Initializes the A2AClientInvalidArgsError. + + Args: + message: A descriptive error message. + """ + self.message = message + super().__init__(f'Invalid arguments error: {message}') + + +class A2AClientInvalidStateError(A2AClientError): + """Client exception for an invalid client state.""" + + def __init__(self, message: str): + """Initializes the A2AClientInvalidStateError. + + Args: + message: A descriptive error message. + """ + self.message = message + super().__init__(f'Invalid state error: {message}') + + +class A2AClientJSONRPCError(A2AClientError): + """Client exception for JSON-RPC errors returned by the server.""" + + def __init__(self, error: JSONRPCErrorResponse): + """Initializes the A2AClientJsonRPCError. + + Args: + error: The JSON-RPC error object. + """ + self.error = error.error + super().__init__(f'JSON-RPC Error {error.error}') diff --git a/src/a2a/client/legacy.py b/src/a2a/client/legacy.py new file mode 100644 index 00000000..4318543d --- /dev/null +++ b/src/a2a/client/legacy.py @@ -0,0 +1,344 @@ +"""Backwards compatibility layer for legacy A2A clients.""" + +import warnings + +from collections.abc import AsyncGenerator +from typing import Any + +import httpx + +from a2a.client.errors import A2AClientJSONRPCError +from a2a.client.middleware import ClientCallContext, ClientCallInterceptor +from a2a.client.transports.jsonrpc import JsonRpcTransport +from a2a.types import ( + AgentCard, + CancelTaskRequest, + CancelTaskResponse, + CancelTaskSuccessResponse, + GetTaskPushNotificationConfigParams, + GetTaskPushNotificationConfigRequest, + GetTaskPushNotificationConfigResponse, + GetTaskPushNotificationConfigSuccessResponse, + GetTaskRequest, + GetTaskResponse, + GetTaskSuccessResponse, + JSONRPCErrorResponse, + SendMessageRequest, + SendMessageResponse, + SendMessageSuccessResponse, + SendStreamingMessageRequest, + SendStreamingMessageResponse, + SendStreamingMessageSuccessResponse, + SetTaskPushNotificationConfigRequest, + SetTaskPushNotificationConfigResponse, + SetTaskPushNotificationConfigSuccessResponse, + TaskIdParams, + TaskResubscriptionRequest, +) + + +class A2AClient: + """[DEPRECATED] Backwards compatibility wrapper for the JSON-RPC client.""" + + def __init__( + self, + httpx_client: httpx.AsyncClient, + agent_card: AgentCard | None = None, + url: str | None = None, + interceptors: list[ClientCallInterceptor] | None = None, + ): + warnings.warn( + 'A2AClient is deprecated and will be removed in a future version. ' + 'Use ClientFactory to create a client with a JSON-RPC transport.', + DeprecationWarning, + stacklevel=2, + ) + self._transport = JsonRpcTransport( + httpx_client, agent_card, url, interceptors + ) + + async def send_message( + self, + request: SendMessageRequest, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> SendMessageResponse: + """Sends a non-streaming message request to the agent. + + Args: + request: The `SendMessageRequest` object containing the message and configuration. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `SendMessageResponse` object containing the agent's response (Task or Message) or an error. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ + if not context and http_kwargs: + context = ClientCallContext(state={'http_kwargs': http_kwargs}) + + try: + result = await self._transport.send_message( + request.params, context=context + ) + return SendMessageResponse( + root=SendMessageSuccessResponse( + id=request.id, jsonrpc='2.0', result=result + ) + ) + except A2AClientJSONRPCError as e: + return SendMessageResponse(JSONRPCErrorResponse(error=e.error)) + + async def send_message_streaming( + self, + request: SendStreamingMessageRequest, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[SendStreamingMessageResponse, None]: + """Sends a streaming message request to the agent and yields responses as they arrive. + + This method uses Server-Sent Events (SSE) to receive a stream of updates from the agent. + + Args: + request: The `SendStreamingMessageRequest` object containing the message and configuration. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. A default `timeout=None` is set but can be overridden. + context: The client call context. + + Yields: + `SendStreamingMessageResponse` objects as they are received in the SSE stream. + These can be Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent. + + Raises: + A2AClientHTTPError: If an HTTP or SSE protocol error occurs during the request. + A2AClientJSONError: If an SSE event data cannot be decoded as JSON or validated. + """ + if not context and http_kwargs: + context = ClientCallContext(state={'http_kwargs': http_kwargs}) + + async for result in self._transport.send_message_streaming( + request.params, context=context + ): + yield SendStreamingMessageResponse( + root=SendStreamingMessageSuccessResponse( + id=request.id, jsonrpc='2.0', result=result + ) + ) + + async def get_task( + self, + request: GetTaskRequest, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> GetTaskResponse: + """Retrieves the current state and history of a specific task. + + Args: + request: The `GetTaskRequest` object specifying the task ID and history length. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `GetTaskResponse` object containing the Task or an error. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ + if not context and http_kwargs: + context = ClientCallContext(state={'http_kwargs': http_kwargs}) + try: + result = await self._transport.get_task( + request.params, context=context + ) + return GetTaskResponse( + root=GetTaskSuccessResponse( + id=request.id, jsonrpc='2.0', result=result + ) + ) + except A2AClientJSONRPCError as e: + return GetTaskResponse(root=JSONRPCErrorResponse(error=e.error)) + + async def cancel_task( + self, + request: CancelTaskRequest, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> CancelTaskResponse: + """Requests the agent to cancel a specific task. + + Args: + request: The `CancelTaskRequest` object specifying the task ID. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `CancelTaskResponse` object containing the updated Task with canceled status or an error. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ + if not context and http_kwargs: + context = ClientCallContext(state={'http_kwargs': http_kwargs}) + try: + result = await self._transport.cancel_task( + request.params, context=context + ) + return CancelTaskResponse( + root=CancelTaskSuccessResponse( + id=request.id, jsonrpc='2.0', result=result + ) + ) + except A2AClientJSONRPCError as e: + return CancelTaskResponse(JSONRPCErrorResponse(error=e.error)) + + async def set_task_callback( + self, + request: SetTaskPushNotificationConfigRequest, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> SetTaskPushNotificationConfigResponse: + """Sets or updates the push notification configuration for a specific task. + + Args: + request: The `SetTaskPushNotificationConfigRequest` object specifying the task ID and configuration. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `SetTaskPushNotificationConfigResponse` object containing the confirmation or an error. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ + if not context and http_kwargs: + context = ClientCallContext(state={'http_kwargs': http_kwargs}) + try: + result = await self._transport.set_task_callback( + request.params, context=context + ) + return SetTaskPushNotificationConfigResponse( + root=SetTaskPushNotificationConfigSuccessResponse( + id=request.id, jsonrpc='2.0', result=result + ) + ) + except A2AClientJSONRPCError as e: + return SetTaskPushNotificationConfigResponse( + JSONRPCErrorResponse(error=e.error) + ) + + async def get_task_callback( + self, + request: GetTaskPushNotificationConfigRequest, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> GetTaskPushNotificationConfigResponse: + """Retrieves the push notification configuration for a specific task. + + Args: + request: The `GetTaskPushNotificationConfigRequest` object specifying the task ID. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `GetTaskPushNotificationConfigResponse` object containing the configuration or an error. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ + if not context and http_kwargs: + context = ClientCallContext(state={'http_kwargs': http_kwargs}) + params = request.params + if isinstance(params, TaskIdParams): + params = GetTaskPushNotificationConfigParams(id=request.params.id) + try: + result = await self._transport.get_task_callback( + params, context=context + ) + return GetTaskPushNotificationConfigResponse( + root=GetTaskPushNotificationConfigSuccessResponse( + id=request.id, jsonrpc='2.0', result=result + ) + ) + except A2AClientJSONRPCError as e: + return GetTaskPushNotificationConfigResponse( + JSONRPCErrorResponse(error=e.error) + ) + + async def resubscribe( + self, + request: TaskResubscriptionRequest, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[SendStreamingMessageResponse, None]: + """Reconnects to get task updates. + + This method uses Server-Sent Events (SSE) to receive a stream of updates from the agent. + + Args: + request: The `TaskResubscriptionRequest` object containing the task information to reconnect to. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. A default `timeout=None` is set but can be overridden. + context: The client call context. + + Yields: + `SendStreamingMessageResponse` objects as they are received in the SSE stream. + These can be Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent. + + Raises: + A2AClientHTTPError: If an HTTP or SSE protocol error occurs during the request. + A2AClientJSONError: If an SSE event data cannot be decoded as JSON or validated. + """ + if not context and http_kwargs: + context = ClientCallContext(state={'http_kwargs': http_kwargs}) + + async for result in self._transport.resubscribe( + request.params, context=context + ): + yield SendStreamingMessageResponse( + root=SendStreamingMessageSuccessResponse( + id=request.id, jsonrpc='2.0', result=result + ) + ) + + async def get_card( + self, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> AgentCard: + """Retrieves the authenticated card (if necessary) or the public one. + + Args: + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `AgentCard` object containing the card or an error. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ + if not context and http_kwargs: + context = ClientCallContext(state={'http_kwargs': http_kwargs}) + return await self._transport.get_card(context=context) diff --git a/src/a2a/client/legacy_grpc.py b/src/a2a/client/legacy_grpc.py new file mode 100644 index 00000000..0b62b009 --- /dev/null +++ b/src/a2a/client/legacy_grpc.py @@ -0,0 +1,44 @@ +"""Backwards compatibility layer for the legacy A2A gRPC client.""" + +import warnings + +from typing import TYPE_CHECKING + +from a2a.client.transports.grpc import GrpcTransport +from a2a.types import AgentCard + + +if TYPE_CHECKING: + from a2a.grpc.a2a_pb2_grpc import A2AServiceStub + + +class A2AGrpcClient(GrpcTransport): + """[DEPRECATED] Backwards compatibility wrapper for the gRPC client.""" + + def __init__( # pylint: disable=super-init-not-called + self, + grpc_stub: 'A2AServiceStub', + agent_card: AgentCard, + ): + warnings.warn( + 'A2AGrpcClient is deprecated and will be removed in a future version. ' + 'Use ClientFactory to create a client with a gRPC transport.', + DeprecationWarning, + stacklevel=2, + ) + # The old gRPC client accepted a stub directly. The new one accepts a + # channel and builds the stub itself. We just have a stub here, so we + # need to handle initialization ourselves. + self.stub = grpc_stub + self.agent_card = agent_card + self._needs_extended_card = ( + agent_card.supports_authenticated_extended_card + if agent_card + else True + ) + + class _NopChannel: + async def close(self) -> None: + pass + + self.channel = _NopChannel() diff --git a/src/a2a/client/optionals.py b/src/a2a/client/optionals.py new file mode 100644 index 00000000..f55f0186 --- /dev/null +++ b/src/a2a/client/optionals.py @@ -0,0 +1,16 @@ +from typing import TYPE_CHECKING + + +# Attempt to import the optional module +try: + from grpc.aio import Channel # pyright: ignore[reportAssignmentType] +except ImportError: + # If grpc.aio is not available, define a dummy type for type checking. + # This dummy type will only be used by type checkers. + if TYPE_CHECKING: + + class Channel: # type: ignore[no-redef] + """Dummy class for type hinting when grpc.aio is not available.""" + + else: + Channel = None # At runtime, pd will be None if the import failed. diff --git a/src/a2a/client/transports/__init__.py b/src/a2a/client/transports/__init__.py new file mode 100644 index 00000000..af7c60f6 --- /dev/null +++ b/src/a2a/client/transports/__init__.py @@ -0,0 +1,19 @@ +"""A2A Client Transports.""" + +from a2a.client.transports.base import ClientTransport +from a2a.client.transports.jsonrpc import JsonRpcTransport +from a2a.client.transports.rest import RestTransport + + +try: + from a2a.client.transports.grpc import GrpcTransport +except ImportError: + GrpcTransport = None # type: ignore + + +__all__ = [ + 'ClientTransport', + 'GrpcTransport', + 'JsonRpcTransport', + 'RestTransport', +] diff --git a/src/a2a/client/transports/base.py b/src/a2a/client/transports/base.py new file mode 100644 index 00000000..3573cb7c --- /dev/null +++ b/src/a2a/client/transports/base.py @@ -0,0 +1,103 @@ +from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator + +from a2a.client.middleware import ClientCallContext +from a2a.types import ( + AgentCard, + GetTaskPushNotificationConfigParams, + Message, + MessageSendParams, + Task, + TaskArtifactUpdateEvent, + TaskIdParams, + TaskPushNotificationConfig, + TaskQueryParams, + TaskStatusUpdateEvent, +) + + +class ClientTransport(ABC): + """Abstract base class for a client transport.""" + + @abstractmethod + async def send_message( + self, + request: MessageSendParams, + *, + context: ClientCallContext | None = None, + ) -> Task | Message: + """Sends a non-streaming message request to the agent.""" + + @abstractmethod + async def send_message_streaming( + self, + request: MessageSendParams, + *, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[ + Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent + ]: + """Sends a streaming message request to the agent and yields responses as they arrive.""" + return + yield + + @abstractmethod + async def get_task( + self, + request: TaskQueryParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + """Retrieves the current state and history of a specific task.""" + + @abstractmethod + async def cancel_task( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + """Requests the agent to cancel a specific task.""" + + @abstractmethod + async def set_task_callback( + self, + request: TaskPushNotificationConfig, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Sets or updates the push notification configuration for a specific task.""" + + @abstractmethod + async def get_task_callback( + self, + request: GetTaskPushNotificationConfigParams, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Retrieves the push notification configuration for a specific task.""" + + @abstractmethod + async def resubscribe( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[ + Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent + ]: + """Reconnects to get task updates.""" + return + yield + + @abstractmethod + async def get_card( + self, + *, + context: ClientCallContext | None = None, + ) -> AgentCard: + """Retrieves the AgentCard.""" + + @abstractmethod + async def close(self) -> None: + """Closes the transport.""" diff --git a/src/a2a/client/grpc_client.py b/src/a2a/client/transports/grpc.py similarity index 50% rename from src/a2a/client/grpc_client.py rename to src/a2a/client/transports/grpc.py index d224b201..e64c1534 100644 --- a/src/a2a/client/grpc_client.py +++ b/src/a2a/client/transports/grpc.py @@ -12,10 +12,14 @@ "'pip install a2a-sdk[grpc]'" ) from e - +from a2a.client.client import ClientConfig +from a2a.client.middleware import ClientCallContext, ClientCallInterceptor +from a2a.client.optionals import Channel +from a2a.client.transports.base import ClientTransport from a2a.grpc import a2a_pb2, a2a_pb2_grpc from a2a.types import ( AgentCard, + GetTaskPushNotificationConfigParams, Message, MessageSendParams, Task, @@ -33,37 +37,47 @@ @trace_class(kind=SpanKind.CLIENT) -class A2AGrpcClient: - """A2A Client for interacting with an A2A agent via gRPC.""" +class GrpcTransport(ClientTransport): + """A gRPC transport for the A2A client.""" def __init__( self, - grpc_stub: a2a_pb2_grpc.A2AServiceStub, - agent_card: AgentCard, + channel: Channel, + agent_card: AgentCard | None, ): - """Initializes the A2AGrpcClient. - - Requires an `AgentCard` - - Args: - grpc_stub: A grpc client stub. - agent_card: The agent card object. - """ + """Initializes the GrpcTransport.""" self.agent_card = agent_card - self.stub = grpc_stub + self.channel = channel + self.stub = a2a_pb2_grpc.A2AServiceStub(channel) + self._needs_extended_card = ( + agent_card.supports_authenticated_extended_card + if agent_card + else True + ) + + @classmethod + def create( + cls, + card: AgentCard, + url: str, + config: ClientConfig, + interceptors: list[ClientCallInterceptor], + ) -> 'GrpcTransport': + """Creates a gRPC transport for the A2A client.""" + if config.grpc_channel_factory is None: + raise ValueError('grpc_channel_factory is required when using gRPC') + return cls( + config.grpc_channel_factory(url), + card, + ) async def send_message( self, request: MessageSendParams, + *, + context: ClientCallContext | None = None, ) -> Task | Message: - """Sends a non-streaming message request to the agent. - - Args: - request: The `MessageSendParams` object containing the message and configuration. - - Returns: - A `Task` or `Message` object containing the agent's response. - """ + """Sends a non-streaming message request to the agent.""" response = await self.stub.SendMessage( a2a_pb2.SendMessageRequest( request=proto_utils.ToProto.message(request.message), @@ -80,22 +94,12 @@ async def send_message( async def send_message_streaming( self, request: MessageSendParams, + *, + context: ClientCallContext | None = None, ) -> AsyncGenerator[ Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent ]: - """Sends a streaming message request to the agent and yields responses as they arrive. - - This method uses gRPC streams to receive a stream of updates from the - agent. - - Args: - request: The `MessageSendParams` object containing the message and configuration. - - Yields: - `Message` or `Task` or `TaskStatusUpdateEvent` or - `TaskArtifactUpdateEvent` objects as they are received in the - stream. - """ + """Sends a streaming message request to the agent and yields responses as they arrive.""" stream = self.stub.SendStreamingMessage( a2a_pb2.SendMessageRequest( request=proto_utils.ToProto.message(request.message), @@ -107,33 +111,32 @@ async def send_message_streaming( ) while True: response = await stream.read() - if response == grpc.aio.EOF: # pyright: ignore [reportAttributeAccessIssue] + if response == grpc.aio.EOF: # pyright: ignore[reportAttributeAccessIssue] break - if response.HasField('msg'): - yield proto_utils.FromProto.message(response.msg) - elif response.HasField('task'): - yield proto_utils.FromProto.task(response.task) - elif response.HasField('status_update'): - yield proto_utils.FromProto.task_status_update_event( - response.status_update - ) - elif response.HasField('artifact_update'): - yield proto_utils.FromProto.task_artifact_update_event( - response.artifact_update - ) + yield proto_utils.FromProto.stream_response(response) + + async def resubscribe( + self, request: TaskIdParams, *, context: ClientCallContext | None = None + ) -> AsyncGenerator[ + Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent + ]: + """Reconnects to get task updates.""" + stream = self.stub.TaskSubscription( + a2a_pb2.TaskSubscriptionRequest(name=f'tasks/{request.id}') + ) + while True: + response = await stream.read() + if response == grpc.aio.EOF: # pyright: ignore[reportAttributeAccessIssue] + break + yield proto_utils.FromProto.stream_response(response) async def get_task( self, request: TaskQueryParams, + *, + context: ClientCallContext | None = None, ) -> Task: - """Retrieves the current state and history of a specific task. - - Args: - request: The `TaskQueryParams` object specifying the task ID - - Returns: - A `Task` object containing the Task or None. - """ + """Retrieves the current state and history of a specific task.""" task = await self.stub.GetTask( a2a_pb2.GetTaskRequest(name=f'tasks/{request.id}') ) @@ -142,15 +145,10 @@ async def get_task( async def cancel_task( self, request: TaskIdParams, + *, + context: ClientCallContext | None = None, ) -> Task: - """Requests the agent to cancel a specific task. - - Args: - request: The `TaskIdParams` object specifying the task ID. - - Returns: - A `Task` object containing the updated Task - """ + """Requests the agent to cancel a specific task.""" task = await self.stub.CancelTask( a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}') ) @@ -159,19 +157,14 @@ async def cancel_task( async def set_task_callback( self, request: TaskPushNotificationConfig, + *, + context: ClientCallContext | None = None, ) -> TaskPushNotificationConfig: - """Sets or updates the push notification configuration for a specific task. - - Args: - request: The `TaskPushNotificationConfig` object specifying the task ID and configuration. - - Returns: - A `TaskPushNotificationConfig` object containing the config. - """ + """Sets or updates the push notification configuration for a specific task.""" config = await self.stub.CreateTaskPushNotificationConfig( a2a_pb2.CreateTaskPushNotificationConfigRequest( - parent='', - config_id='', + parent=f'tasks/{request.task_id}', + config_id=request.push_notification_config.id, config=proto_utils.ToProto.task_push_notification_config( request ), @@ -181,19 +174,38 @@ async def set_task_callback( async def get_task_callback( self, - request: TaskIdParams, # TODO: Update to a push id params + request: GetTaskPushNotificationConfigParams, + *, + context: ClientCallContext | None = None, ) -> TaskPushNotificationConfig: - """Retrieves the push notification configuration for a specific task. - - Args: - request: The `TaskIdParams` object specifying the task ID. - - Returns: - A `TaskPushNotificationConfig` object containing the configuration. - """ + """Retrieves the push notification configuration for a specific task.""" config = await self.stub.GetTaskPushNotificationConfig( a2a_pb2.GetTaskPushNotificationConfigRequest( - name=f'tasks/{request.id}/pushNotification/undefined', + name=f'tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}', ) ) return proto_utils.FromProto.task_push_notification_config(config) + + async def get_card( + self, + *, + context: ClientCallContext | None = None, + ) -> AgentCard: + """Retrieves the agent's card.""" + card = self.agent_card + if card and not self._needs_extended_card: + return card + if card is None and not self._needs_extended_card: + raise ValueError('Agent card is not available.') + + card_pb = await self.stub.GetAgentCard( + a2a_pb2.GetAgentCardRequest(), + ) + card = proto_utils.FromProto.agent_card(card_pb) + self.agent_card = card + self._needs_extended_card = False + return card + + async def close(self) -> None: + """Closes the gRPC channel.""" + await self.channel.close() diff --git a/src/a2a/client/transports/jsonrpc.py b/src/a2a/client/transports/jsonrpc.py new file mode 100644 index 00000000..868b3a01 --- /dev/null +++ b/src/a2a/client/transports/jsonrpc.py @@ -0,0 +1,376 @@ +import json +import logging + +from collections.abc import AsyncGenerator +from typing import Any +from uuid import uuid4 + +import httpx + +from httpx_sse import SSEError, aconnect_sse + +from a2a.client.card_resolver import A2ACardResolver +from a2a.client.errors import ( + A2AClientHTTPError, + A2AClientJSONError, + A2AClientJSONRPCError, + A2AClientTimeoutError, +) +from a2a.client.middleware import ClientCallContext, ClientCallInterceptor +from a2a.client.transports.base import ClientTransport +from a2a.types import ( + AgentCard, + CancelTaskRequest, + CancelTaskResponse, + GetAuthenticatedExtendedCardRequest, + GetAuthenticatedExtendedCardResponse, + GetTaskPushNotificationConfigParams, + GetTaskPushNotificationConfigRequest, + GetTaskPushNotificationConfigResponse, + GetTaskRequest, + GetTaskResponse, + JSONRPCErrorResponse, + Message, + MessageSendParams, + SendMessageRequest, + SendMessageResponse, + SendStreamingMessageRequest, + SendStreamingMessageResponse, + SetTaskPushNotificationConfigRequest, + SetTaskPushNotificationConfigResponse, + Task, + TaskArtifactUpdateEvent, + TaskIdParams, + TaskPushNotificationConfig, + TaskQueryParams, + TaskResubscriptionRequest, + TaskStatusUpdateEvent, +) +from a2a.utils.telemetry import SpanKind, trace_class + + +logger = logging.getLogger(__name__) + + +@trace_class(kind=SpanKind.CLIENT) +class JsonRpcTransport(ClientTransport): + """A JSON-RPC transport for the A2A client.""" + + def __init__( + self, + httpx_client: httpx.AsyncClient, + agent_card: AgentCard | None = None, + url: str | None = None, + interceptors: list[ClientCallInterceptor] | None = None, + ): + """Initializes the JsonRpcTransport.""" + if url: + self.url = url + elif agent_card: + self.url = agent_card.url + else: + raise ValueError('Must provide either agent_card or url') + + self.httpx_client = httpx_client + self.agent_card = agent_card + self.interceptors = interceptors or [] + self._needs_extended_card = ( + agent_card.supports_authenticated_extended_card + if agent_card + else True + ) + + async def _apply_interceptors( + self, + method_name: str, + request_payload: dict[str, Any], + http_kwargs: dict[str, Any] | None, + context: ClientCallContext | None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + final_http_kwargs = http_kwargs or {} + final_request_payload = request_payload + + for interceptor in self.interceptors: + ( + final_request_payload, + final_http_kwargs, + ) = await interceptor.intercept( + method_name, + final_request_payload, + final_http_kwargs, + self.agent_card, + context, + ) + return final_request_payload, final_http_kwargs + + def _get_http_args( + self, context: ClientCallContext | None + ) -> dict[str, Any] | None: + return context.state.get('http_kwargs') if context else None + + async def send_message( + self, + request: MessageSendParams, + *, + context: ClientCallContext | None = None, + ) -> Task | Message: + """Sends a non-streaming message request to the agent.""" + rpc_request = SendMessageRequest(params=request, id=str(uuid4())) + payload, modified_kwargs = await self._apply_interceptors( + 'message/send', + rpc_request.model_dump(mode='json', exclude_none=True), + self._get_http_args(context), + context, + ) + response_data = await self._send_request(payload, modified_kwargs) + response = SendMessageResponse.model_validate(response_data) + if isinstance(response.root, JSONRPCErrorResponse): + raise A2AClientJSONRPCError(response.root) + return response.root.result + + async def send_message_streaming( + self, + request: MessageSendParams, + *, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[ + Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent + ]: + """Sends a streaming message request to the agent and yields responses as they arrive.""" + rpc_request = SendStreamingMessageRequest( + params=request, id=str(uuid4()) + ) + payload, modified_kwargs = await self._apply_interceptors( + 'message/stream', + rpc_request.model_dump(mode='json', exclude_none=True), + self._get_http_args(context), + context, + ) + + modified_kwargs.setdefault('timeout', None) + + async with aconnect_sse( + self.httpx_client, + 'POST', + self.url, + json=payload, + **modified_kwargs, + ) as event_source: + try: + async for sse in event_source.aiter_sse(): + response = SendStreamingMessageResponse.model_validate( + json.loads(sse.data) + ) + if isinstance(response.root, JSONRPCErrorResponse): + raise A2AClientJSONRPCError(response.root) + yield response.root.result + except SSEError as e: + raise A2AClientHTTPError( + 400, f'Invalid SSE response or protocol error: {e}' + ) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + except httpx.RequestError as e: + raise A2AClientHTTPError( + 503, f'Network communication error: {e}' + ) from e + + async def _send_request( + self, + rpc_request_payload: dict[str, Any], + http_kwargs: dict[str, Any] | None = None, + ) -> dict[str, Any]: + try: + response = await self.httpx_client.post( + self.url, json=rpc_request_payload, **(http_kwargs or {}) + ) + response.raise_for_status() + return response.json() + except httpx.ReadTimeout as e: + raise A2AClientTimeoutError('Client Request timed out') from e + except httpx.HTTPStatusError as e: + raise A2AClientHTTPError(e.response.status_code, str(e)) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + except httpx.RequestError as e: + raise A2AClientHTTPError( + 503, f'Network communication error: {e}' + ) from e + + async def get_task( + self, + request: TaskQueryParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + """Retrieves the current state and history of a specific task.""" + rpc_request = GetTaskRequest(params=request, id=str(uuid4())) + payload, modified_kwargs = await self._apply_interceptors( + 'tasks/get', + rpc_request.model_dump(mode='json', exclude_none=True), + self._get_http_args(context), + context, + ) + response_data = await self._send_request(payload, modified_kwargs) + response = GetTaskResponse.model_validate(response_data) + if isinstance(response.root, JSONRPCErrorResponse): + raise A2AClientJSONRPCError(response.root) + return response.root.result + + async def cancel_task( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + """Requests the agent to cancel a specific task.""" + rpc_request = CancelTaskRequest(params=request, id=str(uuid4())) + payload, modified_kwargs = await self._apply_interceptors( + 'tasks/cancel', + rpc_request.model_dump(mode='json', exclude_none=True), + self._get_http_args(context), + context, + ) + response_data = await self._send_request(payload, modified_kwargs) + response = CancelTaskResponse.model_validate(response_data) + if isinstance(response.root, JSONRPCErrorResponse): + raise A2AClientJSONRPCError(response.root) + return response.root.result + + async def set_task_callback( + self, + request: TaskPushNotificationConfig, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Sets or updates the push notification configuration for a specific task.""" + rpc_request = SetTaskPushNotificationConfigRequest( + params=request, id=str(uuid4()) + ) + payload, modified_kwargs = await self._apply_interceptors( + 'tasks/pushNotificationConfig/set', + rpc_request.model_dump(mode='json', exclude_none=True), + self._get_http_args(context), + context, + ) + response_data = await self._send_request(payload, modified_kwargs) + response = SetTaskPushNotificationConfigResponse.model_validate( + response_data + ) + if isinstance(response.root, JSONRPCErrorResponse): + raise A2AClientJSONRPCError(response.root) + return response.root.result + + async def get_task_callback( + self, + request: GetTaskPushNotificationConfigParams, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Retrieves the push notification configuration for a specific task.""" + rpc_request = GetTaskPushNotificationConfigRequest( + params=request, id=str(uuid4()) + ) + payload, modified_kwargs = await self._apply_interceptors( + 'tasks/pushNotificationConfig/get', + rpc_request.model_dump(mode='json', exclude_none=True), + self._get_http_args(context), + context, + ) + response_data = await self._send_request(payload, modified_kwargs) + response = GetTaskPushNotificationConfigResponse.model_validate( + response_data + ) + if isinstance(response.root, JSONRPCErrorResponse): + raise A2AClientJSONRPCError(response.root) + return response.root.result + + async def resubscribe( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[ + Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent + ]: + """Reconnects to get task updates.""" + rpc_request = TaskResubscriptionRequest(params=request, id=str(uuid4())) + payload, modified_kwargs = await self._apply_interceptors( + 'tasks/resubscribe', + rpc_request.model_dump(mode='json', exclude_none=True), + self._get_http_args(context), + context, + ) + + modified_kwargs.setdefault('timeout', None) + + async with aconnect_sse( + self.httpx_client, + 'POST', + self.url, + json=payload, + **modified_kwargs, + ) as event_source: + try: + async for sse in event_source.aiter_sse(): + response = SendStreamingMessageResponse.model_validate_json( + sse.data + ) + if isinstance(response.root, JSONRPCErrorResponse): + raise A2AClientJSONRPCError(response.root) + yield response.root.result + except SSEError as e: + raise A2AClientHTTPError( + 400, f'Invalid SSE response or protocol error: {e}' + ) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + except httpx.RequestError as e: + raise A2AClientHTTPError( + 503, f'Network communication error: {e}' + ) from e + + async def get_card( + self, + *, + context: ClientCallContext | None = None, + ) -> AgentCard: + """Retrieves the agent's card.""" + card = self.agent_card + if not card: + resolver = A2ACardResolver(self.httpx_client, self.url) + card = await resolver.get_agent_card( + http_kwargs=self._get_http_args(context) + ) + self._needs_extended_card = ( + card.supports_authenticated_extended_card + ) + self.agent_card = card + + if not self._needs_extended_card: + return card + + request = GetAuthenticatedExtendedCardRequest(id=str(uuid4())) + payload, modified_kwargs = await self._apply_interceptors( + request.method, + request.model_dump(mode='json', exclude_none=True), + self._get_http_args(context), + context, + ) + + response_data = await self._send_request( + payload, + modified_kwargs, + ) + response = GetAuthenticatedExtendedCardResponse.model_validate( + response_data + ) + if isinstance(response.root, JSONRPCErrorResponse): + raise A2AClientJSONRPCError(response.root) + self.agent_card = response.root.result + self._needs_extended_card = False + return card + + async def close(self) -> None: + """Closes the httpx client.""" + await self.httpx_client.aclose() diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py new file mode 100644 index 00000000..3a72a5b1 --- /dev/null +++ b/src/a2a/client/transports/rest.py @@ -0,0 +1,365 @@ +import json +import logging + +from collections.abc import AsyncGenerator +from typing import Any + +import httpx + +from google.protobuf.json_format import MessageToDict, Parse, ParseDict +from httpx_sse import SSEError, aconnect_sse + +from a2a.client.card_resolver import A2ACardResolver +from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError +from a2a.client.middleware import ClientCallContext, ClientCallInterceptor +from a2a.client.transports.base import ClientTransport +from a2a.grpc import a2a_pb2 +from a2a.types import ( + AgentCard, + GetTaskPushNotificationConfigParams, + Message, + MessageSendParams, + Task, + TaskArtifactUpdateEvent, + TaskIdParams, + TaskPushNotificationConfig, + TaskQueryParams, + TaskStatusUpdateEvent, +) +from a2a.utils import proto_utils +from a2a.utils.telemetry import SpanKind, trace_class + + +logger = logging.getLogger(__name__) + + +@trace_class(kind=SpanKind.CLIENT) +class RestTransport(ClientTransport): + """A REST transport for the A2A client.""" + + def __init__( + self, + httpx_client: httpx.AsyncClient, + agent_card: AgentCard | None = None, + url: str | None = None, + interceptors: list[ClientCallInterceptor] | None = None, + ): + """Initializes the RestTransport.""" + if url: + self.url = url + elif agent_card: + self.url = agent_card.url + else: + raise ValueError('Must provide either agent_card or url') + if self.url.endswith('/'): + self.url = self.url[:-1] + self.httpx_client = httpx_client + self.agent_card = agent_card + self.interceptors = interceptors or [] + self._needs_extended_card = ( + agent_card.supports_authenticated_extended_card + if agent_card + else True + ) + + async def _apply_interceptors( + self, + request_payload: dict[str, Any], + http_kwargs: dict[str, Any] | None, + context: ClientCallContext | None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + final_http_kwargs = http_kwargs or {} + final_request_payload = request_payload + # TODO: Implement interceptors for other transports + return final_request_payload, final_http_kwargs + + def _get_http_args( + self, context: ClientCallContext | None + ) -> dict[str, Any] | None: + return context.state.get('http_kwargs') if context else None + + async def _prepare_send_message( + self, request: MessageSendParams, context: ClientCallContext | None + ) -> tuple[dict[str, Any], dict[str, Any]]: + pb = a2a_pb2.SendMessageRequest( + request=proto_utils.ToProto.message(request.message), + configuration=proto_utils.ToProto.message_send_configuration( + request.configuration + ), + metadata=( + proto_utils.ToProto.metadata(request.metadata) + if request.metadata + else None + ), + ) + payload = MessageToDict(pb) + payload, modified_kwargs = await self._apply_interceptors( + payload, + self._get_http_args(context), + context, + ) + return payload, modified_kwargs + + async def send_message( + self, + request: MessageSendParams, + *, + context: ClientCallContext | None = None, + ) -> Task | Message: + """Sends a non-streaming message request to the agent.""" + payload, modified_kwargs = await self._prepare_send_message( + request, context + ) + response_data = await self._send_post_request( + '/v1/message:send', payload, modified_kwargs + ) + response_pb = a2a_pb2.SendMessageResponse() + ParseDict(response_data, response_pb) + return proto_utils.FromProto.task_or_message(response_pb) + + async def send_message_streaming( + self, + request: MessageSendParams, + *, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[ + Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message + ]: + """Sends a streaming message request to the agent and yields responses as they arrive.""" + payload, modified_kwargs = await self._prepare_send_message( + request, context + ) + + modified_kwargs.setdefault('timeout', None) + + async with aconnect_sse( + self.httpx_client, + 'POST', + f'{self.url}/v1/message:stream', + json=payload, + **modified_kwargs, + ) as event_source: + try: + async for sse in event_source.aiter_sse(): + event = a2a_pb2.StreamResponse() + Parse(sse.data, event) + yield proto_utils.FromProto.stream_response(event) + except SSEError as e: + raise A2AClientHTTPError( + 400, f'Invalid SSE response or protocol error: {e}' + ) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + except httpx.RequestError as e: + raise A2AClientHTTPError( + 503, f'Network communication error: {e}' + ) from e + + async def _send_request(self, request: httpx.Request) -> dict[str, Any]: + try: + response = await self.httpx_client.send(request) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as e: + raise A2AClientHTTPError(e.response.status_code, str(e)) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + except httpx.RequestError as e: + raise A2AClientHTTPError( + 503, f'Network communication error: {e}' + ) from e + + async def _send_post_request( + self, + target: str, + rpc_request_payload: dict[str, Any], + http_kwargs: dict[str, Any] | None = None, + ) -> dict[str, Any]: + return await self._send_request( + self.httpx_client.build_request( + 'POST', + f'{self.url}{target}', + json=rpc_request_payload, + **(http_kwargs or {}), + ) + ) + + async def _send_get_request( + self, + target: str, + query_params: dict[str, str], + http_kwargs: dict[str, Any] | None = None, + ) -> dict[str, Any]: + return await self._send_request( + self.httpx_client.build_request( + 'GET', + f'{self.url}{target}', + params=query_params, + **(http_kwargs or {}), + ) + ) + + async def get_task( + self, + request: TaskQueryParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + """Retrieves the current state and history of a specific task.""" + payload, modified_kwargs = await self._apply_interceptors( + request.model_dump(mode='json', exclude_none=True), + self._get_http_args(context), + context, + ) + response_data = await self._send_get_request( + f'/v1/tasks/{request.id}', + {'historyLength': str(request.history_length)} + if request.history_length + else {}, + modified_kwargs, + ) + task = a2a_pb2.Task() + ParseDict(response_data, task) + return proto_utils.FromProto.task(task) + + async def cancel_task( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + """Requests the agent to cancel a specific task.""" + pb = a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}') + payload = MessageToDict(pb) + payload, modified_kwargs = await self._apply_interceptors( + payload, + self._get_http_args(context), + context, + ) + response_data = await self._send_post_request( + f'/v1/tasks/{request.id}:cancel', payload, modified_kwargs + ) + task = a2a_pb2.Task() + ParseDict(response_data, task) + return proto_utils.FromProto.task(task) + + async def set_task_callback( + self, + request: TaskPushNotificationConfig, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Sets or updates the push notification configuration for a specific task.""" + pb = a2a_pb2.CreateTaskPushNotificationConfigRequest( + parent=f'tasks/{request.task_id}', + config_id=request.push_notification_config.id, + config=proto_utils.ToProto.task_push_notification_config(request), + ) + payload = MessageToDict(pb) + payload, modified_kwargs = await self._apply_interceptors( + payload, self._get_http_args(context), context + ) + response_data = await self._send_post_request( + f'/v1/tasks/{request.task_id}/pushNotificationConfigs', + payload, + modified_kwargs, + ) + config = a2a_pb2.TaskPushNotificationConfig() + ParseDict(response_data, config) + return proto_utils.FromProto.task_push_notification_config(config) + + async def get_task_callback( + self, + request: GetTaskPushNotificationConfigParams, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Retrieves the push notification configuration for a specific task.""" + pb = a2a_pb2.GetTaskPushNotificationConfigRequest( + name=f'tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}', + ) + payload = MessageToDict(pb) + payload, modified_kwargs = await self._apply_interceptors( + payload, + self._get_http_args(context), + context, + ) + response_data = await self._send_get_request( + f'/v1/tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}', + {}, + modified_kwargs, + ) + config = a2a_pb2.TaskPushNotificationConfig() + ParseDict(response_data, config) + return proto_utils.FromProto.task_push_notification_config(config) + + async def resubscribe( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[ + Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message + ]: + """Reconnects to get task updates.""" + http_kwargs = self._get_http_args(context) or {} + http_kwargs.setdefault('timeout', None) + + async with aconnect_sse( + self.httpx_client, + 'GET', + f'{self.url}/v1/tasks/{request.id}:subscribe', + **http_kwargs, + ) as event_source: + try: + async for sse in event_source.aiter_sse(): + event = a2a_pb2.StreamResponse() + Parse(sse.data, event) + yield proto_utils.FromProto.stream_response(event) + except SSEError as e: + raise A2AClientHTTPError( + 400, f'Invalid SSE response or protocol error: {e}' + ) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + except httpx.RequestError as e: + raise A2AClientHTTPError( + 503, f'Network communication error: {e}' + ) from e + + async def get_card( + self, + *, + context: ClientCallContext | None = None, + ) -> AgentCard: + """Retrieves the agent's card.""" + card = self.agent_card + if not card: + resolver = A2ACardResolver(self.httpx_client, self.url) + card = await resolver.get_agent_card( + http_kwargs=self._get_http_args(context) + ) + self._needs_extended_card = ( + card.supports_authenticated_extended_card + ) + self.agent_card = card + + if not self._needs_extended_card: + return card + + _, modified_kwargs = await self._apply_interceptors( + {}, + self._get_http_args(context), + context, + ) + response_data = await self._send_get_request( + '/v1/card', {}, modified_kwargs + ) + card = AgentCard.model_validate(response_data) + self.agent_card = card + self._needs_extended_card = False + return card + + async def close(self) -> None: + """Closes the httpx client.""" + await self.httpx_client.aclose() diff --git a/src/a2a/server/apps/__init__.py b/src/a2a/server/apps/__init__.py index a73e05c8..579deaa5 100644 --- a/src/a2a/server/apps/__init__.py +++ b/src/a2a/server/apps/__init__.py @@ -6,10 +6,12 @@ CallContextBuilder, JSONRPCApplication, ) +from a2a.server.apps.rest import A2ARESTFastAPIApplication __all__ = [ 'A2AFastAPIApplication', + 'A2ARESTFastAPIApplication', 'A2AStarletteApplication', 'CallContextBuilder', 'JSONRPCApplication', diff --git a/src/a2a/server/apps/jsonrpc/__init__.py b/src/a2a/server/apps/jsonrpc/__init__.py index ab803d4e..1121fdbc 100644 --- a/src/a2a/server/apps/jsonrpc/__init__.py +++ b/src/a2a/server/apps/jsonrpc/__init__.py @@ -3,7 +3,9 @@ from a2a.server.apps.jsonrpc.fastapi_app import A2AFastAPIApplication from a2a.server.apps.jsonrpc.jsonrpc_app import ( CallContextBuilder, + DefaultCallContextBuilder, JSONRPCApplication, + StarletteUserProxy, ) from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication @@ -12,5 +14,7 @@ 'A2AFastAPIApplication', 'A2AStarletteApplication', 'CallContextBuilder', + 'DefaultCallContextBuilder', 'JSONRPCApplication', + 'StarletteUserProxy', ] diff --git a/src/a2a/server/apps/rest/__init__.py b/src/a2a/server/apps/rest/__init__.py new file mode 100644 index 00000000..bafe4cb6 --- /dev/null +++ b/src/a2a/server/apps/rest/__init__.py @@ -0,0 +1,8 @@ +"""A2A REST Applications.""" + +from a2a.server.apps.rest.fastapi_app import A2ARESTFastAPIApplication + + +__all__ = [ + 'A2ARESTFastAPIApplication', +] diff --git a/src/a2a/server/apps/rest/fastapi_app.py b/src/a2a/server/apps/rest/fastapi_app.py new file mode 100644 index 00000000..e4092b12 --- /dev/null +++ b/src/a2a/server/apps/rest/fastapi_app.py @@ -0,0 +1,78 @@ +import logging + +from typing import Any + +from fastapi import APIRouter, FastAPI, Request, Response + +from a2a.server.apps.jsonrpc.jsonrpc_app import CallContextBuilder +from a2a.server.apps.rest.rest_adapter import RESTAdapter +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.types import AgentCard +from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH + + +logger = logging.getLogger(__name__) + + +class A2ARESTFastAPIApplication: + """A FastAPI application implementing the A2A protocol server REST endpoints. + + Handles incoming REST requests, routes them to the appropriate + handler methods, and manages response generation including Server-Sent Events + (SSE). + """ + + def __init__( + self, + agent_card: AgentCard, + http_handler: RequestHandler, + context_builder: CallContextBuilder | None = None, + ): + """Initializes the A2ARESTFastAPIApplication. + + Args: + agent_card: The AgentCard describing the agent's capabilities. + http_handler: The handler instance responsible for processing A2A + requests via http. + extended_agent_card: An optional, distinct AgentCard to be served + at the authenticated extended card endpoint. + context_builder: The CallContextBuilder used to construct the + ServerCallContext passed to the http_handler. If None, no + ServerCallContext is passed. + """ + self._adapter = RESTAdapter( + agent_card=agent_card, + http_handler=http_handler, + context_builder=context_builder, + ) + + def build( + self, + agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH, + rpc_url: str = '', + **kwargs: Any, + ) -> FastAPI: + """Builds and returns the FastAPI application instance. + + Args: + agent_card_url: The URL for the agent card endpoint. + rpc_url: The URL for the A2A JSON-RPC endpoint. + extended_agent_card_url: The URL for the authenticated extended agent card endpoint. + **kwargs: Additional keyword arguments to pass to the FastAPI constructor. + + Returns: + A configured FastAPI application instance. + """ + app = FastAPI(**kwargs) + router = APIRouter() + for route, callback in self._adapter.routes().items(): + router.add_api_route( + f'{rpc_url}{route[0]}', callback, methods=[route[1]] + ) + + @router.get(f'{rpc_url}{agent_card_url}') + async def get_agent_card(request: Request) -> Response: + return await self._adapter.handle_get_agent_card(request) + + app.include_router(router) + return app diff --git a/src/a2a/server/apps/rest/rest_adapter.py b/src/a2a/server/apps/rest/rest_adapter.py new file mode 100644 index 00000000..102349b9 --- /dev/null +++ b/src/a2a/server/apps/rest/rest_adapter.py @@ -0,0 +1,181 @@ +import functools +import logging + +from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable +from typing import Any + +from sse_starlette.sse import EventSourceResponse +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from a2a.server.apps.jsonrpc import ( + CallContextBuilder, + DefaultCallContextBuilder, +) +from a2a.server.context import ServerCallContext +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.server.request_handlers.rest_handler import RESTHandler +from a2a.types import AgentCard, AuthenticatedExtendedCardNotConfiguredError +from a2a.utils.error_handlers import ( + rest_error_handler, + rest_stream_error_handler, +) +from a2a.utils.errors import ServerError + + +logger = logging.getLogger(__name__) + + +class RESTAdapter: + """Adapter to make RequestHandler work with RESTful API. + + Defines REST requests processors and the routes to attach them too, as well as + manages response generation including Server-Sent Events (SSE). + """ + + def __init__( + self, + agent_card: AgentCard, + http_handler: RequestHandler, + context_builder: CallContextBuilder | None = None, + ): + """Initializes the RESTApplication. + + Args: + agent_card: The AgentCard describing the agent's capabilities. + http_handler: The handler instance responsible for processing A2A + requests via http. + context_builder: The CallContextBuilder used to construct the + ServerCallContext passed to the http_handler. If None, no + ServerCallContext is passed. + """ + self.agent_card = agent_card + self.handler = RESTHandler( + agent_card=agent_card, request_handler=http_handler + ) + self._context_builder = context_builder or DefaultCallContextBuilder() + + @rest_error_handler + async def _handle_request( + self, + method: Callable[[Request, ServerCallContext], Awaitable[Any]], + request: Request, + ) -> Response: + call_context = self._context_builder.build(request) + response = await method(request, call_context) + return JSONResponse(content=response) + + @rest_stream_error_handler + async def _handle_streaming_request( + self, + method: Callable[[Request, ServerCallContext], AsyncIterable[Any]], + request: Request, + ) -> EventSourceResponse: + call_context = self._context_builder.build(request) + + async def event_generator( + stream: AsyncIterable[Any], + ) -> AsyncIterator[dict[str, dict[str, Any]]]: + async for item in stream: + yield {'data': item} + + return EventSourceResponse( + event_generator(method(request, call_context)) + ) + + @rest_error_handler + async def handle_get_agent_card(self, request: Request) -> JSONResponse: + """Handles GET requests for the agent card endpoint. + + Args: + request: The incoming Starlette Request object. + + Returns: + A JSONResponse containing the agent card data. + """ + # The public agent card is a direct serialization of the agent_card + # provided at initialization. + return JSONResponse( + self.agent_card.model_dump(mode='json', exclude_none=True) + ) + + @rest_error_handler + async def handle_authenticated_agent_card( + self, request: Request + ) -> JSONResponse: + """Hook for per credential agent card response. + + If a dynamic card is needed based on the credentials provided in the request + override this method and return the customized content. + + Args: + request: The incoming Starlette Request object. + + Returns: + A JSONResponse containing the authenticated card. + """ + if not self.agent_card.supports_authenticated_extended_card: + raise ServerError( + error=AuthenticatedExtendedCardNotConfiguredError( + message='Authenticated card not supported' + ) + ) + return JSONResponse( + self.agent_card.model_dump(mode='json', exclude_none=True) + ) + + def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]: + """Constructs a dictionary of API routes and their corresponding handlers. + + This method maps URL paths and HTTP methods to the appropriate handler + functions from the RESTHandler. It can be used by a web framework + (like Starlette or FastAPI) to set up the application's endpoints. + + Returns: + A dictionary where each key is a tuple of (path, http_method) and + the value is the callable handler for that route. + """ + routes: dict[tuple[str, str], Callable[[Request], Any]] = { + ('/v1/message:send', 'POST'): functools.partial( + self._handle_request, self.handler.on_message_send + ), + ('/v1/message:stream', 'POST'): functools.partial( + self._handle_streaming_request, + self.handler.on_message_send_stream, + ), + ('/v1/tasks/{id}:cancel', 'POST'): functools.partial( + self._handle_request, self.handler.on_cancel_task + ), + ('/v1/tasks/{id}:subscribe', 'GET'): functools.partial( + self._handle_streaming_request, + self.handler.on_resubscribe_to_task, + ), + ('/v1/tasks/{id}', 'GET'): functools.partial( + self._handle_request, self.handler.on_get_task + ), + ( + '/v1/tasks/{id}/pushNotificationConfigs/{push_id}', + 'GET', + ): functools.partial( + self._handle_request, self.handler.get_push_notification + ), + ( + '/v1/tasks/{id}/pushNotificationConfigs', + 'POST', + ): functools.partial( + self._handle_request, self.handler.set_push_notification + ), + ( + '/v1/tasks/{id}/pushNotificationConfigs', + 'GET', + ): functools.partial( + self._handle_request, self.handler.list_push_notifications + ), + ('/v1/tasks', 'GET'): functools.partial( + self._handle_request, self.handler.list_tasks + ), + } + if self.agent_card.supports_authenticated_extended_card: + routes[('/v1/card', 'GET')] = self.handle_authenticated_agent_card + + return routes diff --git a/src/a2a/server/models.py b/src/a2a/server/models.py index 38b0f700..a094005d 100644 --- a/src/a2a/server/models.py +++ b/src/a2a/server/models.py @@ -182,7 +182,7 @@ def create_task_model( TaskModel = create_task_model('tasks', MyBase) """ - class TaskModel(TaskMixin, base): + class TaskModel(TaskMixin, base): # type: ignore __tablename__ = table_name @override @@ -235,7 +235,7 @@ def create_push_notification_config_model( ) -> type: """Create a PushNotificationConfigModel class with a configurable table name.""" - class PushNotificationConfigModel(PushNotificationConfigMixin, base): + class PushNotificationConfigModel(PushNotificationConfigMixin, base): # type: ignore __tablename__ = table_name @override diff --git a/src/a2a/server/request_handlers/__init__.py b/src/a2a/server/request_handlers/__init__.py index 9882dc2a..43ebc8e2 100644 --- a/src/a2a/server/request_handlers/__init__.py +++ b/src/a2a/server/request_handlers/__init__.py @@ -11,6 +11,7 @@ build_error_response, prepare_response_object, ) +from a2a.server.request_handlers.rest_handler import RESTHandler logger = logging.getLogger(__name__) @@ -40,6 +41,7 @@ def __init__(self, *args, **kwargs): 'DefaultRequestHandler', 'GrpcHandler', 'JSONRPCHandler', + 'RESTHandler', 'RequestHandler', 'build_error_response', 'prepare_response_object', diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index 2761ed33..5fc15cf9 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -286,7 +286,7 @@ async def CreateTaskPushNotificationConfig( server_context = self.context_builder.build(context) config = ( await self.request_handler.on_set_task_push_notification_config( - proto_utils.FromProto.task_push_notification_config( + proto_utils.FromProto.task_push_notification_config_request( request, ), server_context, diff --git a/src/a2a/server/request_handlers/rest_handler.py b/src/a2a/server/request_handlers/rest_handler.py new file mode 100644 index 00000000..48f57b75 --- /dev/null +++ b/src/a2a/server/request_handlers/rest_handler.py @@ -0,0 +1,298 @@ +import logging + +from collections.abc import AsyncIterable, AsyncIterator +from typing import Any + +from google.protobuf.json_format import MessageToDict, MessageToJson, Parse +from starlette.requests import Request + +from a2a.grpc import a2a_pb2 +from a2a.server.context import ServerCallContext +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.types import ( + AgentCard, + GetTaskPushNotificationConfigParams, + TaskIdParams, + TaskNotFoundError, + TaskQueryParams, +) +from a2a.utils import proto_utils +from a2a.utils.errors import ServerError +from a2a.utils.helpers import validate +from a2a.utils.telemetry import SpanKind, trace_class + + +logger = logging.getLogger(__name__) + + +@trace_class(kind=SpanKind.SERVER) +class RESTHandler: + """Maps incoming REST-like (JSON+HTTP) requests to the appropriate request handler method and formats responses. + + This uses the protobuf definitions of the gRPC service as the source of truth. By + doing this, it ensures that this implementation and the gRPC transcoding + (via Envoy) are equivalent. This handler should be used if using the gRPC handler + with Envoy is not feasible for a given deployment solution. Use this handler + and a related application if you desire to ONLY server the RESTful API. + """ + + def __init__( + self, + agent_card: AgentCard, + request_handler: RequestHandler, + ): + """Initializes the RESTHandler. + + Args: + agent_card: The AgentCard describing the agent's capabilities. + request_handler: The underlying `RequestHandler` instance to delegate requests to. + """ + self.agent_card = agent_card + self.request_handler = request_handler + + async def on_message_send( + self, + request: Request, + context: ServerCallContext, + ) -> dict[str, Any]: + """Handles the 'message/send' REST method. + + Args: + request: The incoming `Request` object. + context: Context provided by the server. + + Returns: + A `dict` containing the result (Task or Message) + """ + body = await request.body() + params = a2a_pb2.SendMessageRequest() + Parse(body, params) + # Transform the proto object to the python internal objects + a2a_request = proto_utils.FromProto.message_send_params( + params, + ) + task_or_message = await self.request_handler.on_message_send( + a2a_request, context + ) + return MessageToDict( + proto_utils.ToProto.task_or_message(task_or_message) + ) + + @validate( + lambda self: self.agent_card.capabilities.streaming, + 'Streaming is not supported by the agent', + ) + async def on_message_send_stream( + self, + request: Request, + context: ServerCallContext, + ) -> AsyncIterator[str]: + """Handles the 'message/stream' REST method. + + Yields response objects as they are produced by the underlying handler's stream. + + Args: + request: The incoming `Request` object. + context: Context provided by the server. + + Yields: + JSON serialized objects containing streaming events + (Task, Message, TaskStatusUpdateEvent, TaskArtifactUpdateEvent) as JSON + """ + body = await request.body() + params = a2a_pb2.SendMessageRequest() + Parse(body, params) + # Transform the proto object to the python internal objects + a2a_request = proto_utils.FromProto.message_send_params( + params, + ) + async for event in self.request_handler.on_message_send_stream( + a2a_request, context + ): + response = proto_utils.ToProto.stream_response(event) + yield MessageToJson(response) + + async def on_cancel_task( + self, + request: Request, + context: ServerCallContext, + ) -> dict[str, Any]: + """Handles the 'tasks/cancel' REST method. + + Args: + request: The incoming `Request` object. + context: Context provided by the server. + + Returns: + A `dict` containing the updated Task + """ + task_id = request.path_params['id'] + task = await self.request_handler.on_cancel_task( + TaskIdParams(id=task_id), context + ) + if task: + return MessageToDict(proto_utils.ToProto.task(task)) + raise ServerError(error=TaskNotFoundError()) + + @validate( + lambda self: self.agent_card.capabilities.streaming, + 'Streaming is not supported by the agent', + ) + async def on_resubscribe_to_task( + self, + request: Request, + context: ServerCallContext, + ) -> AsyncIterable[str]: + """Handles the 'tasks/resubscribe' REST method. + + Yields response objects as they are produced by the underlying handler's stream. + + Args: + request: The incoming `Request` object. + context: Context provided by the server. + + Yields: + JSON serialized objects containing streaming events + """ + task_id = request.path_params['id'] + async for event in self.request_handler.on_resubscribe_to_task( + TaskIdParams(id=task_id), context + ): + yield MessageToJson(proto_utils.ToProto.stream_response(event)) + + async def get_push_notification( + self, + request: Request, + context: ServerCallContext, + ) -> dict[str, Any]: + """Handles the 'tasks/pushNotificationConfig/get' REST method. + + Args: + request: The incoming `Request` object. + context: Context provided by the server. + + Returns: + A `dict` containing the config + """ + task_id = request.path_params['id'] + push_id = request.path_params['push_id'] + params = GetTaskPushNotificationConfigParams( + id=task_id, push_notification_config_id=push_id + ) + config = ( + await self.request_handler.on_get_task_push_notification_config( + params, context + ) + ) + return MessageToDict( + proto_utils.ToProto.task_push_notification_config(config) + ) + + @validate( + lambda self: self.agent_card.capabilities.push_notifications, + 'Push notifications are not supported by the agent', + ) + async def set_push_notification( + self, + request: Request, + context: ServerCallContext, + ) -> dict[str, Any]: + """Handles the 'tasks/pushNotificationConfig/set' REST method. + + Requires the agent to support push notifications. + + Args: + request: The incoming `TaskPushNotificationConfig` object. + context: Context provided by the server. + + Returns: + A `dict` containing the config object. + + Raises: + ServerError: If push notifications are not supported by the agent + (due to the `@validate` decorator), A2AError if processing error is + found. + """ + task_id = request.path_params['id'] + body = await request.body() + params = a2a_pb2.CreateTaskPushNotificationConfigRequest() + Parse(body, params) + a2a_request = ( + proto_utils.FromProto.task_push_notification_config_request( + params, + ) + ) + a2a_request.task_id = task_id + config = ( + await self.request_handler.on_set_task_push_notification_config( + a2a_request, context + ) + ) + return MessageToDict( + proto_utils.ToProto.task_push_notification_config(config) + ) + + async def on_get_task( + self, + request: Request, + context: ServerCallContext, + ) -> dict[str, Any]: + """Handles the 'v1/tasks/{id}' REST method. + + Args: + request: The incoming `Request` object. + context: Context provided by the server. + + Returns: + A `Task` object containing the Task. + """ + task_id = request.path_params['id'] + history_length_str = request.query_params.get('historyLength') + history_length = int(history_length_str) if history_length_str else None + params = TaskQueryParams(id=task_id, history_length=history_length) + task = await self.request_handler.on_get_task(params, context) + if task: + return MessageToDict(proto_utils.ToProto.task(task)) + raise ServerError(error=TaskNotFoundError()) + + async def list_push_notifications( + self, + request: Request, + context: ServerCallContext, + ) -> dict[str, Any]: + """Handles the 'tasks/pushNotificationConfig/list' REST method. + + This method is currently not implemented. + + Args: + request: The incoming `Request` object. + context: Context provided by the server. + + Returns: + A list of `dict` representing the `TaskPushNotificationConfig` objects. + + Raises: + NotImplementedError: This method is not yet implemented. + """ + raise NotImplementedError('list notifications not implemented') + + async def list_tasks( + self, + request: Request, + context: ServerCallContext, + ) -> dict[str, Any]: + """Handles the 'tasks/list' REST method. + + This method is currently not implemented. + + Args: + request: The incoming `Request` object. + context: Context provided by the server. + + Returns: + A list of dict representing the`Task` objects. + + Raises: + NotImplementedError: This method is not yet implemented. + """ + raise NotImplementedError('list tasks not implemented') diff --git a/src/a2a/utils/error_handlers.py b/src/a2a/utils/error_handlers.py new file mode 100644 index 00000000..22527aa7 --- /dev/null +++ b/src/a2a/utils/error_handlers.py @@ -0,0 +1,118 @@ +import functools +import logging + +from collections.abc import Awaitable, Callable, Coroutine +from typing import Any + +from starlette.responses import JSONResponse, Response + +from a2a._base import A2ABaseModel +from a2a.types import ( + AuthenticatedExtendedCardNotConfiguredError, + ContentTypeNotSupportedError, + InternalError, + InvalidAgentResponseError, + InvalidParamsError, + InvalidRequestError, + JSONParseError, + MethodNotFoundError, + PushNotificationNotSupportedError, + TaskNotCancelableError, + TaskNotFoundError, + UnsupportedOperationError, +) +from a2a.utils.errors import ServerError + + +logger = logging.getLogger(__name__) + +A2AErrorToHttpStatus: dict[type[A2ABaseModel], int] = { + JSONParseError: 400, + InvalidRequestError: 400, + MethodNotFoundError: 404, + InvalidParamsError: 422, + InternalError: 500, + TaskNotFoundError: 404, + TaskNotCancelableError: 409, + PushNotificationNotSupportedError: 501, + UnsupportedOperationError: 501, + ContentTypeNotSupportedError: 415, + InvalidAgentResponseError: 502, + AuthenticatedExtendedCardNotConfiguredError: 404, +} + + +def rest_error_handler( + func: Callable[..., Awaitable[Response]], +) -> Callable[..., Awaitable[Response]]: + """Decorator to catch ServerError and map it to an appropriate JSONResponse.""" + + @functools.wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Response: + try: + return await func(*args, **kwargs) + except ServerError as e: + error = e.error or InternalError( + message='Internal error due to unknown reason' + ) + http_code = A2AErrorToHttpStatus.get(type(error), 500) + + log_level = ( + logging.ERROR + if isinstance(error, InternalError) + else logging.WARNING + ) + logger.log( + log_level, + 'Request error: ' + f"Code={error.code}, Message='{error.message}'" + f'{", Data=" + str(error.data) if error.data else ""}', + ) + return JSONResponse( + content={'message': error.message}, status_code=http_code + ) + except Exception as e: + logger.log(logging.ERROR, f'Unknown error occurred {e}') + return JSONResponse( + content={'message': 'unknown exception'}, status_code=500 + ) + + return wrapper + + +def rest_stream_error_handler( + func: Callable[..., Coroutine[Any, Any, Any]], +) -> Callable[..., Coroutine[Any, Any, Any]]: + """Decorator to catch ServerError for a straming method,log it and then rethrow it to be handled by framework.""" + + @functools.wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + try: + return await func(*args, **kwargs) + except ServerError as e: + error = e.error or InternalError( + message='Internal error due to unknown reason' + ) + + log_level = ( + logging.ERROR + if isinstance(error, InternalError) + else logging.WARNING + ) + logger.log( + log_level, + 'Request error: ' + f"Code={error.code}, Message='{error.message}'" + f'{", Data=" + str(error.data) if error.data else ""}', + ) + # Since the stream has started, we can't return a JSONResponse. + # Instead, we runt the error handling logic (provides logging) + # and reraise the error and let server framework manage + raise e + except Exception as e: + # Since the stream has started, we can't return a JSONResponse. + # Instead, we runt the error handling logic (provides logging) + # and reraise the error and let server framework manage + raise e + + return wrapper diff --git a/src/a2a/utils/errors.py b/src/a2a/utils/errors.py index 2964172d..f850857a 100644 --- a/src/a2a/utils/errors.py +++ b/src/a2a/utils/errors.py @@ -1,6 +1,7 @@ """Custom exceptions for A2A server-side errors.""" from a2a.types import ( + AuthenticatedExtendedCardNotConfiguredError, ContentTypeNotSupportedError, InternalError, InvalidAgentResponseError, @@ -57,6 +58,7 @@ def __init__( | UnsupportedOperationError | ContentTypeNotSupportedError | InvalidAgentResponseError + | AuthenticatedExtendedCardNotConfiguredError | None ), ): diff --git a/src/a2a/utils/helpers.py b/src/a2a/utils/helpers.py index 091268ba..0760690b 100644 --- a/src/a2a/utils/helpers.py +++ b/src/a2a/utils/helpers.py @@ -1,6 +1,7 @@ """General utility functions for the A2A Python SDK.""" import functools +import inspect import logging from collections.abc import Callable @@ -135,7 +136,22 @@ def validate( """ def decorator(function: Callable) -> Callable: - def wrapper(self: Any, *args, **kwargs) -> Any: + if inspect.iscoroutinefunction(function): + + @functools.wraps(function) + async def async_wrapper(self: Any, *args, **kwargs) -> Any: + if not expression(self): + final_message = error_message or str(expression) + logger.error(f'Unsupported Operation: {final_message}') + raise ServerError( + UnsupportedOperationError(message=final_message) + ) + return await function(self, *args, **kwargs) + + return async_wrapper + + @functools.wraps(function) + def sync_wrapper(self: Any, *args, **kwargs) -> Any: if not expression(self): final_message = error_message or str(expression) logger.error(f'Unsupported Operation: {final_message}') @@ -144,7 +160,7 @@ def wrapper(self: Any, *args, **kwargs) -> Any: ) return function(self, *args, **kwargs) - return wrapper + return sync_wrapper return decorator diff --git a/src/a2a/utils/proto_utils.py b/src/a2a/utils/proto_utils.py index f4d9e70e..e67f0dfa 100644 --- a/src/a2a/utils/proto_utils.py +++ b/src/a2a/utils/proto_utils.py @@ -2,6 +2,7 @@ """Utils for converting between proto and Python types.""" import json +import logging import re from typing import Any @@ -13,9 +14,14 @@ from a2a.utils.errors import ServerError +logger = logging.getLogger(__name__) + + # Regexp patterns for matching -_TASK_NAME_MATCH = r'tasks/(\w+)' -_TASK_PUSH_CONFIG_NAME_MATCH = r'tasks/(\w+)/pushNotificationConfigs/(\w+)' +_TASK_NAME_MATCH = r'tasks/([\w-]+)' +_TASK_PUSH_CONFIG_NAME_MATCH = ( + r'tasks/([\w-]+)/pushNotificationConfigs/([\w-]+)' +) class ToProto: @@ -259,7 +265,7 @@ def task_push_notification_config( cls, config: types.TaskPushNotificationConfig ) -> a2a_pb2.TaskPushNotificationConfig: return a2a_pb2.TaskPushNotificationConfig( - name=f'tasks/{config.task_id}/pushNotificationConfigs/{config.task_id}', + name=f'tasks/{config.task_id}/pushNotificationConfigs/{config.push_notification_config.id}', push_notification_config=cls.push_notification_config( config.push_notification_config, ), @@ -286,6 +292,23 @@ def agent_card( supports_authenticated_extended_card=bool( card.supports_authenticated_extended_card ), + preferred_transport=card.preferred_transport, + protocol_version=card.protocol_version, + additional_interfaces=[ + cls.agent_interface(x) for x in card.additional_interfaces + ] + if card.additional_interfaces + else None, + ) + + @classmethod + def agent_interface( + cls, + interface: types.AgentInterface, + ) -> a2a_pb2.AgentInterface: + return a2a_pb2.AgentInterface( + transport=interface.transport, + url=interface.url, ) @classmethod @@ -485,6 +508,14 @@ def file( return types.FileWithUri(uri=file.file_with_uri) return types.FileWithBytes(bytes=file.file_with_bytes.decode('utf-8')) + @classmethod + def task_or_message( + cls, event: a2a_pb2.SendMessageResponse + ) -> types.Task | types.Message: + if event.HasField('msg'): + return cls.message(event.msg) + return cls.task(event.task) + @classmethod def task(cls, task: a2a_pb2.Task) -> types.Task: return types.Task( @@ -632,7 +663,7 @@ def task_id_params( return types.TaskIdParams(id=m.group(1)) @classmethod - def task_push_notification_config( + def task_push_notification_config_request( cls, request: a2a_pb2.CreateTaskPushNotificationConfigRequest, ) -> types.TaskPushNotificationConfig: @@ -650,6 +681,25 @@ def task_push_notification_config( task_id=m.group(1), ) + @classmethod + def task_push_notification_config( + cls, + config: a2a_pb2.TaskPushNotificationConfig, + ) -> types.TaskPushNotificationConfig: + m = re.match(_TASK_PUSH_CONFIG_NAME_MATCH, config.name) + if not m: + raise ServerError( + error=types.InvalidParamsError( + message=f'Bad TaskPushNotificationConfig resource name {config.name}' + ) + ) + return types.TaskPushNotificationConfig( + push_notification_config=cls.push_notification_config( + config.push_notification_config, + ), + task_id=m.group(1), + ) + @classmethod def agent_card( cls, @@ -669,6 +719,23 @@ def agent_card( url=card.url, version=card.version, supports_authenticated_extended_card=card.supports_authenticated_extended_card, + preferred_transport=card.preferred_transport, + protocol_version=card.protocol_version, + additional_interfaces=[ + cls.agent_interface(x) for x in card.additional_interfaces + ] + if card.additional_interfaces + else None, + ) + + @classmethod + def agent_interface( + cls, + interface: a2a_pb2.AgentInterface, + ) -> types.AgentInterface: + return types.AgentInterface( + transport=interface.transport, + url=interface.url, ) @classmethod @@ -799,6 +866,24 @@ def oauth2_flows(cls, flows: a2a_pb2.OAuthFlows) -> types.OAuthFlows: ), ) + @classmethod + def stream_response( + cls, + response: a2a_pb2.StreamResponse, + ) -> ( + types.Message + | types.Task + | types.TaskStatusUpdateEvent + | types.TaskArtifactUpdateEvent + ): + if response.HasField('msg'): + return cls.message(response.msg) + if response.HasField('task'): + return cls.task(response.task) + if response.HasField('status_update'): + return cls.task_status_update_event(response.status_update) + return cls.task_artifact_update_event(response.artifact_update) + @classmethod def skill(cls, skill: a2a_pb2.AgentSkill) -> types.AgentSkill: return types.AgentSkill( diff --git a/src/a2a/utils/telemetry.py b/src/a2a/utils/telemetry.py index f8908183..f911fd6b 100644 --- a/src/a2a/utils/telemetry.py +++ b/src/a2a/utils/telemetry.py @@ -83,7 +83,7 @@ def internal_method(self): class _NoOp: """A no-op object that absorbs all tracing calls when OpenTelemetry is not installed.""" - def __call__(self, *args: Any, **kwargs: Any) -> '_NoOp': + def __call__(self, *args: Any, **kwargs: Any) -> Any: return self def __enter__(self) -> '_NoOp': @@ -92,12 +92,12 @@ def __enter__(self) -> '_NoOp': def __exit__(self, *args: object, **kwargs: Any) -> None: pass - def __getattr__(self, name: str) -> '_NoOp': + def __getattr__(self, name: str) -> Any: return self - trace = _NoOp() - _SpanKind = _NoOp() - StatusCode = _NoOp() + trace = _NoOp() # type: ignore + _SpanKind = _NoOp() # type: ignore + StatusCode = _NoOp() # type: ignore SpanKind = _SpanKind __all__ = ['SpanKind'] diff --git a/tests/client/test_auth_middleware.py b/tests/client/test_auth_middleware.py index ec89d1e2..4f53ca3f 100644 --- a/tests/client/test_auth_middleware.py +++ b/tests/client/test_auth_middleware.py @@ -1,3 +1,5 @@ +import json + from collections.abc import Callable from dataclasses import dataclass from typing import Any @@ -6,8 +8,15 @@ import pytest import respx -from a2a.client import A2AClient, ClientCallContext, ClientCallInterceptor -from a2a.client.auth import AuthInterceptor, InMemoryContextCredentialStore +from a2a.client import ( + AuthInterceptor, + Client, + ClientCallContext, + ClientCallInterceptor, + ClientConfig, + ClientFactory, + InMemoryContextCredentialStore, +) from a2a.types import ( APIKeySecurityScheme, AgentCapabilities, @@ -16,14 +25,13 @@ HTTPAuthSecurityScheme, In, Message, - MessageSendParams, OAuth2SecurityScheme, OAuthFlows, OpenIdConnectSecurityScheme, Role, SecurityScheme, - SendMessageRequest, SendMessageSuccessResponse, + TransportProtocol, ) @@ -48,10 +56,11 @@ async def intercept( return request_payload, http_kwargs -def build_success_response() -> dict: - """Creates a valid JSON-RPC success response as dict.""" - return SendMessageSuccessResponse( - id='1', +def build_success_response(request: httpx.Request) -> httpx.Response: + """Creates a valid JSON-RPC success response based on the request.""" + request_payload = json.loads(request.content) + response_payload = SendMessageSuccessResponse( + id=request_payload['id'], jsonrpc='2.0', result=Message( kind='message', @@ -60,41 +69,33 @@ def build_success_response() -> dict: parts=[], ), ).model_dump(mode='json') + return httpx.Response(200, json=response_payload) -def build_send_message_request() -> SendMessageRequest: - """Builds a minimal SendMessageRequest.""" - return SendMessageRequest( - id='1', - params=MessageSendParams( - message=Message( - message_id='msg1', - role=Role.user, - parts=[], - ) - ), +def build_message() -> Message: + """Builds a minimal Message.""" + return Message( + message_id='msg1', + role=Role.user, + parts=[], ) async def send_message( - client: A2AClient, + client: Client, url: str, session_id: str | None = None, ) -> httpx.Request: """Mocks the response and sends a message using the client.""" - respx.post(url).mock( - return_value=httpx.Response( - 200, - json=build_success_response(), - ) - ) + respx.post(url).mock(side_effect=build_success_response) context = ClientCallContext( state={'sessionId': session_id} if session_id else {} ) - await client.send_message( - request=build_send_message_request(), + async for _ in client.send_message( + request=build_message(), context=context, - ) + ): + pass return respx.calls.last.request @@ -169,11 +170,26 @@ async def test_client_with_simple_interceptor(): """ url = 'http://agent.com/rpc' interceptor = HeaderInterceptor('X-Test-Header', 'Test-Value-123') + card = AgentCard( + url=url, + name='testbot', + description='test bot', + version='1.0', + default_input_modes=[], + default_output_modes=[], + skills=[], + capabilities=AgentCapabilities(), + preferred_transport=TransportProtocol.jsonrpc, + ) async with httpx.AsyncClient() as http_client: - client = A2AClient( - httpx_client=http_client, url=url, interceptors=[interceptor] + config = ClientConfig( + httpx_client=http_client, + supported_transports=[TransportProtocol.jsonrpc], ) + factory = ClientFactory(config) + client = factory.create(card, interceptors=[interceptor]) + request = await send_message(client, url) assert request.headers['x-test-header'] == 'Test-Value-123' @@ -292,14 +308,17 @@ async def test_auth_interceptor_variants(test_case, store): root=test_case.security_scheme ) }, + preferred_transport=TransportProtocol.jsonrpc, ) async with httpx.AsyncClient() as http_client: - client = A2AClient( + config = ClientConfig( httpx_client=http_client, - agent_card=agent_card, - interceptors=[auth_interceptor], + supported_transports=[TransportProtocol.jsonrpc], ) + factory = ClientFactory(config) + client = factory.create(agent_card, interceptors=[auth_interceptor]) + request = await send_message( client, test_case.url, test_case.session_id ) diff --git a/tests/client/test_base_client.py b/tests/client/test_base_client.py new file mode 100644 index 00000000..7b1aacec --- /dev/null +++ b/tests/client/test_base_client.py @@ -0,0 +1,118 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from a2a.client.base_client import BaseClient +from a2a.client.client import ClientConfig +from a2a.client.transports.base import ClientTransport +from a2a.types import ( + AgentCapabilities, + AgentCard, + Message, + Part, + Role, + Task, + TaskState, + TaskStatus, + TextPart, +) + + +@pytest.fixture +def mock_transport(): + transport = AsyncMock(spec=ClientTransport) + return transport + + +@pytest.fixture +def sample_agent_card(): + return AgentCard( + name='Test Agent', + description='An agent for testing', + url='http://test.com', + version='1.0', + capabilities=AgentCapabilities(streaming=True), + default_input_modes=['text/plain'], + default_output_modes=['text/plain'], + skills=[], + ) + + +@pytest.fixture +def sample_message(): + return Message( + role=Role.user, + message_id='msg-1', + parts=[Part(root=TextPart(text='Hello'))], + ) + + +@pytest.fixture +def base_client(sample_agent_card, mock_transport): + config = ClientConfig(streaming=True) + return BaseClient( + card=sample_agent_card, + config=config, + transport=mock_transport, + consumers=[], + middleware=[], + ) + + +@pytest.mark.asyncio +async def test_send_message_streaming( + base_client: BaseClient, mock_transport: MagicMock, sample_message: Message +): + async def create_stream(*args, **kwargs): + yield Task( + id='task-123', + context_id='ctx-456', + status=TaskStatus(state=TaskState.completed), + ) + + mock_transport.send_message_streaming.return_value = create_stream() + + events = [event async for event in base_client.send_message(sample_message)] + + mock_transport.send_message_streaming.assert_called_once() + assert not mock_transport.send_message.called + assert len(events) == 1 + assert events[0][0].id == 'task-123' + + +@pytest.mark.asyncio +async def test_send_message_non_streaming( + base_client: BaseClient, mock_transport: MagicMock, sample_message: Message +): + base_client._config.streaming = False + mock_transport.send_message.return_value = Task( + id='task-456', + context_id='ctx-789', + status=TaskStatus(state=TaskState.completed), + ) + + events = [event async for event in base_client.send_message(sample_message)] + + mock_transport.send_message.assert_called_once() + assert not mock_transport.send_message_streaming.called + assert len(events) == 1 + assert events[0][0].id == 'task-456' + + +@pytest.mark.asyncio +async def test_send_message_non_streaming_agent_capability_false( + base_client: BaseClient, mock_transport: MagicMock, sample_message: Message +): + base_client._card.capabilities.streaming = False + mock_transport.send_message.return_value = Task( + id='task-789', + context_id='ctx-101', + status=TaskStatus(state=TaskState.completed), + ) + + events = [event async for event in base_client.send_message(sample_message)] + + mock_transport.send_message.assert_called_once() + assert not mock_transport.send_message_streaming.called + assert len(events) == 1 + assert events[0][0].id == 'task-789' diff --git a/tests/client/test_client.py b/tests/client/test_client.py deleted file mode 100644 index 547fbe10..00000000 --- a/tests/client/test_client.py +++ /dev/null @@ -1,1296 +0,0 @@ -import json - -from collections.abc import AsyncGenerator -from typing import Any -from unittest.mock import ANY, AsyncMock, MagicMock, patch - -import httpx -import pytest - -from httpx_sse import EventSource, SSEError, ServerSentEvent - -from a2a.client import ( - A2ACardResolver, - A2AClient, - A2AClientHTTPError, - A2AClientJSONError, - A2AClientTimeoutError, - create_text_message_object, -) -from a2a.types import ( - A2ARequest, - AgentCapabilities, - AgentCard, - AgentSkill, - CancelTaskRequest, - CancelTaskResponse, - CancelTaskSuccessResponse, - GetTaskPushNotificationConfigRequest, - GetTaskPushNotificationConfigResponse, - GetTaskPushNotificationConfigSuccessResponse, - GetTaskRequest, - GetTaskResponse, - InvalidParamsError, - JSONRPCErrorResponse, - MessageSendParams, - PushNotificationConfig, - Role, - SendMessageRequest, - SendMessageResponse, - SendMessageSuccessResponse, - SendStreamingMessageRequest, - SendStreamingMessageResponse, - SetTaskPushNotificationConfigRequest, - SetTaskPushNotificationConfigResponse, - SetTaskPushNotificationConfigSuccessResponse, - TaskIdParams, - TaskNotCancelableError, - TaskPushNotificationConfig, - TaskQueryParams, -) -from a2a.utils import AGENT_CARD_WELL_KNOWN_PATH - - -AGENT_CARD = AgentCard( - name='Hello World Agent', - description='Just a hello world agent', - url='http://localhost:9999/', - version='1.0.0', - default_input_modes=['text'], - default_output_modes=['text'], - capabilities=AgentCapabilities(), - skills=[ - AgentSkill( - id='hello_world', - name='Returns hello world', - description='just returns hello world', - tags=['hello world'], - examples=['hi', 'hello world'], - ) - ], -) - -AGENT_CARD_EXTENDED = AGENT_CARD.model_copy( - update={ - 'name': 'Hello World Agent - Extended Edition', - 'skills': [ - *AGENT_CARD.skills, - AgentSkill( - id='extended_skill', - name='Super Greet', - description='A more enthusiastic greeting.', - tags=['extended'], - examples=['super hi'], - ), - ], - 'version': '1.0.1', - } -) - -AGENT_CARD_SUPPORTS_EXTENDED = AGENT_CARD.model_copy( - update={'supports_authenticated_extended_card': True} -) -AGENT_CARD_NO_URL_SUPPORTS_EXTENDED = AGENT_CARD_SUPPORTS_EXTENDED.model_copy( - update={'url': ''} -) - -MINIMAL_TASK: dict[str, Any] = { - 'id': 'task-abc', - 'contextId': 'session-xyz', - 'status': {'state': 'working'}, - 'kind': 'task', -} - -MINIMAL_CANCELLED_TASK: dict[str, Any] = { - 'id': 'task-abc', - 'contextId': 'session-xyz', - 'status': {'state': 'canceled'}, - 'kind': 'task', -} - - -@pytest.fixture -def mock_httpx_client() -> AsyncMock: - return AsyncMock(spec=httpx.AsyncClient) - - -@pytest.fixture -def mock_agent_card() -> MagicMock: - return MagicMock(spec=AgentCard, url='http://agent.example.com/api') - - -async def async_iterable_from_list( - items: list[ServerSentEvent], -) -> AsyncGenerator[ServerSentEvent]: - """Helper to create an async iterable from a list.""" - for item in items: - yield item - - -class TestA2ACardResolver: - BASE_URL = 'http://example.com' - AGENT_CARD_PATH = AGENT_CARD_WELL_KNOWN_PATH - FULL_AGENT_CARD_URL = f'{BASE_URL}{AGENT_CARD_PATH}' - EXTENDED_AGENT_CARD_PATH = ( - '/agent/authenticatedExtendedCard' # Default path - ) - - @pytest.mark.asyncio - async def test_init_parameters_stored_correctly( - self, mock_httpx_client: AsyncMock - ): - base_url = 'http://example.com' - custom_path = '/custom/agent-card.json' - resolver = A2ACardResolver( - httpx_client=mock_httpx_client, - base_url=base_url, - agent_card_path=custom_path, - ) - assert resolver.base_url == base_url - assert resolver.agent_card_path == custom_path.lstrip('/') - assert resolver.httpx_client == mock_httpx_client - - # Test default agent_card_path - resolver_default_path = A2ACardResolver( - httpx_client=mock_httpx_client, - base_url=base_url, - ) - assert ( - '/' + resolver_default_path.agent_card_path - == AGENT_CARD_WELL_KNOWN_PATH - ) - - @pytest.mark.asyncio - async def test_init_strips_slashes(self, mock_httpx_client: AsyncMock): - resolver = A2ACardResolver( - httpx_client=mock_httpx_client, - base_url='http://example.com/', # With trailing slash - agent_card_path='/.well-known/agent-card.json/', # With leading/trailing slash - ) - assert ( - resolver.base_url == 'http://example.com' - ) # Trailing slash stripped - # constructor lstrips agent_card_path, but keeps trailing if provided - assert resolver.agent_card_path == '.well-known/agent-card.json/' - - @pytest.mark.asyncio - async def test_get_agent_card_success_public_only( - self, mock_httpx_client: AsyncMock - ): - mock_response = AsyncMock(spec=httpx.Response) - mock_response.status_code = 200 - mock_response.json.return_value = AGENT_CARD.model_dump(mode='json') - mock_httpx_client.get.return_value = mock_response - - resolver = A2ACardResolver( - httpx_client=mock_httpx_client, - base_url=self.BASE_URL, - agent_card_path=self.AGENT_CARD_PATH, - ) - agent_card = await resolver.get_agent_card(http_kwargs={'timeout': 10}) - - mock_httpx_client.get.assert_called_once_with( - self.FULL_AGENT_CARD_URL, timeout=10 - ) - mock_response.raise_for_status.assert_called_once() - assert isinstance(agent_card, AgentCard) - assert agent_card == AGENT_CARD - # Ensure only one call was made (for the public card) - assert mock_httpx_client.get.call_count == 1 - - @pytest.mark.asyncio - async def test_get_agent_card_success_with_specified_path_for_extended_card( - self, mock_httpx_client: AsyncMock - ): - extended_card_response = AsyncMock(spec=httpx.Response) - extended_card_response.status_code = 200 - extended_card_response.json.return_value = ( - AGENT_CARD_EXTENDED.model_dump(mode='json') - ) - - # Mock the single call for the extended card - mock_httpx_client.get.return_value = extended_card_response - - resolver = A2ACardResolver( - httpx_client=mock_httpx_client, - base_url=self.BASE_URL, - agent_card_path=self.AGENT_CARD_PATH, - ) - - # Fetch the extended card by providing its relative path and example auth - auth_kwargs = {'headers': {'Authorization': 'Bearer test token'}} - agent_card_result = await resolver.get_agent_card( - relative_card_path=self.EXTENDED_AGENT_CARD_PATH, - http_kwargs=auth_kwargs, - ) - - expected_extended_url = ( - f'{self.BASE_URL}/{self.EXTENDED_AGENT_CARD_PATH.lstrip("/")}' - ) - mock_httpx_client.get.assert_called_once_with( - expected_extended_url, **auth_kwargs - ) - extended_card_response.raise_for_status.assert_called_once() - - assert isinstance(agent_card_result, AgentCard) - assert ( - agent_card_result == AGENT_CARD_EXTENDED - ) # Should return the extended card - - @pytest.mark.asyncio - async def test_get_agent_card_validation_error( - self, mock_httpx_client: AsyncMock - ): - mock_response = AsyncMock(spec=httpx.Response) - mock_response.status_code = 200 - # Data that will cause a Pydantic ValidationError - mock_response.json.return_value = { - 'invalid_field': 'value', - 'name': 'Test Agent', - } - mock_httpx_client.get.return_value = mock_response - - resolver = A2ACardResolver( - httpx_client=mock_httpx_client, base_url=self.BASE_URL - ) - # The call that is expected to raise an error should be within pytest.raises - with pytest.raises(A2AClientJSONError) as exc_info: - await resolver.get_agent_card() # Fetches from default path - - assert ( - f'Failed to validate agent card structure from {self.FULL_AGENT_CARD_URL}' - in str(exc_info.value) - ) - assert 'invalid_field' in str( - exc_info.value - ) # Check if Pydantic error details are present - assert ( - mock_httpx_client.get.call_count == 1 - ) # Should only be called once - - @pytest.mark.asyncio - async def test_get_agent_card_http_status_error( - self, mock_httpx_client: AsyncMock - ): - mock_response = MagicMock( - spec=httpx.Response - ) # Use MagicMock for response attribute - mock_response.status_code = 404 - mock_response.text = 'Not Found' - - http_status_error = httpx.HTTPStatusError( - 'Not Found', request=MagicMock(), response=mock_response - ) - mock_httpx_client.get.side_effect = http_status_error - - resolver = A2ACardResolver( - httpx_client=mock_httpx_client, - base_url=self.BASE_URL, - agent_card_path=self.AGENT_CARD_PATH, - ) - - with pytest.raises(A2AClientHTTPError) as exc_info: - await resolver.get_agent_card() - - assert exc_info.value.status_code == 404 - assert ( - f'Failed to fetch agent card from {self.FULL_AGENT_CARD_URL}' - in str(exc_info.value) - ) - assert 'Not Found' in str(exc_info.value) - mock_httpx_client.get.assert_called_once_with(self.FULL_AGENT_CARD_URL) - - @pytest.mark.asyncio - async def test_get_agent_card_json_decode_error( - self, mock_httpx_client: AsyncMock - ): - mock_response = AsyncMock(spec=httpx.Response) - mock_response.status_code = 200 - # Define json_error before using it - json_error = json.JSONDecodeError('Expecting value', 'doc', 0) - mock_response.json.side_effect = json_error - mock_httpx_client.get.return_value = mock_response - - resolver = A2ACardResolver( - httpx_client=mock_httpx_client, - base_url=self.BASE_URL, - agent_card_path=self.AGENT_CARD_PATH, - ) - - with pytest.raises(A2AClientJSONError) as exc_info: - await resolver.get_agent_card() - - # Assertions using exc_info must be after the with block - assert ( - f'Failed to parse JSON for agent card from {self.FULL_AGENT_CARD_URL}' - in str(exc_info.value) - ) - assert 'Expecting value' in str(exc_info.value) - mock_httpx_client.get.assert_called_once_with(self.FULL_AGENT_CARD_URL) - - @pytest.mark.asyncio - async def test_get_agent_card_request_error( - self, mock_httpx_client: AsyncMock - ): - request_error = httpx.RequestError('Network issue', request=MagicMock()) - mock_httpx_client.get.side_effect = request_error - - resolver = A2ACardResolver( - httpx_client=mock_httpx_client, - base_url=self.BASE_URL, - agent_card_path=self.AGENT_CARD_PATH, - ) - - with pytest.raises(A2AClientHTTPError) as exc_info: - await resolver.get_agent_card() - - assert exc_info.value.status_code == 503 - assert ( - f'Network communication error fetching agent card from {self.FULL_AGENT_CARD_URL}' - in str(exc_info.value) - ) - assert 'Network issue' in str(exc_info.value) - mock_httpx_client.get.assert_called_once_with(self.FULL_AGENT_CARD_URL) - - -class TestA2AClient: - AGENT_URL = 'http://agent.example.com/api' - - def test_init_with_agent_card( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - assert client.url == mock_agent_card.url - assert client.httpx_client == mock_httpx_client - - def test_init_with_url(self, mock_httpx_client: AsyncMock): - client = A2AClient(httpx_client=mock_httpx_client, url=self.AGENT_URL) - assert client.url == self.AGENT_URL - assert client.httpx_client == mock_httpx_client - - def test_init_with_agent_card_and_url_prioritizes_agent_card( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - client = A2AClient( - httpx_client=mock_httpx_client, - agent_card=mock_agent_card, - url='http://otherurl.com', - ) - assert ( - client.url == mock_agent_card.url - ) # Agent card URL should be used - - def test_init_raises_value_error_if_no_card_or_url( - self, mock_httpx_client: AsyncMock - ): - with pytest.raises(ValueError) as exc_info: - A2AClient(httpx_client=mock_httpx_client) - assert 'Must provide either agent_card or url' in str(exc_info.value) - - @pytest.mark.asyncio - async def test_get_client_from_agent_card_url_success( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - base_url = 'http://example.com' - agent_card_path = '/.well-known/custom-agent.json' - resolver_kwargs = {'timeout': 30} - - mock_resolver_instance = AsyncMock(spec=A2ACardResolver) - mock_resolver_instance.get_agent_card.return_value = mock_agent_card - - with patch( - 'a2a.client.client.A2ACardResolver', - return_value=mock_resolver_instance, - ) as mock_resolver_class: - client = await A2AClient.get_client_from_agent_card_url( - httpx_client=mock_httpx_client, - base_url=base_url, - agent_card_path=agent_card_path, - http_kwargs=resolver_kwargs, - ) - - mock_resolver_class.assert_called_once_with( - mock_httpx_client, - base_url=base_url, - agent_card_path=agent_card_path, - ) - mock_resolver_instance.get_agent_card.assert_called_once_with( - http_kwargs=resolver_kwargs, - # relative_card_path=None is implied by not passing it - ) - assert isinstance(client, A2AClient) - assert client.url == mock_agent_card.url - assert client.httpx_client == mock_httpx_client - - @pytest.mark.asyncio - async def test_get_client_from_agent_card_url_resolver_error( - self, mock_httpx_client: AsyncMock - ): - error_to_raise = A2AClientHTTPError(404, 'Agent card not found') - with patch( - 'a2a.client.client.A2ACardResolver.get_agent_card', - new_callable=AsyncMock, - side_effect=error_to_raise, - ): - with pytest.raises(A2AClientHTTPError) as exc_info: - await A2AClient.get_client_from_agent_card_url( - httpx_client=mock_httpx_client, - base_url='http://example.com', - ) - assert exc_info.value == error_to_raise - - @pytest.mark.asyncio - async def test_send_message_success_use_request( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - - params = MessageSendParams( - message=create_text_message_object(content='Hello') - ) - - request = SendMessageRequest(id=123, params=params) - - success_response = create_text_message_object( - role=Role.agent, content='Hi there!' - ).model_dump(exclude_none=True) - - rpc_response: dict[str, Any] = { - 'id': 123, - 'jsonrpc': '2.0', - 'result': success_response, - } - - with patch.object( - client, '_send_request', new_callable=AsyncMock - ) as mock_send_req: - mock_send_req.return_value = rpc_response - response = await client.send_message( - request=request, http_kwargs={'timeout': 10} - ) - - assert mock_send_req.call_count == 1 - called_args, called_kwargs = mock_send_req.call_args - assert not called_kwargs # no kwargs to _send_request - assert len(called_args) == 2 - json_rpc_request: dict[str, Any] = called_args[0] - assert isinstance(json_rpc_request['id'], int) - http_kwargs: dict[str, Any] = called_args[1] - assert http_kwargs['timeout'] == 10 - - a2a_request_arg = A2ARequest.model_validate(json_rpc_request) - assert isinstance(a2a_request_arg.root, SendMessageRequest) - assert isinstance(a2a_request_arg.root.params, MessageSendParams) - - assert a2a_request_arg.root.params.model_dump( - exclude_none=True - ) == params.model_dump(exclude_none=True) - - assert isinstance(response, SendMessageResponse) - assert isinstance(response.root, SendMessageSuccessResponse) - assert ( - response.root.result.model_dump(exclude_none=True) - == success_response - ) - - @pytest.mark.asyncio - async def test_send_message_error_response( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - - params = MessageSendParams( - message=create_text_message_object(content='Hello') - ) - - request = SendMessageRequest(id=123, params=params) - - error_response = InvalidParamsError() - - rpc_response: dict[str, Any] = { - 'id': 123, - 'jsonrpc': '2.0', - 'error': error_response.model_dump(exclude_none=True), - } - - with patch.object( - client, '_send_request', new_callable=AsyncMock - ) as mock_send_req: - mock_send_req.return_value = rpc_response - response = await client.send_message(request=request) - - assert isinstance(response, SendMessageResponse) - assert isinstance(response.root, JSONRPCErrorResponse) - assert response.root.error.model_dump( - exclude_none=True - ) == InvalidParamsError().model_dump(exclude_none=True) - - @pytest.mark.asyncio - @patch('a2a.client.client.aconnect_sse') - async def test_send_message_streaming_success_request( - self, - mock_aconnect_sse: AsyncMock, - mock_httpx_client: AsyncMock, - mock_agent_card: MagicMock, - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - params = MessageSendParams( - message=create_text_message_object(content='Hello stream') - ) - - request = SendStreamingMessageRequest(id=123, params=params) - - mock_stream_response_1_dict: dict[str, Any] = { - 'id': 'stream_id_123', - 'jsonrpc': '2.0', - 'result': create_text_message_object( - content='First part ', role=Role.agent - ).model_dump(mode='json', exclude_none=True), - } - mock_stream_response_2_dict: dict[str, Any] = { - 'id': 'stream_id_123', - 'jsonrpc': '2.0', - 'result': create_text_message_object( - content='second part ', role=Role.agent - ).model_dump(mode='json', exclude_none=True), - } - - sse_event_1 = ServerSentEvent( - data=json.dumps(mock_stream_response_1_dict) - ) - sse_event_2 = ServerSentEvent( - data=json.dumps(mock_stream_response_2_dict) - ) - - mock_event_source = AsyncMock(spec=EventSource) - with patch.object(mock_event_source, 'aiter_sse') as mock_aiter_sse: - mock_aiter_sse.return_value = async_iterable_from_list( - [sse_event_1, sse_event_2] - ) - mock_aconnect_sse.return_value.__aenter__.return_value = ( - mock_event_source - ) - - results: list[Any] = [] - async for response in client.send_message_streaming( - request=request - ): - results.append(response) - - assert len(results) == 2 - assert isinstance(results[0], SendStreamingMessageResponse) - # Assuming SendStreamingMessageResponse is a RootModel like SendMessageResponse - assert results[0].root.id == 'stream_id_123' - assert ( - results[0].root.result.model_dump( # type: ignore - mode='json', exclude_none=True - ) - == mock_stream_response_1_dict['result'] - ) - - assert isinstance(results[1], SendStreamingMessageResponse) - assert results[1].root.id == 'stream_id_123' - assert ( - results[1].root.result.model_dump( # type: ignore - mode='json', exclude_none=True - ) - == mock_stream_response_2_dict['result'] - ) - - mock_aconnect_sse.assert_called_once() - call_args, call_kwargs = mock_aconnect_sse.call_args - assert call_args[0] == mock_httpx_client - assert call_args[1] == 'POST' - assert call_args[2] == mock_agent_card.url - - sent_json_payload = call_kwargs['json'] - assert sent_json_payload['method'] == 'message/stream' - assert sent_json_payload['params'] == params.model_dump( - mode='json', exclude_none=True - ) - assert ( - call_kwargs['timeout'] is None - ) # Default timeout for streaming - - @pytest.mark.asyncio - @patch('a2a.client.client.aconnect_sse') - async def test_send_message_streaming_http_kwargs_passed( - self, - mock_aconnect_sse: AsyncMock, - mock_httpx_client: AsyncMock, - mock_agent_card: MagicMock, - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - params = MessageSendParams( - message=create_text_message_object(content='Stream with kwargs') - ) - request = SendStreamingMessageRequest(id='kwarg_req', params=params) - custom_kwargs = { - 'headers': {'X-Custom-Header': 'TestValue'}, - 'timeout': 60, - } - - # Setup mock_aconnect_sse to behave minimally - mock_event_source = AsyncMock(spec=EventSource) - mock_event_source.aiter_sse.return_value = async_iterable_from_list( - [] - ) # No events needed for this test - mock_aconnect_sse.return_value.__aenter__.return_value = ( - mock_event_source - ) - - async for _ in client.send_message_streaming( - request=request, http_kwargs=custom_kwargs - ): - pass # We just want to check the call to aconnect_sse - - mock_aconnect_sse.assert_called_once() - _, called_kwargs = mock_aconnect_sse.call_args - assert called_kwargs['headers'] == custom_kwargs['headers'] - assert ( - called_kwargs['timeout'] == custom_kwargs['timeout'] - ) # Ensure custom timeout is used - - @pytest.mark.asyncio - @patch('a2a.client.client.aconnect_sse') - async def test_send_message_streaming_sse_error_handling( - self, - mock_aconnect_sse: AsyncMock, - mock_httpx_client: AsyncMock, - mock_agent_card: MagicMock, - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - request = SendStreamingMessageRequest( - id='sse_err_req', - params=MessageSendParams( - message=create_text_message_object(content='SSE error test') - ), - ) - - # Configure the mock aconnect_sse to raise SSEError when aiter_sse is called - mock_event_source = AsyncMock(spec=EventSource) - mock_event_source.aiter_sse.side_effect = SSEError( - 'Simulated SSE protocol error' - ) - mock_aconnect_sse.return_value.__aenter__.return_value = ( - mock_event_source - ) - - with pytest.raises(A2AClientHTTPError) as exc_info: - async for _ in client.send_message_streaming(request=request): - pass - - assert exc_info.value.status_code == 400 # As per client implementation - assert 'Invalid SSE response or protocol error' in str(exc_info.value) - assert 'Simulated SSE protocol error' in str(exc_info.value) - - @pytest.mark.asyncio - @patch('a2a.client.client.aconnect_sse') - async def test_send_message_streaming_json_decode_error_handling( - self, - mock_aconnect_sse: AsyncMock, - mock_httpx_client: AsyncMock, - mock_agent_card: MagicMock, - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - request = SendStreamingMessageRequest( - id='json_err_req', - params=MessageSendParams( - message=create_text_message_object(content='JSON error test') - ), - ) - - # Malformed JSON event - malformed_sse_event = ServerSentEvent(data='not valid json') - - mock_event_source = AsyncMock(spec=EventSource) - # json.loads will be called on "not valid json" and raise JSONDecodeError - mock_event_source.aiter_sse.return_value = async_iterable_from_list( - [malformed_sse_event] - ) - mock_aconnect_sse.return_value.__aenter__.return_value = ( - mock_event_source - ) - - with pytest.raises(A2AClientJSONError) as exc_info: - async for _ in client.send_message_streaming(request=request): - pass - - assert 'Expecting value: line 1 column 1 (char 0)' in str( - exc_info.value - ) # Example of JSONDecodeError message - - @pytest.mark.asyncio - @patch('a2a.client.client.aconnect_sse') - async def test_send_message_streaming_httpx_request_error_handling( - self, - mock_aconnect_sse: AsyncMock, - mock_httpx_client: AsyncMock, - mock_agent_card: MagicMock, - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - request = SendStreamingMessageRequest( - id='httpx_err_req', - params=MessageSendParams( - message=create_text_message_object(content='httpx error test') - ), - ) - - # Configure aconnect_sse itself to raise httpx.RequestError (e.g., during connection) - # This needs to be raised when aconnect_sse is entered or iterated. - # One way is to make the context manager's __aenter__ raise it, or aiter_sse. - # For simplicity, let's make aiter_sse raise it, as if the error occurs after connection. - mock_event_source = AsyncMock(spec=EventSource) - mock_event_source.aiter_sse.side_effect = httpx.RequestError( - 'Simulated network error', request=MagicMock() - ) - mock_aconnect_sse.return_value.__aenter__.return_value = ( - mock_event_source - ) - - with pytest.raises(A2AClientHTTPError) as exc_info: - async for _ in client.send_message_streaming(request=request): - pass - - assert exc_info.value.status_code == 503 # As per client implementation - assert 'Network communication error' in str(exc_info.value) - assert 'Simulated network error' in str(exc_info.value) - - @pytest.mark.asyncio - async def test_send_request_http_status_error( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - mock_response = MagicMock(spec=httpx.Response) - mock_response.status_code = 404 - mock_response.text = 'Not Found' - http_error = httpx.HTTPStatusError( - 'Not Found', request=MagicMock(), response=mock_response - ) - mock_httpx_client.post.side_effect = http_error - - with pytest.raises(A2AClientHTTPError) as exc_info: - await client._send_request({}, {}) - - assert exc_info.value.status_code == 404 - assert 'Not Found' in str(exc_info.value) - - @pytest.mark.asyncio - async def test_send_request_json_decode_error( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - mock_response = AsyncMock(spec=httpx.Response) - mock_response.status_code = 200 - json_error = json.JSONDecodeError('Expecting value', 'doc', 0) - mock_response.json.side_effect = json_error - mock_httpx_client.post.return_value = mock_response - - with pytest.raises(A2AClientJSONError) as exc_info: - await client._send_request({}, {}) - - assert 'Expecting value' in str(exc_info.value) - - @pytest.mark.asyncio - async def test_send_request_httpx_request_error( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - request_error = httpx.RequestError('Network issue', request=MagicMock()) - mock_httpx_client.post.side_effect = request_error - - with pytest.raises(A2AClientHTTPError) as exc_info: - await client._send_request({}, {}) - - assert exc_info.value.status_code == 503 - assert 'Network communication error' in str(exc_info.value) - assert 'Network issue' in str(exc_info.value) - - @pytest.mark.asyncio - async def test_set_task_callback_success( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - task_id_val = 'task_set_cb_001' - # Correctly create the PushNotificationConfig (inner model) - push_config_payload = PushNotificationConfig( - url='https://callback.example.com/taskupdate' - ) - # Correctly create the TaskPushNotificationConfig (outer model) - params_model = TaskPushNotificationConfig( - task_id=task_id_val, push_notification_config=push_config_payload - ) - - # request.id will be generated by the client method if not provided - request = SetTaskPushNotificationConfigRequest( - id='', params=params_model - ) # Test ID auto-generation - - # The result for a successful set operation is the same config - rpc_response_payload: dict[str, Any] = { - 'id': ANY, # Will be checked against generated ID - 'jsonrpc': '2.0', - 'result': params_model.model_dump(mode='json', exclude_none=True), - } - - with ( - patch.object( - client, '_send_request', new_callable=AsyncMock - ) as mock_send_req, - patch( - 'a2a.client.client.uuid4', - return_value=MagicMock(hex='testuuid'), - ) as mock_uuid, - ): - # Capture the generated ID for assertion - generated_id = str(mock_uuid.return_value) - rpc_response_payload['id'] = ( - generated_id # Ensure mock response uses the generated ID - ) - mock_send_req.return_value = rpc_response_payload - - response = await client.set_task_callback(request=request) - - mock_send_req.assert_called_once() - called_args, _ = mock_send_req.call_args - sent_json_payload = called_args[0] - - assert sent_json_payload['id'] == generated_id - assert ( - sent_json_payload['method'] - == 'tasks/pushNotificationConfig/set' - ) - assert sent_json_payload['params'] == params_model.model_dump( - mode='json', exclude_none=True - ) - - assert isinstance(response, SetTaskPushNotificationConfigResponse) - assert isinstance( - response.root, SetTaskPushNotificationConfigSuccessResponse - ) - assert response.root.id == generated_id - assert response.root.result.model_dump( - mode='json', exclude_none=True - ) == params_model.model_dump(mode='json', exclude_none=True) - - @pytest.mark.asyncio - async def test_set_task_callback_error_response( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - req_id = 'set_cb_err_req' - push_config_payload = PushNotificationConfig(url='https://errors.com') - params_model = TaskPushNotificationConfig( - task_id='task_err_cb', push_notification_config=push_config_payload - ) - request = SetTaskPushNotificationConfigRequest( - id=req_id, params=params_model - ) - error_details = InvalidParamsError(message='Invalid callback URL') - - rpc_response_payload: dict[str, Any] = { - 'id': req_id, - 'jsonrpc': '2.0', - 'error': error_details.model_dump(mode='json', exclude_none=True), - } - - with patch.object( - client, '_send_request', new_callable=AsyncMock - ) as mock_send_req: - mock_send_req.return_value = rpc_response_payload - response = await client.set_task_callback(request=request) - - assert isinstance(response, SetTaskPushNotificationConfigResponse) - assert isinstance(response.root, JSONRPCErrorResponse) - assert response.root.error.model_dump( - mode='json', exclude_none=True - ) == error_details.model_dump(mode='json', exclude_none=True) - assert response.root.id == req_id - - @pytest.mark.asyncio - async def test_set_task_callback_http_kwargs_passed( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - push_config_payload = PushNotificationConfig(url='https://kwargs.com') - params_model = TaskPushNotificationConfig( - task_id='task_cb_kwargs', - push_notification_config=push_config_payload, - ) - request = SetTaskPushNotificationConfigRequest( - id='cb_kwargs_req', params=params_model - ) - custom_kwargs = {'headers': {'X-Callback-Token': 'secret'}} - - # Minimal successful response - rpc_response_payload: dict[str, Any] = { - 'id': 'cb_kwargs_req', - 'jsonrpc': '2.0', - 'result': params_model.model_dump(mode='json'), - } - - with patch.object( - client, '_send_request', new_callable=AsyncMock - ) as mock_send_req: - mock_send_req.return_value = rpc_response_payload - await client.set_task_callback( - request=request, http_kwargs=custom_kwargs - ) - - mock_send_req.assert_called_once() - called_args, _ = mock_send_req.call_args # Correctly unpack args - assert ( - called_args[1] == custom_kwargs - ) # http_kwargs is the second positional arg - - @pytest.mark.asyncio - async def test_get_task_callback_success( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - task_id_val = 'task_get_cb_001' - params_model = TaskIdParams( - id=task_id_val - ) # Params for get is just TaskIdParams - - request = GetTaskPushNotificationConfigRequest( - id='', params=params_model - ) # ID is empty string for auto-generation test - - # Expected result for a successful get operation - push_config_payload = PushNotificationConfig( - url='https://callback.example.com/taskupdate' - ) - expected_callback_config = TaskPushNotificationConfig( - task_id=task_id_val, push_notification_config=push_config_payload - ) - rpc_response_payload: dict[str, Any] = { - 'id': ANY, - 'jsonrpc': '2.0', - 'result': expected_callback_config.model_dump( - mode='json', exclude_none=True - ), - } - - with ( - patch.object( - client, '_send_request', new_callable=AsyncMock - ) as mock_send_req, - patch( - 'a2a.client.client.uuid4', - return_value=MagicMock(hex='testgetuuid'), - ) as mock_uuid, - ): - generated_id = str(mock_uuid.return_value) - rpc_response_payload['id'] = generated_id - mock_send_req.return_value = rpc_response_payload - - response = await client.get_task_callback(request=request) - - mock_send_req.assert_called_once() - called_args, _ = mock_send_req.call_args - sent_json_payload = called_args[0] - - assert sent_json_payload['id'] == generated_id - assert ( - sent_json_payload['method'] - == 'tasks/pushNotificationConfig/get' - ) - assert sent_json_payload['params'] == params_model.model_dump( - mode='json', exclude_none=True - ) - - assert isinstance(response, GetTaskPushNotificationConfigResponse) - assert isinstance( - response.root, GetTaskPushNotificationConfigSuccessResponse - ) - assert response.root.id == generated_id - assert response.root.result.model_dump( - mode='json', exclude_none=True - ) == expected_callback_config.model_dump( - mode='json', exclude_none=True - ) - - @pytest.mark.asyncio - async def test_get_task_callback_error_response( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - req_id = 'get_cb_err_req' - params_model = TaskIdParams(id='task_get_err_cb') - request = GetTaskPushNotificationConfigRequest( - id=req_id, params=params_model - ) - error_details = TaskNotCancelableError( - message='Cannot get callback for uncancelable task' - ) # Example error - - rpc_response_payload: dict[str, Any] = { - 'id': req_id, - 'jsonrpc': '2.0', - 'error': error_details.model_dump(mode='json', exclude_none=True), - } - - with patch.object( - client, '_send_request', new_callable=AsyncMock - ) as mock_send_req: - mock_send_req.return_value = rpc_response_payload - response = await client.get_task_callback(request=request) - - assert isinstance(response, GetTaskPushNotificationConfigResponse) - assert isinstance(response.root, JSONRPCErrorResponse) - assert response.root.error.model_dump( - mode='json', exclude_none=True - ) == error_details.model_dump(mode='json', exclude_none=True) - assert response.root.id == req_id - - @pytest.mark.asyncio - async def test_get_task_callback_http_kwargs_passed( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - params_model = TaskIdParams(id='task_get_cb_kwargs') - request = GetTaskPushNotificationConfigRequest( - id='get_cb_kwargs_req', params=params_model - ) - custom_kwargs = {'headers': {'X-Tenant-ID': 'tenant-x'}} - - # Correctly create the nested PushNotificationConfig - push_config_payload_for_expected = PushNotificationConfig( - url='https://getkwargs.com' - ) - expected_callback_config = TaskPushNotificationConfig( - task_id='task_get_cb_kwargs', - push_notification_config=push_config_payload_for_expected, - ) - rpc_response_payload: dict[str, Any] = { - 'id': 'get_cb_kwargs_req', - 'jsonrpc': '2.0', - 'result': expected_callback_config.model_dump(mode='json'), - } - - with patch.object( - client, '_send_request', new_callable=AsyncMock - ) as mock_send_req: - mock_send_req.return_value = rpc_response_payload - await client.get_task_callback( - request=request, http_kwargs=custom_kwargs - ) - - mock_send_req.assert_called_once() - called_args, _ = mock_send_req.call_args # Correctly unpack args - assert ( - called_args[1] == custom_kwargs - ) # http_kwargs is the second positional arg - - @pytest.mark.asyncio - async def test_get_task_success_use_request( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - task_id_val = 'task_for_req_obj' - params_model = TaskQueryParams(id=task_id_val) - request_obj_id = 789 - request = GetTaskRequest(id=request_obj_id, params=params_model) - - rpc_response_payload: dict[str, Any] = { - 'id': request_obj_id, - 'jsonrpc': '2.0', - 'result': MINIMAL_TASK, - } - - with patch.object( - client, '_send_request', new_callable=AsyncMock - ) as mock_send_req: - mock_send_req.return_value = rpc_response_payload - response = await client.get_task( - request=request, http_kwargs={'timeout': 20} - ) - - assert mock_send_req.call_count == 1 - called_args, called_kwargs = mock_send_req.call_args - assert len(called_args) == 2 - json_rpc_request_sent: dict[str, Any] = called_args[0] - assert not called_kwargs # no extra kwargs to _send_request - http_kwargs: dict[str, Any] = called_args[1] - assert http_kwargs['timeout'] == 20 - - assert json_rpc_request_sent['method'] == 'tasks/get' - assert json_rpc_request_sent['id'] == request_obj_id - assert json_rpc_request_sent['params'] == params_model.model_dump( - mode='json', exclude_none=True - ) - - assert isinstance(response, GetTaskResponse) - assert hasattr(response.root, 'result') - assert ( - response.root.result.model_dump(mode='json', exclude_none=True) # type: ignore - == MINIMAL_TASK - ) - assert response.root.id == request_obj_id - - @pytest.mark.asyncio - async def test_get_task_error_response( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - params_model = TaskQueryParams(id='task_error_case') - request = GetTaskRequest(id='err_req_id', params=params_model) - error_details = InvalidParamsError() - - rpc_response_payload: dict[str, Any] = { - 'id': 'err_req_id', - 'jsonrpc': '2.0', - 'error': error_details.model_dump(mode='json', exclude_none=True), - } - - with patch.object( - client, '_send_request', new_callable=AsyncMock - ) as mock_send_req: - mock_send_req.return_value = rpc_response_payload - response = await client.get_task(request=request) - - assert isinstance(response, GetTaskResponse) - assert isinstance(response.root, JSONRPCErrorResponse) - assert response.root.error.model_dump( - mode='json', exclude_none=True - ) == error_details.model_dump(exclude_none=True) - assert response.root.id == 'err_req_id' - - @pytest.mark.asyncio - async def test_cancel_task_success_use_request( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - task_id_val = MINIMAL_CANCELLED_TASK['id'] - params_model = TaskIdParams(id=task_id_val) - request_obj_id = 'cancel_req_obj_id_001' - request = CancelTaskRequest(id=request_obj_id, params=params_model) - - rpc_response_payload: dict[str, Any] = { - 'id': request_obj_id, - 'jsonrpc': '2.0', - 'result': MINIMAL_CANCELLED_TASK, - } - - with patch.object( - client, '_send_request', new_callable=AsyncMock - ) as mock_send_req: - mock_send_req.return_value = rpc_response_payload - response = await client.cancel_task( - request=request, http_kwargs={'timeout': 15} - ) - - assert mock_send_req.call_count == 1 - called_args, called_kwargs = mock_send_req.call_args - assert not called_kwargs # no extra kwargs to _send_request - assert len(called_args) == 2 - json_rpc_request_sent: dict[str, Any] = called_args[0] - http_kwargs: dict[str, Any] = called_args[1] - assert http_kwargs['timeout'] == 15 - - assert json_rpc_request_sent['method'] == 'tasks/cancel' - assert json_rpc_request_sent['id'] == request_obj_id - assert json_rpc_request_sent['params'] == params_model.model_dump( - mode='json', exclude_none=True - ) - - assert isinstance(response, CancelTaskResponse) - assert isinstance(response.root, CancelTaskSuccessResponse) - assert ( - response.root.result.model_dump(mode='json', exclude_none=True) # type: ignore - == MINIMAL_CANCELLED_TASK - ) - assert response.root.id == request_obj_id - - @pytest.mark.asyncio - async def test_cancel_task_error_response( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - params_model = TaskIdParams(id='task_cancel_error_case') - request = CancelTaskRequest(id='err_cancel_req', params=params_model) - error_details = TaskNotCancelableError() - - rpc_response_payload: dict[str, Any] = { - 'id': 'err_cancel_req', - 'jsonrpc': '2.0', - 'error': error_details.model_dump(mode='json', exclude_none=True), - } - - with patch.object( - client, '_send_request', new_callable=AsyncMock - ) as mock_send_req: - mock_send_req.return_value = rpc_response_payload - response = await client.cancel_task(request=request) - - assert isinstance(response, CancelTaskResponse) - assert isinstance(response.root, JSONRPCErrorResponse) - assert response.root.error.model_dump( - mode='json', exclude_none=True - ) == error_details.model_dump(exclude_none=True) - assert response.root.id == 'err_cancel_req' - - @pytest.mark.asyncio - async def test_send_message_client_timeout( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - mock_httpx_client.post.side_effect = httpx.ReadTimeout( - 'Request timed out' - ) - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - - params = MessageSendParams( - message=create_text_message_object(content='Hello') - ) - - request = SendMessageRequest(id=123, params=params) - - with pytest.raises(A2AClientTimeoutError) as exc_info: - await client.send_message(request=request) - - assert 'Request timed out' in str(exc_info.value) diff --git a/tests/client/test_client_factory.py b/tests/client/test_client_factory.py new file mode 100644 index 00000000..d615bbff --- /dev/null +++ b/tests/client/test_client_factory.py @@ -0,0 +1,105 @@ +"""Tests for the ClientFactory.""" + +import httpx +import pytest + +from a2a.client import ClientConfig, ClientFactory +from a2a.client.transports import JsonRpcTransport, RestTransport +from a2a.types import ( + AgentCapabilities, + AgentCard, + AgentInterface, + TransportProtocol, +) + + +@pytest.fixture +def base_agent_card() -> AgentCard: + """Provides a base AgentCard for tests.""" + return AgentCard( + name='Test Agent', + description='An agent for testing.', + url='http://primary-url.com', + version='1.0.0', + capabilities=AgentCapabilities(), + skills=[], + default_input_modes=[], + default_output_modes=[], + preferred_transport=TransportProtocol.jsonrpc, + ) + + +def test_client_factory_selects_preferred_transport(base_agent_card: AgentCard): + """Verify that the factory selects the preferred transport by default.""" + config = ClientConfig( + httpx_client=httpx.AsyncClient(), + supported_transports=[ + TransportProtocol.jsonrpc, + TransportProtocol.http_json, + ], + ) + factory = ClientFactory(config) + client = factory.create(base_agent_card) + + assert isinstance(client._transport, JsonRpcTransport) + assert client._transport.url == 'http://primary-url.com' + + +def test_client_factory_selects_secondary_transport_url( + base_agent_card: AgentCard, +): + """Verify that the factory selects the correct URL for a secondary transport.""" + base_agent_card.additional_interfaces = [ + AgentInterface( + transport=TransportProtocol.http_json, + url='http://secondary-url.com', + ) + ] + # Client prefers REST, which is available as a secondary transport + config = ClientConfig( + httpx_client=httpx.AsyncClient(), + supported_transports=[ + TransportProtocol.http_json, + TransportProtocol.jsonrpc, + ], + use_client_preference=True, + ) + factory = ClientFactory(config) + client = factory.create(base_agent_card) + + assert isinstance(client._transport, RestTransport) + assert client._transport.url == 'http://secondary-url.com' + + +def test_client_factory_server_preference(base_agent_card: AgentCard): + """Verify that the factory respects server transport preference.""" + base_agent_card.preferred_transport = TransportProtocol.http_json + base_agent_card.additional_interfaces = [ + AgentInterface( + transport=TransportProtocol.jsonrpc, url='http://secondary-url.com' + ) + ] + # Client supports both, but server prefers REST + config = ClientConfig( + httpx_client=httpx.AsyncClient(), + supported_transports=[ + TransportProtocol.jsonrpc, + TransportProtocol.http_json, + ], + ) + factory = ClientFactory(config) + client = factory.create(base_agent_card) + + assert isinstance(client._transport, RestTransport) + assert client._transport.url == 'http://primary-url.com' + + +def test_client_factory_no_compatible_transport(base_agent_card: AgentCard): + """Verify that the factory raises an error if no compatible transport is found.""" + config = ClientConfig( + httpx_client=httpx.AsyncClient(), + supported_transports=[TransportProtocol.grpc], + ) + factory = ClientFactory(config) + with pytest.raises(ValueError, match='no compatible transports found'): + factory.create(base_agent_card) diff --git a/tests/client/test_client_task_manager.py b/tests/client/test_client_task_manager.py new file mode 100644 index 00000000..fd626d2c --- /dev/null +++ b/tests/client/test_client_task_manager.py @@ -0,0 +1,172 @@ +import pytest +from unittest.mock import AsyncMock, Mock, patch +from a2a.client.client_task_manager import ClientTaskManager +from a2a.client.errors import ( + A2AClientInvalidArgsError, + A2AClientInvalidStateError, +) +from a2a.types import ( + Task, + TaskStatus, + TaskState, + TaskStatusUpdateEvent, + TaskArtifactUpdateEvent, + Message, + Role, + Part, + TextPart, + Artifact, +) + + +@pytest.fixture +def task_manager(): + return ClientTaskManager() + + +@pytest.fixture +def sample_task(): + return Task( + id='task123', + context_id='context456', + status=TaskStatus(state=TaskState.working), + history=[], + artifacts=[], + ) + + +@pytest.fixture +def sample_message(): + return Message( + message_id='msg1', + role=Role.user, + parts=[Part(root=TextPart(text='Hello'))], + ) + + +def test_get_task_no_task_id_returns_none(task_manager: ClientTaskManager): + assert task_manager.get_task() is None + + +def test_get_task_or_raise_no_task_raises_error( + task_manager: ClientTaskManager, +): + with pytest.raises(A2AClientInvalidStateError, match='no current Task'): + task_manager.get_task_or_raise() + + +@pytest.mark.asyncio +async def test_save_task_event_with_task( + task_manager: ClientTaskManager, sample_task: Task +): + await task_manager.save_task_event(sample_task) + assert task_manager.get_task() == sample_task + assert task_manager._task_id == sample_task.id + assert task_manager._context_id == sample_task.context_id + + +@pytest.mark.asyncio +async def test_save_task_event_with_task_already_set_raises_error( + task_manager: ClientTaskManager, sample_task: Task +): + await task_manager.save_task_event(sample_task) + with pytest.raises( + A2AClientInvalidArgsError, + match='Task is already set, create new manager for new tasks.', + ): + await task_manager.save_task_event(sample_task) + + +@pytest.mark.asyncio +async def test_save_task_event_with_status_update( + task_manager: ClientTaskManager, sample_task: Task, sample_message: Message +): + await task_manager.save_task_event(sample_task) + status_update = TaskStatusUpdateEvent( + task_id=sample_task.id, + context_id=sample_task.context_id, + status=TaskStatus(state=TaskState.completed, message=sample_message), + final=True, + ) + updated_task = await task_manager.save_task_event(status_update) + assert updated_task.status.state == TaskState.completed + assert updated_task.history == [sample_message] + + +@pytest.mark.asyncio +async def test_save_task_event_with_artifact_update( + task_manager: ClientTaskManager, sample_task: Task +): + await task_manager.save_task_event(sample_task) + artifact = Artifact( + artifact_id='art1', parts=[Part(root=TextPart(text='artifact content'))] + ) + artifact_update = TaskArtifactUpdateEvent( + task_id=sample_task.id, + context_id=sample_task.context_id, + artifact=artifact, + ) + + with patch( + 'a2a.client.client_task_manager.append_artifact_to_task' + ) as mock_append: + updated_task = await task_manager.save_task_event(artifact_update) + mock_append.assert_called_once_with(updated_task, artifact_update) + + +@pytest.mark.asyncio +async def test_save_task_event_creates_task_if_not_exists( + task_manager: ClientTaskManager, +): + status_update = TaskStatusUpdateEvent( + task_id='new_task', + context_id='new_context', + status=TaskStatus(state=TaskState.working), + final=False, + ) + updated_task = await task_manager.save_task_event(status_update) + assert updated_task is not None + assert updated_task.id == 'new_task' + assert updated_task.status.state == TaskState.working + + +@pytest.mark.asyncio +async def test_process_with_task_event( + task_manager: ClientTaskManager, sample_task: Task +): + with patch.object( + task_manager, 'save_task_event', new_callable=AsyncMock + ) as mock_save: + await task_manager.process(sample_task) + mock_save.assert_called_once_with(sample_task) + + +@pytest.mark.asyncio +async def test_process_with_non_task_event(task_manager: ClientTaskManager): + with patch.object( + task_manager, 'save_task_event', new_callable=Mock + ) as mock_save: + non_task_event = 'not a task event' + await task_manager.process(non_task_event) + mock_save.assert_not_called() + + +def test_update_with_message( + task_manager: ClientTaskManager, sample_task: Task, sample_message: Message +): + updated_task = task_manager.update_with_message(sample_message, sample_task) + assert updated_task.history == [sample_message] + + +def test_update_with_message_moves_status_message( + task_manager: ClientTaskManager, sample_task: Task, sample_message: Message +): + status_message = Message( + message_id='status_msg', + role=Role.agent, + parts=[Part(root=TextPart(text='Status'))], + ) + sample_task.status.message = status_message + updated_task = task_manager.update_with_message(sample_message, sample_task) + assert updated_task.history == [status_message, sample_message] + assert updated_task.status.message is None diff --git a/tests/client/test_grpc_client.py b/tests/client/test_grpc_client.py index 26967d73..c2dbc2b8 100644 --- a/tests/client/test_grpc_client.py +++ b/tests/client/test_grpc_client.py @@ -2,7 +2,7 @@ import pytest -from a2a.client import A2AGrpcClient +from a2a.client.transports.grpc import GrpcTransport from a2a.grpc import a2a_pb2, a2a_pb2_grpc from a2a.types import ( AgentCapabilities, @@ -51,11 +51,14 @@ def sample_agent_card() -> AgentCard: @pytest.fixture -def grpc_client( +def grpc_transport( mock_grpc_stub: AsyncMock, sample_agent_card: AgentCard -) -> A2AGrpcClient: - """Provides an A2AGrpcClient instance.""" - return A2AGrpcClient(grpc_stub=mock_grpc_stub, agent_card=sample_agent_card) +) -> GrpcTransport: + """Provides a GrpcTransport instance.""" + channel = AsyncMock() + transport = GrpcTransport(channel=channel, agent_card=sample_agent_card) + transport.stub = mock_grpc_stub + return transport @pytest.fixture @@ -92,7 +95,7 @@ def sample_message() -> Message: @pytest.mark.asyncio async def test_send_message_task_response( - grpc_client: A2AGrpcClient, + grpc_transport: GrpcTransport, mock_grpc_stub: AsyncMock, sample_message_send_params: MessageSendParams, sample_task: Task, @@ -102,7 +105,7 @@ async def test_send_message_task_response( task=proto_utils.ToProto.task(sample_task) ) - response = await grpc_client.send_message(sample_message_send_params) + response = await grpc_transport.send_message(sample_message_send_params) mock_grpc_stub.SendMessage.assert_awaited_once() assert isinstance(response, Task) @@ -111,13 +114,13 @@ async def test_send_message_task_response( @pytest.mark.asyncio async def test_get_task( - grpc_client: A2AGrpcClient, mock_grpc_stub: AsyncMock, sample_task: Task + grpc_transport: GrpcTransport, mock_grpc_stub: AsyncMock, sample_task: Task ): """Test retrieving a task.""" mock_grpc_stub.GetTask.return_value = proto_utils.ToProto.task(sample_task) params = TaskQueryParams(id=sample_task.id) - response = await grpc_client.get_task(params) + response = await grpc_transport.get_task(params) mock_grpc_stub.GetTask.assert_awaited_once_with( a2a_pb2.GetTaskRequest(name=f'tasks/{sample_task.id}') @@ -127,7 +130,7 @@ async def test_get_task( @pytest.mark.asyncio async def test_cancel_task( - grpc_client: A2AGrpcClient, mock_grpc_stub: AsyncMock, sample_task: Task + grpc_transport: GrpcTransport, mock_grpc_stub: AsyncMock, sample_task: Task ): """Test cancelling a task.""" cancelled_task = sample_task.model_copy() @@ -137,7 +140,7 @@ async def test_cancel_task( ) params = TaskIdParams(id=sample_task.id) - response = await grpc_client.cancel_task(params) + response = await grpc_transport.cancel_task(params) mock_grpc_stub.CancelTask.assert_awaited_once_with( a2a_pb2.CancelTaskRequest(name=f'tasks/{sample_task.id}') diff --git a/tests/client/test_jsonrpc_client.py b/tests/client/test_jsonrpc_client.py new file mode 100644 index 00000000..58feec25 --- /dev/null +++ b/tests/client/test_jsonrpc_client.py @@ -0,0 +1,787 @@ +import json + +from collections.abc import AsyncGenerator +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from httpx_sse import EventSource, SSEError, ServerSentEvent + +from a2a.client import ( + A2ACardResolver, + A2AClientHTTPError, + A2AClientJSONError, + A2AClientTimeoutError, + create_text_message_object, +) +from a2a.client.transports.jsonrpc import JsonRpcTransport +from a2a.types import ( + AgentCapabilities, + AgentCard, + AgentSkill, + InvalidParamsError, + Message, + MessageSendParams, + PushNotificationConfig, + Role, + SendMessageSuccessResponse, + Task, + TaskIdParams, + TaskPushNotificationConfig, + TaskQueryParams, +) +from a2a.utils import AGENT_CARD_WELL_KNOWN_PATH + + +AGENT_CARD = AgentCard( + name='Hello World Agent', + description='Just a hello world agent', + url='http://localhost:9999/', + version='1.0.0', + default_input_modes=['text'], + default_output_modes=['text'], + capabilities=AgentCapabilities(), + skills=[ + AgentSkill( + id='hello_world', + name='Returns hello world', + description='just returns hello world', + tags=['hello world'], + examples=['hi', 'hello world'], + ) + ], +) + +AGENT_CARD_EXTENDED = AGENT_CARD.model_copy( + update={ + 'name': 'Hello World Agent - Extended Edition', + 'skills': [ + *AGENT_CARD.skills, + AgentSkill( + id='extended_skill', + name='Super Greet', + description='A more enthusiastic greeting.', + tags=['extended'], + examples=['super hi'], + ), + ], + 'version': '1.0.1', + } +) + +AGENT_CARD_SUPPORTS_EXTENDED = AGENT_CARD.model_copy( + update={'supports_authenticated_extended_card': True} +) +AGENT_CARD_NO_URL_SUPPORTS_EXTENDED = AGENT_CARD_SUPPORTS_EXTENDED.model_copy( + update={'url': ''} +) + +MINIMAL_TASK: dict[str, Any] = { + 'id': 'task-abc', + 'contextId': 'session-xyz', + 'status': {'state': 'working'}, + 'kind': 'task', +} + +MINIMAL_CANCELLED_TASK: dict[str, Any] = { + 'id': 'task-abc', + 'contextId': 'session-xyz', + 'status': {'state': 'canceled'}, + 'kind': 'task', +} + + +@pytest.fixture +def mock_httpx_client() -> AsyncMock: + return AsyncMock(spec=httpx.AsyncClient) + + +@pytest.fixture +def mock_agent_card() -> MagicMock: + mock = MagicMock(spec=AgentCard, url='http://agent.example.com/api') + mock.supports_authenticated_extended_card = False + return mock + + +async def async_iterable_from_list( + items: list[ServerSentEvent], +) -> AsyncGenerator[ServerSentEvent, None]: + """Helper to create an async iterable from a list.""" + for item in items: + yield item + + +class TestA2ACardResolver: + BASE_URL = 'http://example.com' + AGENT_CARD_PATH = AGENT_CARD_WELL_KNOWN_PATH + FULL_AGENT_CARD_URL = f'{BASE_URL}{AGENT_CARD_PATH}' + EXTENDED_AGENT_CARD_PATH = '/agent/authenticatedExtendedCard' + + @pytest.mark.asyncio + async def test_init_parameters_stored_correctly( + self, mock_httpx_client: AsyncMock + ): + base_url = 'http://example.com' + custom_path = '/custom/agent-card.json' + resolver = A2ACardResolver( + httpx_client=mock_httpx_client, + base_url=base_url, + agent_card_path=custom_path, + ) + assert resolver.base_url == base_url + assert resolver.agent_card_path == custom_path.lstrip('/') + assert resolver.httpx_client == mock_httpx_client + + resolver_default_path = A2ACardResolver( + httpx_client=mock_httpx_client, + base_url=base_url, + ) + assert ( + '/' + resolver_default_path.agent_card_path + == AGENT_CARD_WELL_KNOWN_PATH + ) + + @pytest.mark.asyncio + async def test_init_strips_slashes(self, mock_httpx_client: AsyncMock): + resolver = A2ACardResolver( + httpx_client=mock_httpx_client, + base_url='http://example.com/', + agent_card_path='/.well-known/agent-card.json/', + ) + assert resolver.base_url == 'http://example.com' + assert resolver.agent_card_path == '.well-known/agent-card.json/' + + @pytest.mark.asyncio + async def test_get_agent_card_success_public_only( + self, mock_httpx_client: AsyncMock + ): + mock_response = AsyncMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.json.return_value = AGENT_CARD.model_dump(mode='json') + mock_httpx_client.get.return_value = mock_response + + resolver = A2ACardResolver( + httpx_client=mock_httpx_client, + base_url=self.BASE_URL, + agent_card_path=self.AGENT_CARD_PATH, + ) + agent_card = await resolver.get_agent_card(http_kwargs={'timeout': 10}) + + mock_httpx_client.get.assert_called_once_with( + self.FULL_AGENT_CARD_URL, timeout=10 + ) + mock_response.raise_for_status.assert_called_once() + assert isinstance(agent_card, AgentCard) + assert agent_card == AGENT_CARD + assert mock_httpx_client.get.call_count == 1 + + @pytest.mark.asyncio + async def test_get_agent_card_success_with_specified_path_for_extended_card( + self, mock_httpx_client: AsyncMock + ): + extended_card_response = AsyncMock(spec=httpx.Response) + extended_card_response.status_code = 200 + extended_card_response.json.return_value = ( + AGENT_CARD_EXTENDED.model_dump(mode='json') + ) + mock_httpx_client.get.return_value = extended_card_response + + resolver = A2ACardResolver( + httpx_client=mock_httpx_client, + base_url=self.BASE_URL, + agent_card_path=self.AGENT_CARD_PATH, + ) + + auth_kwargs = {'headers': {'Authorization': 'Bearer test token'}} + agent_card_result = await resolver.get_agent_card( + relative_card_path=self.EXTENDED_AGENT_CARD_PATH, + http_kwargs=auth_kwargs, + ) + + expected_extended_url = ( + f'{self.BASE_URL}/{self.EXTENDED_AGENT_CARD_PATH.lstrip("/")}' + ) + mock_httpx_client.get.assert_called_once_with( + expected_extended_url, **auth_kwargs + ) + extended_card_response.raise_for_status.assert_called_once() + assert isinstance(agent_card_result, AgentCard) + assert agent_card_result == AGENT_CARD_EXTENDED + + @pytest.mark.asyncio + async def test_get_agent_card_validation_error( + self, mock_httpx_client: AsyncMock + ): + mock_response = AsyncMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.json.return_value = { + 'invalid_field': 'value', + 'name': 'Test Agent', + } + mock_httpx_client.get.return_value = mock_response + + resolver = A2ACardResolver( + httpx_client=mock_httpx_client, base_url=self.BASE_URL + ) + with pytest.raises(A2AClientJSONError) as exc_info: + await resolver.get_agent_card() + + assert ( + f'Failed to validate agent card structure from {self.FULL_AGENT_CARD_URL}' + in str(exc_info.value) + ) + assert 'invalid_field' in str(exc_info.value) + assert mock_httpx_client.get.call_count == 1 + + @pytest.mark.asyncio + async def test_get_agent_card_http_status_error( + self, mock_httpx_client: AsyncMock + ): + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 404 + mock_response.text = 'Not Found' + http_status_error = httpx.HTTPStatusError( + 'Not Found', request=MagicMock(), response=mock_response + ) + mock_httpx_client.get.side_effect = http_status_error + + resolver = A2ACardResolver( + httpx_client=mock_httpx_client, + base_url=self.BASE_URL, + agent_card_path=self.AGENT_CARD_PATH, + ) + + with pytest.raises(A2AClientHTTPError) as exc_info: + await resolver.get_agent_card() + + assert exc_info.value.status_code == 404 + assert ( + f'Failed to fetch agent card from {self.FULL_AGENT_CARD_URL}' + in str(exc_info.value) + ) + assert 'Not Found' in str(exc_info.value) + mock_httpx_client.get.assert_called_once_with(self.FULL_AGENT_CARD_URL) + + @pytest.mark.asyncio + async def test_get_agent_card_json_decode_error( + self, mock_httpx_client: AsyncMock + ): + mock_response = AsyncMock(spec=httpx.Response) + mock_response.status_code = 200 + json_error = json.JSONDecodeError('Expecting value', 'doc', 0) + mock_response.json.side_effect = json_error + mock_httpx_client.get.return_value = mock_response + + resolver = A2ACardResolver( + httpx_client=mock_httpx_client, + base_url=self.BASE_URL, + agent_card_path=self.AGENT_CARD_PATH, + ) + + with pytest.raises(A2AClientJSONError) as exc_info: + await resolver.get_agent_card() + + assert ( + f'Failed to parse JSON for agent card from {self.FULL_AGENT_CARD_URL}' + in str(exc_info.value) + ) + assert 'Expecting value' in str(exc_info.value) + mock_httpx_client.get.assert_called_once_with(self.FULL_AGENT_CARD_URL) + + @pytest.mark.asyncio + async def test_get_agent_card_request_error( + self, mock_httpx_client: AsyncMock + ): + request_error = httpx.RequestError('Network issue', request=MagicMock()) + mock_httpx_client.get.side_effect = request_error + + resolver = A2ACardResolver( + httpx_client=mock_httpx_client, + base_url=self.BASE_URL, + agent_card_path=self.AGENT_CARD_PATH, + ) + + with pytest.raises(A2AClientHTTPError) as exc_info: + await resolver.get_agent_card() + + assert exc_info.value.status_code == 503 + assert ( + f'Network communication error fetching agent card from {self.FULL_AGENT_CARD_URL}' + in str(exc_info.value) + ) + assert 'Network issue' in str(exc_info.value) + mock_httpx_client.get.assert_called_once_with(self.FULL_AGENT_CARD_URL) + + +class TestJsonRpcTransport: + AGENT_URL = 'http://agent.example.com/api' + + def test_init_with_agent_card( + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + ): + client = JsonRpcTransport( + httpx_client=mock_httpx_client, agent_card=mock_agent_card + ) + assert client.url == mock_agent_card.url + assert client.httpx_client == mock_httpx_client + + def test_init_with_url(self, mock_httpx_client: AsyncMock): + client = JsonRpcTransport( + httpx_client=mock_httpx_client, url=self.AGENT_URL + ) + assert client.url == self.AGENT_URL + assert client.httpx_client == mock_httpx_client + + def test_init_with_agent_card_and_url_prioritizes_url( + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + ): + client = JsonRpcTransport( + httpx_client=mock_httpx_client, + agent_card=mock_agent_card, + url='http://otherurl.com', + ) + assert client.url == 'http://otherurl.com' + + def test_init_raises_value_error_if_no_card_or_url( + self, mock_httpx_client: AsyncMock + ): + with pytest.raises(ValueError) as exc_info: + JsonRpcTransport(httpx_client=mock_httpx_client) + assert 'Must provide either agent_card or url' in str(exc_info.value) + + @pytest.mark.asyncio + async def test_send_message_success( + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + ): + client = JsonRpcTransport( + httpx_client=mock_httpx_client, agent_card=mock_agent_card + ) + params = MessageSendParams( + message=create_text_message_object(content='Hello') + ) + success_response = create_text_message_object( + role=Role.agent, content='Hi there!' + ) + rpc_response = SendMessageSuccessResponse( + id='123', jsonrpc='2.0', result=success_response + ) + response = httpx.Response( + 200, json=rpc_response.model_dump(mode='json') + ) + response.request = httpx.Request('POST', 'http://agent.example.com/api') + mock_httpx_client.post.return_value = response + + response = await client.send_message(request=params) + + assert isinstance(response, Message) + assert response.model_dump() == success_response.model_dump() + + @pytest.mark.asyncio + async def test_send_message_error_response( + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + ): + client = JsonRpcTransport( + httpx_client=mock_httpx_client, agent_card=mock_agent_card + ) + params = MessageSendParams( + message=create_text_message_object(content='Hello') + ) + error_response = InvalidParamsError() + rpc_response = { + 'id': '123', + 'jsonrpc': '2.0', + 'error': error_response.model_dump(exclude_none=True), + } + mock_httpx_client.post.return_value.json.return_value = rpc_response + + with pytest.raises(Exception): + await client.send_message(request=params) + + @pytest.mark.asyncio + @patch('a2a.client.transports.jsonrpc.aconnect_sse') + async def test_send_message_streaming_success( + self, + mock_aconnect_sse: AsyncMock, + mock_httpx_client: AsyncMock, + mock_agent_card: MagicMock, + ): + client = JsonRpcTransport( + httpx_client=mock_httpx_client, agent_card=mock_agent_card + ) + params = MessageSendParams( + message=create_text_message_object(content='Hello stream') + ) + mock_stream_response_1 = SendMessageSuccessResponse( + id='stream_id_123', + jsonrpc='2.0', + result=create_text_message_object( + content='First part ', role=Role.agent + ), + ) + mock_stream_response_2 = SendMessageSuccessResponse( + id='stream_id_123', + jsonrpc='2.0', + result=create_text_message_object( + content='second part ', role=Role.agent + ), + ) + sse_event_1 = ServerSentEvent( + data=mock_stream_response_1.model_dump_json() + ) + sse_event_2 = ServerSentEvent( + data=mock_stream_response_2.model_dump_json() + ) + mock_event_source = AsyncMock(spec=EventSource) + mock_event_source.aiter_sse.return_value = async_iterable_from_list( + [sse_event_1, sse_event_2] + ) + mock_aconnect_sse.return_value.__aenter__.return_value = ( + mock_event_source + ) + + results = [ + item async for item in client.send_message_streaming(request=params) + ] + + assert len(results) == 2 + assert isinstance(results[0], Message) + assert ( + results[0].model_dump() + == mock_stream_response_1.result.model_dump() + ) + assert isinstance(results[1], Message) + assert ( + results[1].model_dump() + == mock_stream_response_2.result.model_dump() + ) + + @pytest.mark.asyncio + async def test_send_request_http_status_error( + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + ): + client = JsonRpcTransport( + httpx_client=mock_httpx_client, agent_card=mock_agent_card + ) + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 404 + mock_response.text = 'Not Found' + http_error = httpx.HTTPStatusError( + 'Not Found', request=MagicMock(), response=mock_response + ) + mock_httpx_client.post.side_effect = http_error + + with pytest.raises(A2AClientHTTPError) as exc_info: + await client._send_request({}, {}) + + assert exc_info.value.status_code == 404 + assert 'Not Found' in str(exc_info.value) + + @pytest.mark.asyncio + async def test_send_request_json_decode_error( + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + ): + client = JsonRpcTransport( + httpx_client=mock_httpx_client, agent_card=mock_agent_card + ) + mock_response = AsyncMock(spec=httpx.Response) + mock_response.status_code = 200 + json_error = json.JSONDecodeError('Expecting value', 'doc', 0) + mock_response.json.side_effect = json_error + mock_httpx_client.post.return_value = mock_response + + with pytest.raises(A2AClientJSONError) as exc_info: + await client._send_request({}, {}) + + assert 'Expecting value' in str(exc_info.value) + + @pytest.mark.asyncio + async def test_send_request_httpx_request_error( + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + ): + client = JsonRpcTransport( + httpx_client=mock_httpx_client, agent_card=mock_agent_card + ) + request_error = httpx.RequestError('Network issue', request=MagicMock()) + mock_httpx_client.post.side_effect = request_error + + with pytest.raises(A2AClientHTTPError) as exc_info: + await client._send_request({}, {}) + + assert exc_info.value.status_code == 503 + assert 'Network communication error' in str(exc_info.value) + assert 'Network issue' in str(exc_info.value) + + @pytest.mark.asyncio + async def test_send_message_client_timeout( + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + ): + mock_httpx_client.post.side_effect = httpx.ReadTimeout( + 'Request timed out' + ) + client = JsonRpcTransport( + httpx_client=mock_httpx_client, agent_card=mock_agent_card + ) + params = MessageSendParams( + message=create_text_message_object(content='Hello') + ) + + with pytest.raises(A2AClientTimeoutError) as exc_info: + await client.send_message(request=params) + + assert 'Client Request timed out' in str(exc_info.value) + + @pytest.mark.asyncio + async def test_get_task_success( + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + ): + client = JsonRpcTransport( + httpx_client=mock_httpx_client, agent_card=mock_agent_card + ) + params = TaskQueryParams(id='task-abc') + rpc_response = { + 'id': '123', + 'jsonrpc': '2.0', + 'result': MINIMAL_TASK, + } + with patch.object( + client, '_send_request', new_callable=AsyncMock + ) as mock_send_request: + mock_send_request.return_value = rpc_response + response = await client.get_task(request=params) + + assert isinstance(response, Task) + assert ( + response.model_dump() + == Task.model_validate(MINIMAL_TASK).model_dump() + ) + mock_send_request.assert_called_once() + sent_payload = mock_send_request.call_args.args[0] + assert sent_payload['method'] == 'tasks/get' + + @pytest.mark.asyncio + async def test_cancel_task_success( + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + ): + client = JsonRpcTransport( + httpx_client=mock_httpx_client, agent_card=mock_agent_card + ) + params = TaskIdParams(id='task-abc') + rpc_response = { + 'id': '123', + 'jsonrpc': '2.0', + 'result': MINIMAL_CANCELLED_TASK, + } + with patch.object( + client, '_send_request', new_callable=AsyncMock + ) as mock_send_request: + mock_send_request.return_value = rpc_response + response = await client.cancel_task(request=params) + + assert isinstance(response, Task) + assert ( + response.model_dump() + == Task.model_validate(MINIMAL_CANCELLED_TASK).model_dump() + ) + mock_send_request.assert_called_once() + sent_payload = mock_send_request.call_args.args[0] + assert sent_payload['method'] == 'tasks/cancel' + + @pytest.mark.asyncio + async def test_set_task_callback_success( + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + ): + client = JsonRpcTransport( + httpx_client=mock_httpx_client, agent_card=mock_agent_card + ) + params = TaskPushNotificationConfig( + task_id='task-abc', + push_notification_config=PushNotificationConfig( + url='http://callback.com' + ), + ) + rpc_response = { + 'id': '123', + 'jsonrpc': '2.0', + 'result': params.model_dump(mode='json'), + } + with patch.object( + client, '_send_request', new_callable=AsyncMock + ) as mock_send_request: + mock_send_request.return_value = rpc_response + response = await client.set_task_callback(request=params) + + assert isinstance(response, TaskPushNotificationConfig) + assert response.model_dump() == params.model_dump() + mock_send_request.assert_called_once() + sent_payload = mock_send_request.call_args.args[0] + assert sent_payload['method'] == 'tasks/pushNotificationConfig/set' + + @pytest.mark.asyncio + async def test_get_task_callback_success( + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + ): + client = JsonRpcTransport( + httpx_client=mock_httpx_client, agent_card=mock_agent_card + ) + params = TaskIdParams(id='task-abc') + expected_response = TaskPushNotificationConfig( + task_id='task-abc', + push_notification_config=PushNotificationConfig( + url='http://callback.com' + ), + ) + rpc_response = { + 'id': '123', + 'jsonrpc': '2.0', + 'result': expected_response.model_dump(mode='json'), + } + with patch.object( + client, '_send_request', new_callable=AsyncMock + ) as mock_send_request: + mock_send_request.return_value = rpc_response + response = await client.get_task_callback(request=params) + + assert isinstance(response, TaskPushNotificationConfig) + assert response.model_dump() == expected_response.model_dump() + mock_send_request.assert_called_once() + sent_payload = mock_send_request.call_args.args[0] + assert sent_payload['method'] == 'tasks/pushNotificationConfig/get' + + @pytest.mark.asyncio + @patch('a2a.client.transports.jsonrpc.aconnect_sse') + async def test_send_message_streaming_sse_error( + self, + mock_aconnect_sse: AsyncMock, + mock_httpx_client: AsyncMock, + mock_agent_card: MagicMock, + ): + client = JsonRpcTransport( + httpx_client=mock_httpx_client, agent_card=mock_agent_card + ) + params = MessageSendParams( + message=create_text_message_object(content='Hello stream') + ) + mock_event_source = AsyncMock(spec=EventSource) + mock_event_source.aiter_sse.side_effect = SSEError( + 'Simulated SSE error' + ) + mock_aconnect_sse.return_value.__aenter__.return_value = ( + mock_event_source + ) + + with pytest.raises(A2AClientHTTPError): + _ = [ + item + async for item in client.send_message_streaming(request=params) + ] + + @pytest.mark.asyncio + @patch('a2a.client.transports.jsonrpc.aconnect_sse') + async def test_send_message_streaming_json_error( + self, + mock_aconnect_sse: AsyncMock, + mock_httpx_client: AsyncMock, + mock_agent_card: MagicMock, + ): + client = JsonRpcTransport( + httpx_client=mock_httpx_client, agent_card=mock_agent_card + ) + params = MessageSendParams( + message=create_text_message_object(content='Hello stream') + ) + sse_event = ServerSentEvent(data='{invalid json') + mock_event_source = AsyncMock(spec=EventSource) + mock_event_source.aiter_sse.return_value = async_iterable_from_list( + [sse_event] + ) + mock_aconnect_sse.return_value.__aenter__.return_value = ( + mock_event_source + ) + + with pytest.raises(A2AClientJSONError): + _ = [ + item + async for item in client.send_message_streaming(request=params) + ] + + @pytest.mark.asyncio + @patch('a2a.client.transports.jsonrpc.aconnect_sse') + async def test_send_message_streaming_request_error( + self, + mock_aconnect_sse: AsyncMock, + mock_httpx_client: AsyncMock, + mock_agent_card: MagicMock, + ): + client = JsonRpcTransport( + httpx_client=mock_httpx_client, agent_card=mock_agent_card + ) + params = MessageSendParams( + message=create_text_message_object(content='Hello stream') + ) + mock_event_source = AsyncMock(spec=EventSource) + mock_event_source.aiter_sse.side_effect = httpx.RequestError( + 'Simulated request error', request=MagicMock() + ) + mock_aconnect_sse.return_value.__aenter__.return_value = ( + mock_event_source + ) + + with pytest.raises(A2AClientHTTPError): + _ = [ + item + async for item in client.send_message_streaming(request=params) + ] + + @pytest.mark.asyncio + async def test_get_card_no_card_provided( + self, mock_httpx_client: AsyncMock + ): + client = JsonRpcTransport( + httpx_client=mock_httpx_client, url=self.AGENT_URL + ) + mock_response = AsyncMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.json.return_value = AGENT_CARD.model_dump(mode='json') + mock_httpx_client.get.return_value = mock_response + + card = await client.get_card() + + assert card == AGENT_CARD + mock_httpx_client.get.assert_called_once() + + @pytest.mark.asyncio + async def test_get_card_with_extended_card_support( + self, mock_httpx_client: AsyncMock + ): + agent_card = AGENT_CARD.model_copy( + update={'supports_authenticated_extended_card': True} + ) + client = JsonRpcTransport( + httpx_client=mock_httpx_client, agent_card=agent_card + ) + + rpc_response = { + 'id': '123', + 'jsonrpc': '2.0', + 'result': AGENT_CARD_EXTENDED.model_dump(mode='json'), + } + with patch.object( + client, '_send_request', new_callable=AsyncMock + ) as mock_send_request: + mock_send_request.return_value = rpc_response + card = await client.get_card() + + assert card == agent_card + mock_send_request.assert_called_once() + sent_payload = mock_send_request.call_args.args[0] + assert sent_payload['method'] == 'agent/getAuthenticatedExtendedCard' + + @pytest.mark.asyncio + async def test_close(self, mock_httpx_client: AsyncMock): + client = JsonRpcTransport( + httpx_client=mock_httpx_client, url=self.AGENT_URL + ) + await client.close() + mock_httpx_client.aclose.assert_called_once() diff --git a/tests/client/test_legacy_client.py b/tests/client/test_legacy_client.py new file mode 100644 index 00000000..1bd9e4ae --- /dev/null +++ b/tests/client/test_legacy_client.py @@ -0,0 +1,115 @@ +"""Tests for the legacy client compatibility layer.""" + +from unittest.mock import AsyncMock, MagicMock + +import httpx +import pytest + +from a2a.client import A2AClient, A2AGrpcClient +from a2a.types import ( + AgentCapabilities, + AgentCard, + Message, + MessageSendParams, + Part, + Role, + SendMessageRequest, + Task, + TaskQueryParams, + TaskState, + TaskStatus, + TextPart, +) + + +@pytest.fixture +def mock_httpx_client() -> AsyncMock: + return AsyncMock(spec=httpx.AsyncClient) + + +@pytest.fixture +def mock_grpc_stub() -> AsyncMock: + stub = AsyncMock() + stub._channel = MagicMock() + return stub + + +@pytest.fixture +def jsonrpc_agent_card() -> AgentCard: + return AgentCard( + name='Test Agent', + description='A test agent', + url='http://test.agent.com/rpc', + version='1.0.0', + capabilities=AgentCapabilities(streaming=True), + skills=[], + default_input_modes=[], + default_output_modes=[], + preferred_transport='jsonrpc', + ) + + +@pytest.fixture +def grpc_agent_card() -> AgentCard: + return AgentCard( + name='Test Agent', + description='A test agent', + url='http://test.agent.com/rpc', + version='1.0.0', + capabilities=AgentCapabilities(streaming=True), + skills=[], + default_input_modes=[], + default_output_modes=[], + preferred_transport='grpc', + ) + + +@pytest.mark.asyncio +async def test_a2a_client_send_message( + mock_httpx_client: AsyncMock, jsonrpc_agent_card: AgentCard +): + client = A2AClient( + httpx_client=mock_httpx_client, agent_card=jsonrpc_agent_card + ) + + # Mock the underlying transport's send_message method + mock_response_task = Task( + id='task-123', + context_id='ctx-456', + status=TaskStatus(state=TaskState.completed), + ) + + client._transport.send_message = AsyncMock(return_value=mock_response_task) + + message = Message( + message_id='msg-123', + role=Role.user, + parts=[Part(root=TextPart(text='Hello'))], + ) + request = SendMessageRequest( + id='req-123', params=MessageSendParams(message=message) + ) + response = await client.send_message(request) + + assert response.root.result.id == 'task-123' + + +@pytest.mark.asyncio +async def test_a2a_grpc_client_get_task( + mock_grpc_stub: AsyncMock, grpc_agent_card: AgentCard +): + client = A2AGrpcClient(grpc_stub=mock_grpc_stub, agent_card=grpc_agent_card) + + mock_response_task = Task( + id='task-456', + context_id='ctx-789', + status=TaskStatus(state=TaskState.working), + ) + + client.get_task = AsyncMock(return_value=mock_response_task) + + params = TaskQueryParams(id='task-456') + response = await client.get_task(params) + + assert response.id == 'task-456' + client.get_task.assert_awaited_once_with(params) diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py new file mode 100644 index 00000000..46907ee6 --- /dev/null +++ b/tests/integration/test_client_server_integration.py @@ -0,0 +1,747 @@ +import asyncio + +from collections.abc import AsyncGenerator +from typing import NamedTuple +from unittest.mock import ANY, AsyncMock + +import grpc +import httpx +import pytest +import pytest_asyncio + +from grpc.aio import Channel + +from a2a.client.transports import JsonRpcTransport, RestTransport +from a2a.client.transports.base import ClientTransport +from a2a.client.transports.grpc import GrpcTransport +from a2a.grpc import a2a_pb2_grpc +from a2a.server.apps import A2AFastAPIApplication, A2ARESTFastAPIApplication +from a2a.server.request_handlers import GrpcHandler, RequestHandler +from a2a.types import ( + AgentCapabilities, + AgentCard, + AgentInterface, + GetTaskPushNotificationConfigParams, + Message, + MessageSendParams, + Part, + PushNotificationConfig, + Role, + Task, + TaskIdParams, + TaskPushNotificationConfig, + TaskQueryParams, + TaskState, + TaskStatus, + TaskStatusUpdateEvent, + TextPart, + TransportProtocol, +) + + +# --- Test Constants --- + +TASK_FROM_STREAM = Task( + id='task-123-stream', + context_id='ctx-456-stream', + status=TaskStatus(state=TaskState.completed), + kind='task', +) + +TASK_FROM_BLOCKING = Task( + id='task-789-blocking', + context_id='ctx-101-blocking', + status=TaskStatus(state=TaskState.completed), + kind='task', +) + +GET_TASK_RESPONSE = Task( + id='task-get-456', + context_id='ctx-get-789', + status=TaskStatus(state=TaskState.working), + kind='task', +) + +CANCEL_TASK_RESPONSE = Task( + id='task-cancel-789', + context_id='ctx-cancel-101', + status=TaskStatus(state=TaskState.canceled), + kind='task', +) + +CALLBACK_CONFIG = TaskPushNotificationConfig( + task_id='task-callback-123', + push_notification_config=PushNotificationConfig( + id='pnc-abc', url='http://callback.example.com', token='' + ), +) + +RESUBSCRIBE_EVENT = TaskStatusUpdateEvent( + task_id='task-resub-456', + context_id='ctx-resub-789', + status=TaskStatus(state=TaskState.working), + final=False, +) + + +# --- Test Fixtures --- + + +@pytest.fixture +def mock_request_handler() -> AsyncMock: + """Provides a mock RequestHandler for the server-side handlers.""" + handler = AsyncMock(spec=RequestHandler) + + # Configure on_message_send for non-streaming calls + handler.on_message_send.return_value = TASK_FROM_BLOCKING + + # Configure on_message_send_stream for streaming calls + async def stream_side_effect(*args, **kwargs): + yield TASK_FROM_STREAM + + handler.on_message_send_stream.side_effect = stream_side_effect + + # Configure other methods + handler.on_get_task.return_value = GET_TASK_RESPONSE + handler.on_cancel_task.return_value = CANCEL_TASK_RESPONSE + handler.on_set_task_push_notification_config.side_effect = ( + lambda params, context: params + ) + handler.on_get_task_push_notification_config.return_value = CALLBACK_CONFIG + + async def resubscribe_side_effect(*args, **kwargs): + yield RESUBSCRIBE_EVENT + + handler.on_resubscribe_to_task.side_effect = resubscribe_side_effect + + return handler + + +@pytest.fixture +def agent_card() -> AgentCard: + """Provides a sample AgentCard for tests.""" + return AgentCard( + name='Test Agent', + description='An agent for integration testing.', + url='http://testserver', + version='1.0.0', + capabilities=AgentCapabilities(streaming=True, push_notifications=True), + skills=[], + default_input_modes=['text/plain'], + default_output_modes=['text/plain'], + preferred_transport=TransportProtocol.jsonrpc, + supports_authenticated_extended_card=True, + additional_interfaces=[ + AgentInterface( + transport=TransportProtocol.http_json, url='http://testserver' + ), + AgentInterface( + transport=TransportProtocol.grpc, url='localhost:50051' + ), + ], + ) + + +class TransportSetup(NamedTuple): + """Holds the transport and handler for a given test.""" + + transport: ClientTransport + handler: AsyncMock + + +# --- HTTP/JSON-RPC/REST Setup --- + + +@pytest.fixture +def http_base_setup(mock_request_handler: AsyncMock, agent_card: AgentCard): + """A base fixture to patch the sse-starlette event loop issue.""" + from sse_starlette import sse + + sse.AppStatus.should_exit_event = asyncio.Event() + yield mock_request_handler, agent_card + + +@pytest.fixture +def jsonrpc_setup(http_base_setup) -> TransportSetup: + """Sets up the JsonRpcTransport and in-memory server.""" + mock_request_handler, agent_card = http_base_setup + app_builder = A2AFastAPIApplication( + agent_card, mock_request_handler, extended_agent_card=agent_card + ) + app = app_builder.build() + httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) + transport = JsonRpcTransport( + httpx_client=httpx_client, agent_card=agent_card + ) + return TransportSetup(transport=transport, handler=mock_request_handler) + + +@pytest.fixture +def rest_setup(http_base_setup) -> TransportSetup: + """Sets up the RestTransport and in-memory server.""" + mock_request_handler, agent_card = http_base_setup + app_builder = A2ARESTFastAPIApplication(agent_card, mock_request_handler) + app = app_builder.build() + httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) + transport = RestTransport(httpx_client=httpx_client, agent_card=agent_card) + return TransportSetup(transport=transport, handler=mock_request_handler) + + +# --- gRPC Setup --- + + +@pytest_asyncio.fixture +async def grpc_server_and_handler( + mock_request_handler: AsyncMock, agent_card: AgentCard +) -> AsyncGenerator[tuple[str, AsyncMock], None]: + """Creates and manages an in-process gRPC test server.""" + server = grpc.aio.server() + port = server.add_insecure_port('[::]:0') + server_address = f'localhost:{port}' + servicer = GrpcHandler(agent_card, mock_request_handler) + a2a_pb2_grpc.add_A2AServiceServicer_to_server(servicer, server) + await server.start() + yield server_address, mock_request_handler + await server.stop(0) + + +# --- The Integration Tests --- + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'transport_setup_fixture', + [ + pytest.param('jsonrpc_setup', id='JSON-RPC'), + pytest.param('rest_setup', id='REST'), + ], +) +async def test_http_transport_sends_message_streaming( + transport_setup_fixture: str, request +) -> None: + """ + Integration test for HTTP-based transports (JSON-RPC, REST) streaming. + """ + transport_setup: TransportSetup = request.getfixturevalue( + transport_setup_fixture + ) + transport = transport_setup.transport + handler = transport_setup.handler + + message_to_send = Message( + role=Role.user, + message_id='msg-integration-test', + parts=[Part(root=TextPart(text='Hello, integration test!'))], + ) + params = MessageSendParams(message=message_to_send) + + stream = transport.send_message_streaming(request=params) + first_event = await anext(stream) + + assert first_event.id == TASK_FROM_STREAM.id + assert first_event.context_id == TASK_FROM_STREAM.context_id + + handler.on_message_send_stream.assert_called_once() + call_args, _ = handler.on_message_send_stream.call_args + received_params: MessageSendParams = call_args[0] + + assert received_params.message.message_id == message_to_send.message_id + assert ( + received_params.message.parts[0].root.text + == message_to_send.parts[0].root.text + ) + + if hasattr(transport, 'close'): + await transport.close() + + +@pytest.mark.asyncio +async def test_grpc_transport_sends_message_streaming( + grpc_server_and_handler: tuple[str, AsyncMock], + agent_card: AgentCard, +) -> None: + """ + Integration test specifically for the gRPC transport streaming. + """ + server_address, handler = grpc_server_and_handler + agent_card.url = server_address + + def channel_factory(address: str) -> Channel: + return grpc.aio.insecure_channel(address) + + channel = channel_factory(server_address) + transport = GrpcTransport(channel=channel, agent_card=agent_card) + + message_to_send = Message( + role=Role.user, + message_id='msg-grpc-integration-test', + parts=[Part(root=TextPart(text='Hello, gRPC integration test!'))], + ) + params = MessageSendParams(message=message_to_send) + + stream = transport.send_message_streaming(request=params) + first_event = await anext(stream) + + assert first_event.id == TASK_FROM_STREAM.id + assert first_event.context_id == TASK_FROM_STREAM.context_id + + handler.on_message_send_stream.assert_called_once() + call_args, _ = handler.on_message_send_stream.call_args + received_params: MessageSendParams = call_args[0] + + assert received_params.message.message_id == message_to_send.message_id + assert ( + received_params.message.parts[0].root.text + == message_to_send.parts[0].root.text + ) + + await transport.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'transport_setup_fixture', + [ + pytest.param('jsonrpc_setup', id='JSON-RPC'), + pytest.param('rest_setup', id='REST'), + ], +) +async def test_http_transport_sends_message_blocking( + transport_setup_fixture: str, request +) -> None: + """ + Integration test for HTTP-based transports (JSON-RPC, REST) blocking. + """ + transport_setup: TransportSetup = request.getfixturevalue( + transport_setup_fixture + ) + transport = transport_setup.transport + handler = transport_setup.handler + + message_to_send = Message( + role=Role.user, + message_id='msg-integration-test-blocking', + parts=[Part(root=TextPart(text='Hello, blocking test!'))], + ) + params = MessageSendParams(message=message_to_send) + + result = await transport.send_message(request=params) + + assert result.id == TASK_FROM_BLOCKING.id + assert result.context_id == TASK_FROM_BLOCKING.context_id + + handler.on_message_send.assert_awaited_once() + call_args, _ = handler.on_message_send.call_args + received_params: MessageSendParams = call_args[0] + + assert received_params.message.message_id == message_to_send.message_id + assert ( + received_params.message.parts[0].root.text + == message_to_send.parts[0].root.text + ) + + if hasattr(transport, 'close'): + await transport.close() + + +@pytest.mark.asyncio +async def test_grpc_transport_sends_message_blocking( + grpc_server_and_handler: tuple[str, AsyncMock], + agent_card: AgentCard, +) -> None: + """ + Integration test specifically for the gRPC transport blocking. + """ + server_address, handler = grpc_server_and_handler + agent_card.url = server_address + + def channel_factory(address: str) -> Channel: + return grpc.aio.insecure_channel(address) + + channel = channel_factory(server_address) + transport = GrpcTransport(channel=channel, agent_card=agent_card) + + message_to_send = Message( + role=Role.user, + message_id='msg-grpc-integration-test-blocking', + parts=[Part(root=TextPart(text='Hello, gRPC blocking test!'))], + ) + params = MessageSendParams(message=message_to_send) + + result = await transport.send_message(request=params) + + assert result.id == TASK_FROM_BLOCKING.id + assert result.context_id == TASK_FROM_BLOCKING.context_id + + handler.on_message_send.assert_awaited_once() + call_args, _ = handler.on_message_send.call_args + received_params: MessageSendParams = call_args[0] + + assert received_params.message.message_id == message_to_send.message_id + assert ( + received_params.message.parts[0].root.text + == message_to_send.parts[0].root.text + ) + + await transport.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'transport_setup_fixture', + [ + pytest.param('jsonrpc_setup', id='JSON-RPC'), + pytest.param('rest_setup', id='REST'), + ], +) +async def test_http_transport_get_task( + transport_setup_fixture: str, request +) -> None: + transport_setup: TransportSetup = request.getfixturevalue( + transport_setup_fixture + ) + transport = transport_setup.transport + handler = transport_setup.handler + + params = TaskQueryParams(id=GET_TASK_RESPONSE.id) + result = await transport.get_task(request=params) + + assert result.id == GET_TASK_RESPONSE.id + handler.on_get_task.assert_awaited_once_with(params, ANY) + + if hasattr(transport, 'close'): + await transport.close() + + +@pytest.mark.asyncio +async def test_grpc_transport_get_task( + grpc_server_and_handler: tuple[str, AsyncMock], + agent_card: AgentCard, +) -> None: + server_address, handler = grpc_server_and_handler + agent_card.url = server_address + + def channel_factory(address: str) -> Channel: + return grpc.aio.insecure_channel(address) + + channel = channel_factory(server_address) + transport = GrpcTransport(channel=channel, agent_card=agent_card) + + params = TaskQueryParams(id=GET_TASK_RESPONSE.id) + result = await transport.get_task(request=params) + + assert result.id == GET_TASK_RESPONSE.id + handler.on_get_task.assert_awaited_once() + assert handler.on_get_task.call_args[0][0].id == GET_TASK_RESPONSE.id + + await transport.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'transport_setup_fixture', + [ + pytest.param('jsonrpc_setup', id='JSON-RPC'), + pytest.param('rest_setup', id='REST'), + ], +) +async def test_http_transport_cancel_task( + transport_setup_fixture: str, request +) -> None: + transport_setup: TransportSetup = request.getfixturevalue( + transport_setup_fixture + ) + transport = transport_setup.transport + handler = transport_setup.handler + + params = TaskIdParams(id=CANCEL_TASK_RESPONSE.id) + result = await transport.cancel_task(request=params) + + assert result.id == CANCEL_TASK_RESPONSE.id + handler.on_cancel_task.assert_awaited_once_with(params, ANY) + + if hasattr(transport, 'close'): + await transport.close() + + +@pytest.mark.asyncio +async def test_grpc_transport_cancel_task( + grpc_server_and_handler: tuple[str, AsyncMock], + agent_card: AgentCard, +) -> None: + server_address, handler = grpc_server_and_handler + agent_card.url = server_address + + def channel_factory(address: str) -> Channel: + return grpc.aio.insecure_channel(address) + + channel = channel_factory(server_address) + transport = GrpcTransport(channel=channel, agent_card=agent_card) + + params = TaskIdParams(id=CANCEL_TASK_RESPONSE.id) + result = await transport.cancel_task(request=params) + + assert result.id == CANCEL_TASK_RESPONSE.id + handler.on_cancel_task.assert_awaited_once() + assert handler.on_cancel_task.call_args[0][0].id == CANCEL_TASK_RESPONSE.id + + await transport.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'transport_setup_fixture', + [ + pytest.param('jsonrpc_setup', id='JSON-RPC'), + pytest.param('rest_setup', id='REST'), + ], +) +async def test_http_transport_set_task_callback( + transport_setup_fixture: str, request +) -> None: + transport_setup: TransportSetup = request.getfixturevalue( + transport_setup_fixture + ) + transport = transport_setup.transport + handler = transport_setup.handler + + params = CALLBACK_CONFIG + result = await transport.set_task_callback(request=params) + + assert result.task_id == CALLBACK_CONFIG.task_id + assert ( + result.push_notification_config.id + == CALLBACK_CONFIG.push_notification_config.id + ) + assert ( + result.push_notification_config.url + == CALLBACK_CONFIG.push_notification_config.url + ) + handler.on_set_task_push_notification_config.assert_awaited_once_with( + params, ANY + ) + + if hasattr(transport, 'close'): + await transport.close() + + +@pytest.mark.asyncio +async def test_grpc_transport_set_task_callback( + grpc_server_and_handler: tuple[str, AsyncMock], + agent_card: AgentCard, +) -> None: + server_address, handler = grpc_server_and_handler + agent_card.url = server_address + + def channel_factory(address: str) -> Channel: + return grpc.aio.insecure_channel(address) + + channel = channel_factory(server_address) + transport = GrpcTransport(channel=channel, agent_card=agent_card) + + params = CALLBACK_CONFIG + result = await transport.set_task_callback(request=params) + + assert result.task_id == CALLBACK_CONFIG.task_id + assert ( + result.push_notification_config.id + == CALLBACK_CONFIG.push_notification_config.id + ) + assert ( + result.push_notification_config.url + == CALLBACK_CONFIG.push_notification_config.url + ) + handler.on_set_task_push_notification_config.assert_awaited_once() + assert ( + handler.on_set_task_push_notification_config.call_args[0][0].task_id + == CALLBACK_CONFIG.task_id + ) + + await transport.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'transport_setup_fixture', + [ + pytest.param('jsonrpc_setup', id='JSON-RPC'), + pytest.param('rest_setup', id='REST'), + ], +) +async def test_http_transport_get_task_callback( + transport_setup_fixture: str, request +) -> None: + transport_setup: TransportSetup = request.getfixturevalue( + transport_setup_fixture + ) + transport = transport_setup.transport + handler = transport_setup.handler + + params = GetTaskPushNotificationConfigParams( + id=CALLBACK_CONFIG.task_id, + push_notification_config_id=CALLBACK_CONFIG.push_notification_config.id, + ) + result = await transport.get_task_callback(request=params) + + assert result.task_id == CALLBACK_CONFIG.task_id + assert ( + result.push_notification_config.id + == CALLBACK_CONFIG.push_notification_config.id + ) + assert ( + result.push_notification_config.url + == CALLBACK_CONFIG.push_notification_config.url + ) + handler.on_get_task_push_notification_config.assert_awaited_once_with( + params, ANY + ) + + if hasattr(transport, 'close'): + await transport.close() + + +@pytest.mark.asyncio +async def test_grpc_transport_get_task_callback( + grpc_server_and_handler: tuple[str, AsyncMock], + agent_card: AgentCard, +) -> None: + server_address, handler = grpc_server_and_handler + agent_card.url = server_address + + def channel_factory(address: str) -> Channel: + return grpc.aio.insecure_channel(address) + + channel = channel_factory(server_address) + transport = GrpcTransport(channel=channel, agent_card=agent_card) + + params = GetTaskPushNotificationConfigParams( + id=CALLBACK_CONFIG.task_id, + push_notification_config_id=CALLBACK_CONFIG.push_notification_config.id, + ) + result = await transport.get_task_callback(request=params) + + assert result.task_id == CALLBACK_CONFIG.task_id + assert ( + result.push_notification_config.id + == CALLBACK_CONFIG.push_notification_config.id + ) + assert ( + result.push_notification_config.url + == CALLBACK_CONFIG.push_notification_config.url + ) + handler.on_get_task_push_notification_config.assert_awaited_once() + assert ( + handler.on_get_task_push_notification_config.call_args[0][0].id + == CALLBACK_CONFIG.task_id + ) + + await transport.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'transport_setup_fixture', + [ + pytest.param('jsonrpc_setup', id='JSON-RPC'), + pytest.param('rest_setup', id='REST'), + ], +) +async def test_http_transport_resubscribe( + transport_setup_fixture: str, request +) -> None: + transport_setup: TransportSetup = request.getfixturevalue( + transport_setup_fixture + ) + transport = transport_setup.transport + handler = transport_setup.handler + + params = TaskIdParams(id=RESUBSCRIBE_EVENT.task_id) + stream = transport.resubscribe(request=params) + first_event = await anext(stream) + + assert first_event.task_id == RESUBSCRIBE_EVENT.task_id + handler.on_resubscribe_to_task.assert_called_once_with(params, ANY) + + if hasattr(transport, 'close'): + await transport.close() + + +@pytest.mark.asyncio +async def test_grpc_transport_resubscribe( + grpc_server_and_handler: tuple[str, AsyncMock], + agent_card: AgentCard, +) -> None: + server_address, handler = grpc_server_and_handler + agent_card.url = server_address + + def channel_factory(address: str) -> Channel: + return grpc.aio.insecure_channel(address) + + channel = channel_factory(server_address) + transport = GrpcTransport(channel=channel, agent_card=agent_card) + + params = TaskIdParams(id=RESUBSCRIBE_EVENT.task_id) + stream = transport.resubscribe(request=params) + first_event = await anext(stream) + + assert first_event.task_id == RESUBSCRIBE_EVENT.task_id + handler.on_resubscribe_to_task.assert_called_once() + assert ( + handler.on_resubscribe_to_task.call_args[0][0].id + == RESUBSCRIBE_EVENT.task_id + ) + + await transport.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'transport_setup_fixture', + [ + pytest.param('jsonrpc_setup', id='JSON-RPC'), + pytest.param('rest_setup', id='REST'), + ], +) +async def test_http_transport_get_card( + transport_setup_fixture: str, request, agent_card: AgentCard +) -> None: + transport_setup: TransportSetup = request.getfixturevalue( + transport_setup_fixture + ) + transport = transport_setup.transport + + # The transport starts with a minimal card, get_card() fetches the full one + transport.agent_card.supports_authenticated_extended_card = True + result = await transport.get_card() + + assert result.name == agent_card.name + assert transport.agent_card.name == agent_card.name + assert transport._needs_extended_card is False + + if hasattr(transport, 'close'): + await transport.close() + + +@pytest.mark.asyncio +async def test_grpc_transport_get_card( + grpc_server_and_handler: tuple[str, AsyncMock], + agent_card: AgentCard, +) -> None: + server_address, _ = grpc_server_and_handler + agent_card.url = server_address + + def channel_factory(address: str) -> Channel: + return grpc.aio.insecure_channel(address) + + channel = channel_factory(server_address) + transport = GrpcTransport(channel=channel, agent_card=agent_card) + + # The transport starts with a minimal card, get_card() fetches the full one + transport.agent_card.supports_authenticated_extended_card = True + result = await transport.get_card() + + assert result.name == agent_card.name + assert transport.agent_card.name == agent_card.name + assert transport._needs_extended_card is False + + await transport.close() diff --git a/tests/server/apps/rest/test_fastapi_app.py b/tests/server/apps/rest/test_fastapi_app.py new file mode 100644 index 00000000..49d2121c --- /dev/null +++ b/tests/server/apps/rest/test_fastapi_app.py @@ -0,0 +1,145 @@ +import logging + +from unittest.mock import MagicMock + +import pytest + +from fastapi import FastAPI +from google.protobuf import json_format +from httpx import ASGITransport, AsyncClient + +from a2a.grpc import a2a_pb2 +from a2a.server.apps.rest.fastapi_app import A2ARESTFastAPIApplication +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.types import ( + AgentCard, + Message, + Part, + Role, + Task, + TaskState, + TaskStatus, + TextPart, +) + + +logger = logging.getLogger(__name__) + + +@pytest.fixture +async def agent_card() -> AgentCard: + mock_agent_card = MagicMock(spec=AgentCard) + mock_agent_card.url = 'http://mockurl.com' + mock_agent_card.supports_authenticated_extended_card = False + return mock_agent_card + + +@pytest.fixture +async def request_handler() -> RequestHandler: + return MagicMock(spec=RequestHandler) + + +@pytest.fixture +async def app( + agent_card: AgentCard, request_handler: RequestHandler +) -> FastAPI: + """Builds the FastAPI application for testing.""" + + return A2ARESTFastAPIApplication(agent_card, request_handler).build( + agent_card_url='/well-known/agent.json', rpc_url='' + ) + + +@pytest.fixture +async def client(app: FastAPI) -> AsyncClient: + return AsyncClient( + transport=ASGITransport(app=app), base_url='http://testapp' + ) + + +@pytest.mark.anyio +async def test_send_message_success_message( + client: AsyncClient, request_handler: MagicMock +) -> None: + expected_response = a2a_pb2.SendMessageResponse( + msg=a2a_pb2.Message( + message_id='test', + role=a2a_pb2.Role.ROLE_AGENT, + content=[ + a2a_pb2.Part(text='response message'), + ], + ), + ) + request_handler.on_message_send.return_value = Message( + message_id='test', + role=Role.agent, + parts=[Part(TextPart(text='response message'))], + ) + + request = a2a_pb2.SendMessageRequest( + request=a2a_pb2.Message(), + configuration=a2a_pb2.SendMessageConfiguration(), + ) + # To see log output, run pytest with '--log-cli=true --log-cli-level=INFO' + response = await client.post( + '/v1/message:send', json=json_format.MessageToDict(request) + ) + # request should always be successful + response.raise_for_status() + + actual_response = a2a_pb2.SendMessageResponse() + json_format.Parse(response.text, actual_response) + assert expected_response == actual_response + + +@pytest.mark.anyio +async def test_send_message_success_task( + client: AsyncClient, request_handler: MagicMock +) -> None: + expected_response = a2a_pb2.SendMessageResponse( + task=a2a_pb2.Task( + id='test_task_id', + context_id='test_context_id', + status=a2a_pb2.TaskStatus( + state=a2a_pb2.TaskState.TASK_STATE_COMPLETED, + update=a2a_pb2.Message( + message_id='test', + role=a2a_pb2.ROLE_AGENT, + content=[ + a2a_pb2.Part(text='response task message'), + ], + ), + ), + ), + ) + request_handler.on_message_send.return_value = Task( + id='test_task_id', + context_id='test_context_id', + status=TaskStatus( + state=TaskState.completed, + message=Message( + message_id='test', + role=Role.agent, + parts=[Part(TextPart(text='response task message'))], + ), + ), + ) + + request = a2a_pb2.SendMessageRequest( + request=a2a_pb2.Message(), + configuration=a2a_pb2.SendMessageConfiguration(), + ) + # To see log output, run pytest with '--log-cli=true --log-cli-level=INFO' + response = await client.post( + '/v1/message:send', json=json_format.MessageToDict(request) + ) + # request should always be successful + response.raise_for_status() + + actual_response = a2a_pb2.SendMessageResponse() + json_format.Parse(response.text, actual_response) + assert expected_response == actual_response + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/server/request_handlers/test_grpc_handler.py b/tests/server/request_handlers/test_grpc_handler.py index eb0a3459..8bd65e02 100644 --- a/tests/server/request_handlers/test_grpc_handler.py +++ b/tests/server/request_handlers/test_grpc_handler.py @@ -303,7 +303,7 @@ def side_effect(request, context: ServerCallContext): context.activated_extensions.add('baz') return types.Task( id='task-1', - contextId='ctx-1', + context_id='ctx-1', status=types.TaskStatus(state=types.TaskState.completed), ) @@ -338,9 +338,9 @@ async def test_send_message_with_comma_separated_extensions( (HTTP_EXTENSION_HEADER, 'baz , bar'), ) mock_request_handler.on_message_send.return_value = types.Message( - messageId='1', + message_id='1', role=types.Role.agent, - parts=[types.TextPart(text='test')], + parts=[types.Part(root=types.TextPart(text='test'))], ) await grpc_handler.SendMessage( @@ -368,7 +368,7 @@ async def side_effect(request, context: ServerCallContext): context.activated_extensions.add('baz') yield types.Task( id='task-1', - contextId='ctx-1', + context_id='ctx-1', status=types.TaskStatus(state=types.TaskState.working), ) diff --git a/tests/utils/test_proto_utils.py b/tests/utils/test_proto_utils.py index e5dcac6c..83848c24 100644 --- a/tests/utils/test_proto_utils.py +++ b/tests/utils/test_proto_utils.py @@ -237,9 +237,7 @@ def test_task_id_params_from_proto_invalid_name(self): assert isinstance(exc_info.value.error, types.InvalidParamsError) def test_task_push_config_from_proto_invalid_parent(self): - request = a2a_pb2.CreateTaskPushNotificationConfigRequest( - parent='invalid-parent' - ) + request = a2a_pb2.TaskPushNotificationConfig(name='invalid-name-format') with pytest.raises(ServerError) as exc_info: proto_utils.FromProto.task_push_notification_config(request) assert isinstance(exc_info.value.error, types.InvalidParamsError) diff --git a/uv.lock b/uv.lock index 4f4427cd..cbc718ad 100644 --- a/uv.lock +++ b/uv.lock @@ -11,8 +11,10 @@ name = "a2a-sdk" source = { editable = "." } dependencies = [ { name = "fastapi" }, + { name = "google-api-core" }, { name = "httpx" }, { name = "httpx-sse" }, + { name = "protobuf" }, { name = "pydantic" }, { name = "sse-starlette" }, { name = "starlette" }, @@ -23,11 +25,9 @@ encryption = [ { name = "cryptography" }, ] grpc = [ - { name = "google-api-core" }, { name = "grpcio" }, { name = "grpcio-reflection" }, { name = "grpcio-tools" }, - { name = "protobuf" }, ] mysql = [ { name = "sqlalchemy", extra = ["aiomysql", "asyncio"] }, @@ -60,6 +60,7 @@ dev = [ { name = "pyupgrade" }, { name = "respx" }, { name = "ruff" }, + { name = "trio" }, { name = "types-protobuf" }, { name = "types-requests" }, { name = "uv-dynamic-versioning" }, @@ -68,8 +69,8 @@ dev = [ [package.metadata] requires-dist = [ { name = "cryptography", marker = "extra == 'encryption'", specifier = ">=43.0.0" }, - { name = "fastapi", specifier = ">=0.115.2" }, - { name = "google-api-core", marker = "extra == 'grpc'", specifier = ">=1.26.0" }, + { name = "fastapi", specifier = ">=0.116.1" }, + { name = "google-api-core", specifier = ">=1.26.0" }, { name = "grpcio", marker = "extra == 'grpc'", specifier = ">=1.60" }, { name = "grpcio-reflection", marker = "extra == 'grpc'", specifier = ">=1.7.0" }, { name = "grpcio-tools", marker = "extra == 'grpc'", specifier = ">=1.60" }, @@ -77,7 +78,7 @@ requires-dist = [ { name = "httpx-sse", specifier = ">=0.4.0" }, { name = "opentelemetry-api", marker = "extra == 'telemetry'", specifier = ">=1.33.0" }, { name = "opentelemetry-sdk", marker = "extra == 'telemetry'", specifier = ">=1.33.0" }, - { name = "protobuf", marker = "extra == 'grpc'", specifier = "==5.29.5" }, + { name = "protobuf", specifier = "==5.29.5" }, { name = "pydantic", specifier = ">=2.11.3" }, { name = "sqlalchemy", extras = ["aiomysql", "aiosqlite", "asyncio", "postgresql-asyncpg"], marker = "extra == 'sql'", specifier = ">=2.0.0" }, { name = "sqlalchemy", extras = ["aiomysql", "asyncio"], marker = "extra == 'mysql'", specifier = ">=2.0.0" }, @@ -102,6 +103,7 @@ dev = [ { name = "pyupgrade" }, { name = "respx", specifier = ">=0.20.2" }, { name = "ruff", specifier = ">=0.11.6" }, + { name = "trio" }, { name = "types-protobuf" }, { name = "types-requests" }, { name = "uv-dynamic-versioning", specifier = ">=0.8.2" }, @@ -216,6 +218,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c8/a4/cec76b3389c4c5ff66301cd100fe88c318563ec8a520e0b2e792b5b84972/asyncpg-0.30.0-cp313-cp313-win_amd64.whl", hash = "sha256:f59b430b8e27557c3fb9869222559f7417ced18688375825f8f12302c34e915e", size = 621623, upload-time = "2024-10-20T00:30:09.024Z" }, ] +[[package]] +name = "attrs" +version = "25.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5a/b0/1367933a8532ee6ff8d63537de4f1177af4bff9f3e829baf7331f595bb24/attrs-25.3.0.tar.gz", hash = "sha256:75d7cefc7fb576747b2c81b4442d4d4a1ce0900973527c011d1030fd3bf4af1b", size = 812032, upload-time = "2025-03-13T11:10:22.779Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/77/06/bb80f5f86020c4551da315d78b3ab75e8228f89f0162f2c3a819e407941a/attrs-25.3.0-py3-none-any.whl", hash = "sha256:427318ce031701fea540783410126f03899a97ffc6f61596ad581ac2e40e3bc3", size = 63815, upload-time = "2025-03-13T11:10:21.14Z" }, +] + [[package]] name = "autoflake" version = "2.3.1" @@ -601,16 +612,16 @@ wheels = [ [[package]] name = "fastapi" -version = "0.115.13" +version = "0.116.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pydantic" }, { name = "starlette" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/20/64/ec0788201b5554e2a87c49af26b77a4d132f807a0fa9675257ac92c6aa0e/fastapi-0.115.13.tar.gz", hash = "sha256:55d1d25c2e1e0a0a50aceb1c8705cd932def273c102bff0b1c1da88b3c6eb307", size = 295680, upload-time = "2025-06-17T11:49:45.575Z" } +sdist = { url = "https://files.pythonhosted.org/packages/78/d7/6c8b3bfe33eeffa208183ec037fee0cce9f7f024089ab1c5d12ef04bd27c/fastapi-0.116.1.tar.gz", hash = "sha256:ed52cbf946abfd70c5a0dccb24673f0670deeb517a88b3544d03c2a6bf283143", size = 296485, upload-time = "2025-07-11T16:22:32.057Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/59/4a/e17764385382062b0edbb35a26b7cf76d71e27e456546277a42ba6545c6e/fastapi-0.115.13-py3-none-any.whl", hash = "sha256:0a0cab59afa7bab22f5eb347f8c9864b681558c278395e94035a741fc10cd865", size = 95315, upload-time = "2025-06-17T11:49:44.106Z" }, + { url = "https://files.pythonhosted.org/packages/e5/47/d63c60f59a59467fda0f93f46335c9d18526d7071f025cb5b89d5353ea42/fastapi-0.116.1-py3-none-any.whl", hash = "sha256:c46ac7c312df840f0c9e220f7964bada936781bc4e2e6eb71f1c4d7553786565", size = 95631, upload-time = "2025-07-11T16:22:30.485Z" }, ] [[package]] @@ -1211,6 +1222,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1a/89/267b0af1b1d0ba828f0e60642b6a5116ac1fd917cde7fc02821627029bd1/opentelemetry_semantic_conventions-0.55b1-py3-none-any.whl", hash = "sha256:5da81dfdf7d52e3d37f8fe88d5e771e191de924cfff5f550ab0b8f7b2409baed", size = 196223, upload-time = "2025-06-10T08:55:17.638Z" }, ] +[[package]] +name = "outcome" +version = "1.3.0.post0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/98/df/77698abfac98571e65ffeb0c1fba8ffd692ab8458d617a0eed7d9a8d38f2/outcome-1.3.0.post0.tar.gz", hash = "sha256:9dcf02e65f2971b80047b377468e72a268e15c0af3cf1238e6ff14f7f91143b8", size = 21060, upload-time = "2023-10-26T04:26:04.361Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/55/8b/5ab7257531a5d830fc8000c476e63c935488d74609b50f9384a643ec0a62/outcome-1.3.0.post0-py2.py3-none-any.whl", hash = "sha256:e771c5ce06d1415e356078d3bdd68523f284b4ce5419828922b6871e65eda82b", size = 10692, upload-time = "2023-10-26T04:26:02.532Z" }, +] + [[package]] name = "packaging" version = "25.0" @@ -1666,6 +1689,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235, upload-time = "2024-02-25T23:20:01.196Z" }, ] +[[package]] +name = "sortedcontainers" +version = "2.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e8/c4/ba2f8066cceb6f23394729afe52f3bf7adec04bf9ed2c820b39e19299111/sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88", size = 30594, upload-time = "2021-05-16T22:03:42.897Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0", size = 29575, upload-time = "2021-05-16T22:03:41.177Z" }, +] + [[package]] name = "sqlalchemy" version = "2.0.41" @@ -1810,6 +1842,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/bd/75/8539d011f6be8e29f339c42e633aae3cb73bffa95dd0f9adec09b9c58e85/tomlkit-0.13.3-py3-none-any.whl", hash = "sha256:c89c649d79ee40629a9fda55f8ace8c6a1b42deb912b2a8fd8d942ddadb606b0", size = 38901, upload-time = "2025-06-05T07:13:43.546Z" }, ] +[[package]] +name = "trio" +version = "0.30.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "cffi", marker = "implementation_name != 'pypy' and os_name == 'nt'" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, + { name = "idna" }, + { name = "outcome" }, + { name = "sniffio" }, + { name = "sortedcontainers" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/01/c1/68d582b4d3a1c1f8118e18042464bb12a7c1b75d64d75111b297687041e3/trio-0.30.0.tar.gz", hash = "sha256:0781c857c0c81f8f51e0089929a26b5bb63d57f927728a5586f7e36171f064df", size = 593776, upload-time = "2025-04-21T00:48:19.507Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/69/8e/3f6dfda475ecd940e786defe6df6c500734e686c9cd0a0f8ef6821e9b2f2/trio-0.30.0-py3-none-any.whl", hash = "sha256:3bf4f06b8decf8d3cf00af85f40a89824669e2d033bb32469d34840edcfc22a5", size = 499194, upload-time = "2025-04-21T00:48:17.167Z" }, +] + [[package]] name = "trove-classifiers" version = "2025.5.9.12"