From c7f4eb033f126334dac1e544de5bc75286968cde Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Tue, 29 Jul 2025 19:34:25 +0000 Subject: [PATCH 01/12] Gemini authored: refactor clients into BaseClient + ClientTransport --- src/a2a/client/__init__.py | 84 +-- src/a2a/client/base_client.py | 231 +++++++ src/a2a/client/client_factory.py | 159 +++-- src/a2a/client/grpc_client.py | 544 ---------------- src/a2a/client/jsonrpc_client.py | 851 -------------------------- src/a2a/client/rest_client.py | 833 ------------------------- src/a2a/client/transports/__init__.py | 19 + src/a2a/client/transports/base.py | 101 +++ src/a2a/client/transports/grpc.py | 190 ++++++ src/a2a/client/transports/jsonrpc.py | 337 ++++++++++ src/a2a/client/transports/rest.py | 383 ++++++++++++ 11 files changed, 1352 insertions(+), 2380 deletions(-) create mode 100644 src/a2a/client/base_client.py delete mode 100644 src/a2a/client/grpc_client.py delete mode 100644 src/a2a/client/jsonrpc_client.py delete mode 100644 src/a2a/client/rest_client.py create mode 100644 src/a2a/client/transports/__init__.py create mode 100644 src/a2a/client/transports/base.py create mode 100644 src/a2a/client/transports/grpc.py create mode 100644 src/a2a/client/transports/jsonrpc.py create mode 100644 src/a2a/client/transports/rest.py diff --git a/src/a2a/client/__init__.py b/src/a2a/client/__init__.py index 6e88a03d..96d27033 100644 --- a/src/a2a/client/__init__.py +++ b/src/a2a/client/__init__.py @@ -11,7 +11,6 @@ from a2a.client.client import Client, ClientConfig, ClientEvent, Consumer from a2a.client.client_factory import ( ClientFactory, - ClientProducer, minimal_agent_card, ) from a2a.client.errors import ( @@ -21,77 +20,28 @@ A2AClientTimeoutError, ) from a2a.client.helpers import create_text_message_object -from a2a.client.jsonrpc_client import ( - A2AClient, - JsonRpcClient, - JsonRpcTransportClient, - NewJsonRpcClient, -) from a2a.client.middleware import ClientCallContext, ClientCallInterceptor -from a2a.client.rest_client import ( - NewRestfulClient, - RestClient, - RestTransportClient, -) logger = logging.getLogger(__name__) -try: - from a2a.client.grpc_client import ( - GrpcClient, - GrpcTransportClient, # type: ignore - NewGrpcClient, - ) -except ImportError as e: - _original_error = e - logger.debug( - 'A2AGrpcClient not loaded. This is expected if gRPC dependencies are not installed. Error: %s', - _original_error, - ) - - class GrpcTransportClient: # type: ignore - """Placeholder for A2AGrpcClient when dependencies are not installed.""" - - def __init__(self, *args, **kwargs): - raise ImportError( - 'To use A2AGrpcClient, its dependencies must be installed. ' - 'You can install them with \'pip install "a2a-sdk[grpc]"\'' - ) from _original_error -finally: - # For backward compatability define this alias. This will be deprecated in - # a future release. - A2AGrpcClient = GrpcTransportClient # type: ignore - __all__ = [ - 'A2ACardResolver', - 'A2AClient', # for backward compatability - 'A2AClientError', - 'A2AClientHTTPError', - 'A2AClientJSONError', - 'A2AClientTimeoutError', - 'A2AGrpcClient', # for backward compatability - 'AuthInterceptor', - 'Client', - 'ClientCallContext', - 'ClientCallInterceptor', - 'ClientConfig', - 'ClientEvent', - 'ClientFactory', - 'ClientProducer', - 'Consumer', - 'CredentialService', - 'GrpcClient', - 'GrpcTransportClient', - 'InMemoryContextCredentialStore', - 'JsonRpcClient', - 'JsonRpcTransportClient', - 'NewGrpcClient', - 'NewJsonRpcClient', - 'NewRestfulClient', - 'RestClient', - 'RestTransportClient', - 'create_text_message_object', - 'minimal_agent_card', + "A2ACardResolver", + "A2AClientError", + "A2AClientHTTPError", + "A2AClientJSONError", + "A2AClientTimeoutError", + "AuthInterceptor", + "Client", + "ClientCallContext", + "ClientCallInterceptor", + "ClientConfig", + "ClientEvent", + "ClientFactory", + "Consumer", + "CredentialService", + "InMemoryContextCredentialStore", + "create_text_message_object", + "minimal_agent_card", ] diff --git a/src/a2a/client/base_client.py b/src/a2a/client/base_client.py new file mode 100644 index 00000000..01600366 --- /dev/null +++ b/src/a2a/client/base_client.py @@ -0,0 +1,231 @@ +from collections.abc import AsyncIterator + +from a2a.client.client import ( + Client, + ClientCallContext, + ClientConfig, + ClientEvent, + Consumer, +) +from a2a.client.client_task_manager import ClientTaskManager +from a2a.client.errors import A2AClientInvalidStateError +from a2a.client.middleware import ClientCallInterceptor +from a2a.client.transports.base import ClientTransport +from a2a.types import ( + AgentCard, + GetTaskPushNotificationConfigParams, + Message, + MessageSendConfiguration, + MessageSendParams, + Task, + TaskArtifactUpdateEvent, + TaskIdParams, + TaskPushNotificationConfig, + TaskQueryParams, + TaskStatusUpdateEvent, +) + + +class BaseClient(Client): + """Base implementation of the A2A client, containing transport-independent logic.""" + + def __init__( + self, + card: AgentCard, + config: ClientConfig, + transport: ClientTransport, + consumers: list[Consumer], + middleware: list[ClientCallInterceptor], + ): + super().__init__(consumers, middleware) + self._card = card + self._config = config + self._transport = transport + + async def send_message( + self, + request: Message, + *, + context: ClientCallContext | None = None, + ) -> AsyncIterator[ClientEvent | Message]: + """Sends a message to the agent. + + This method handles both streaming and non-streaming (polling) interactions + based on the client configuration and agent capabilities. It will yield + events as they are received from the agent. + + Args: + request: The message to send to the agent. + context: The client call context. + + Yields: + An async iterator of `ClientEvent` or a final `Message` response. + """ + config = MessageSendConfiguration( + accepted_output_modes=self._config.accepted_output_modes, + blocking=not self._config.polling, + push_notification_config=( + self._config.push_notification_configs[0] + if self._config.push_notification_configs + else None + ), + ) + params = MessageSendParams(message=request, configuration=config) + + if not self._config.streaming or not self._card.capabilities.streaming: + response = await self._transport.send_message(params, context=context) + result = ( + (response, None) if isinstance(response, Task) else response + ) + await self.consume(result, self._card) + yield result + return + + tracker = ClientTaskManager() + stream = self._transport.send_message_streaming(params, context=context) + + first_event = await anext(stream) + if isinstance(first_event, Message): + await self.consume(first_event, self._card) + yield first_event + return + + yield await self._process_response(tracker, first_event) + + async for event in stream: + yield await self._process_response(tracker, event) + + async def _process_response( + self, + tracker: ClientTaskManager, + event: Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent, + ) -> ClientEvent: + if isinstance(event, Message): + raise A2AClientInvalidStateError( + "received a streamed Message from server after first response; this is not supported" + ) + await tracker.process(event) + task = tracker.get_task_or_raise() + update = None if isinstance(event, Task) else event + client_event = (task, update) + await self.consume(client_event, self._card) + return client_event + + async def get_task( + self, + request: TaskQueryParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + """Retrieves the current state and history of a specific task. + + Args: + request: The `TaskQueryParams` object specifying the task ID. + context: The client call context. + + Returns: + A `Task` object representing the current state of the task. + """ + return await self._transport.get_task(request, context=context) + + async def cancel_task( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + """Requests the agent to cancel a specific task. + + Args: + request: The `TaskIdParams` object specifying the task ID. + context: The client call context. + + Returns: + A `Task` object containing the updated task status. + """ + return await self._transport.cancel_task(request, context=context) + + async def set_task_callback( + self, + request: TaskPushNotificationConfig, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Sets or updates the push notification configuration for a specific task. + + Args: + request: The `TaskPushNotificationConfig` object with the new configuration. + context: The client call context. + + Returns: + The created or updated `TaskPushNotificationConfig` object. + """ + return await self._transport.set_task_callback(request, context=context) + + async def get_task_callback( + self, + request: GetTaskPushNotificationConfigParams, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Retrieves the push notification configuration for a specific task. + + Args: + request: The `GetTaskPushNotificationConfigParams` object specifying the task. + context: The client call context. + + Returns: + A `TaskPushNotificationConfig` object containing the configuration. + """ + return await self._transport.get_task_callback(request, context=context) + + async def resubscribe( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> AsyncIterator[ClientEvent]: + """Resubscribes to a task's event stream. + + This is only available if both the client and server support streaming. + + Args: + request: Parameters to identify the task to resubscribe to. + context: The client call context. + + Yields: + An async iterator of `ClientEvent` objects. + + Raises: + NotImplementedError: If streaming is not supported by the client or server. + """ + if not self._config.streaming or not self._card.capabilities.streaming: + raise NotImplementedError( + "client and/or server do not support resubscription." + ) + + tracker = ClientTaskManager() + async for event in self._transport.resubscribe(request, context=context): + yield await self._process_response(tracker, event) + + async def get_card( + self, *, context: ClientCallContext | None = None + ) -> AgentCard: + """Retrieves the agent's card. + + This will fetch the authenticated card if necessary and update the + client's internal state with the new card. + + Args: + context: The client call context. + + Returns: + The `AgentCard` for the agent. + """ + card = await self._transport.get_card(context=context) + self._card = card + return card + + async def close(self) -> None: + """Closes the underlying transport.""" + await self._transport.close() diff --git a/src/a2a/client/client_factory.py b/src/a2a/client/client_factory.py index f47be58a..a97f62d2 100644 --- a/src/a2a/client/client_factory.py +++ b/src/a2a/client/client_factory.py @@ -4,16 +4,14 @@ from collections.abc import Callable -from a2a.client.client import Client, ClientConfig, Consumer - +import httpx -try: - from a2a.client.grpc_client import NewGrpcClient -except ImportError: - NewGrpcClient = None -from a2a.client.jsonrpc_client import NewJsonRpcClient +from a2a.client.base_client import BaseClient +from a2a.client.client import Client, ClientConfig, Consumer from a2a.client.middleware import ClientCallInterceptor -from a2a.client.rest_client import NewRestfulClient +from a2a.client.transports.base import ClientTransport +from a2a.client.transports.jsonrpc import JsonRpcTransport +from a2a.client.transports.rest import RestTransport from a2a.types import ( AgentCapabilities, AgentCard, @@ -22,34 +20,25 @@ ) +try: + from a2a.client.transports.grpc import GrpcTransport + from a2a.grpc import a2a_pb2_grpc +except ImportError: + GrpcTransport = None + a2a_pb2_grpc = None + + logger = logging.getLogger(__name__) -ClientProducer = Callable[ - [ - AgentCard, - ClientConfig, - list[Consumer], - list[ClientCallInterceptor], - ], - Client, + +TransportProducer = Callable[ + [AgentCard, ClientConfig, list[ClientCallInterceptor]], + ClientTransport, ] class ClientFactory: - """ClientFactory is used to generate the appropriate client for the agent. - - The factory is configured with a `ClientConfig` and optionally a list of - `Consumer`s to use for all generated `Client`s. The expected use is: - - factory = ClientFactory(config, consumers) - # Optionally register custom client implementations - factory.register('my_customer_transport', NewCustomTransportClient) - # Then with an agent card make a client with additional consumers and - # interceptors - client = factory.create(card, additional_consumers, interceptors) - # Now the client can be used the same regardless of transport and - # aligns client config with server capabilities. - """ + """ClientFactory is used to generate the appropriate client for the agent.""" def __init__( self, @@ -60,23 +49,41 @@ def __init__( consumers = [] self._config = config self._consumers = consumers - self._registry: dict[str, ClientProducer] = {} - # By default register the 3 core transports if in the config. - # Can be overridden with custom clients via the register method. - if TransportProtocol.jsonrpc in self._config.supported_transports: - self._registry[TransportProtocol.jsonrpc] = NewJsonRpcClient - if TransportProtocol.http_json in self._config.supported_transports: - self._registry[TransportProtocol.http_json] = NewRestfulClient - if TransportProtocol.grpc in self._config.supported_transports: - if NewGrpcClient is None: - raise ImportError( - 'To use GrpcClient, its dependencies must be installed. ' - 'You can install them with \'pip install "a2a-sdk[grpc]"\'' - ) - self._registry[TransportProtocol.grpc] = NewGrpcClient - - def register(self, label: str, generator: ClientProducer) -> None: - """Register a new client producer for a given transport label.""" + self._registry: dict[str, TransportProducer] = {} + self._register_defaults() + + def _register_defaults(self) -> None: + self.register( + TransportProtocol.jsonrpc, + lambda card, config, interceptors: JsonRpcTransport( + config.httpx_client or httpx.AsyncClient(), + card, + card.url, + interceptors, + ), + ) + self.register( + TransportProtocol.http_json, + lambda card, config, interceptors: RestTransport( + config.httpx_client or httpx.AsyncClient(), + card, + card.url, + interceptors, + ), + ) + if GrpcTransport: + self.register( + TransportProtocol.grpc, + lambda card, config, interceptors: GrpcTransport( + a2a_pb2_grpc.A2AServiceStub( + config.grpc_channel_factory(card.url) + ), + card, + ), + ) + + def register(self, label: str, generator: TransportProducer) -> None: + """Register a new transport producer for a given transport label.""" self._registry[label] = generator def create( @@ -85,64 +92,46 @@ def create( consumers: list[Consumer] | None = None, interceptors: list[ClientCallInterceptor] | None = None, ) -> Client: - """Create a new `Client` for the provided `AgentCard`. - - Args: - card: An `AgentCard` defining the characteristics of the agent. - consumers: A list of `Consumer` methods to pass responses to. - interceptors: A list of interceptors to use for each request. These - are used for things like attaching credentials or http headers - to all outbound requests. - - Returns: - A `Client` object. - - Raises: - If there is no valid matching of the client configuration with the - server configuration, a `ValueError` is raised. - """ - # Determine preferential transport + """Create a new `Client` for the provided `AgentCard`.""" server_set = [card.preferred_transport or TransportProtocol.jsonrpc] if card.additional_interfaces: server_set.extend([x.transport for x in card.additional_interfaces]) client_set = self._config.supported_transports or [ TransportProtocol.jsonrpc ] - transport = None - # Two options, use the client ordering or the server ordering. + transport_protocol = None if self._config.use_client_preference: for x in client_set: if x in server_set: - transport = x + transport_protocol = x break else: for x in server_set: if x in client_set: - transport = x + transport_protocol = x break - if not transport: - raise ValueError('no compatible transports found.') - if transport not in self._registry: - raise ValueError(f'no client available for {transport}') + if not transport_protocol: + raise ValueError("no compatible transports found.") + if transport_protocol not in self._registry: + raise ValueError(f"no client available for {transport_protocol}") + all_consumers = self._consumers.copy() if consumers: all_consumers.extend(consumers) - return self._registry[transport]( - card, self._config, all_consumers, interceptors or [] + + transport = self._registry[transport_protocol]( + card, self._config, interceptors or [] + ) + + return BaseClient( + card, self._config, transport, all_consumers, interceptors or [] ) def minimal_agent_card( url: str, transports: list[str] | None = None ) -> AgentCard: - """Generates a minimal card to simplify bootstrapping client creation. - - This minimal card is not viable itself to interact with the remote agent. - Instead this is a short hand way to take a known url and transport option - and interact with the get card endpoint of the agent server to get the - correct agent card. This pattern is necessary for gRPC based card access - as typically these servers won't expose a well known path card. - """ + """Generates a minimal card to simplify bootstrapping client creation.""" if transports is None: transports = [] return AgentCard( @@ -157,8 +146,8 @@ def minimal_agent_card( capabilities=AgentCapabilities(), default_input_modes=[], default_output_modes=[], - description='', + description="", skills=[], - version='', - name='', + version="", + name="", ) diff --git a/src/a2a/client/grpc_client.py b/src/a2a/client/grpc_client.py deleted file mode 100644 index 1ead88c8..00000000 --- a/src/a2a/client/grpc_client.py +++ /dev/null @@ -1,544 +0,0 @@ -import logging - -from collections.abc import AsyncGenerator, AsyncIterator - - -try: - import grpc -except ImportError as e: - raise ImportError( - 'A2AGrpcClient requires grpcio and grpcio-tools to be installed. ' - 'Install with: ' - "'pip install a2a-sdk[grpc]'" - ) from e - - -from a2a.client.client import ( - Client, - ClientCallContext, - ClientConfig, - ClientEvent, - Consumer, -) -from a2a.client.client_task_manager import ClientTaskManager -from a2a.client.errors import A2AClientInvalidStateError -from a2a.client.middleware import ClientCallInterceptor -from a2a.grpc import a2a_pb2, a2a_pb2_grpc -from a2a.types import ( - AgentCard, - GetTaskPushNotificationConfigParams, - Message, - MessageSendConfiguration, - MessageSendParams, - Task, - TaskArtifactUpdateEvent, - TaskIdParams, - TaskPushNotificationConfig, - TaskQueryParams, - TaskStatusUpdateEvent, -) -from a2a.utils import proto_utils -from a2a.utils.telemetry import SpanKind, trace_class - - -logger = logging.getLogger(__name__) - - -@trace_class(kind=SpanKind.CLIENT) -class GrpcTransportClient: - """Transport specific details for interacting with an A2A agent via gRPC.""" - - def __init__( - self, - grpc_stub: a2a_pb2_grpc.A2AServiceStub, - agent_card: AgentCard | None, - ): - """Initializes the GrpcTransportClient. - - Requires an `AgentCard` and a grpc `A2AServiceStub`. - - Args: - grpc_stub: A grpc client stub. - agent_card: The agent card object. - """ - self.agent_card = agent_card - self.stub = grpc_stub - # If they don't provide an agent card, but do have a stub, lookup the - # card from the stub. - self._needs_extended_card = ( - agent_card.supports_authenticated_extended_card - if agent_card - else True - ) - - async def send_message( - self, - request: MessageSendParams, - *, - context: ClientCallContext | None = None, - ) -> Task | Message: - """Sends a non-streaming message request to the agent. - - Args: - request: The `MessageSendParams` object containing the message and configuration. - context: The client call context. - - Returns: - A `Task` or `Message` object containing the agent's response. - """ - response = await self.stub.SendMessage( - a2a_pb2.SendMessageRequest( - request=proto_utils.ToProto.message(request.message), - configuration=proto_utils.ToProto.message_send_configuration( - request.configuration - ), - metadata=proto_utils.ToProto.metadata(request.metadata), - ) - ) - if response.task: - return proto_utils.FromProto.task(response.task) - return proto_utils.FromProto.message(response.msg) - - async def send_message_streaming( - self, - request: MessageSendParams, - *, - context: ClientCallContext | None = None, - ) -> AsyncGenerator[ - Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent - ]: - """Sends a streaming message request to the agent and yields responses as they arrive. - - This method uses gRPC streams to receive a stream of updates from the - agent. - - Args: - request: The `MessageSendParams` object containing the message and configuration. - context: The client call context. - - Yields: - `Message` or `Task` or `TaskStatusUpdateEvent` or - `TaskArtifactUpdateEvent` objects as they are received in the - stream. - """ - stream = self.stub.SendStreamingMessage( - a2a_pb2.SendMessageRequest( - request=proto_utils.ToProto.message(request.message), - configuration=proto_utils.ToProto.message_send_configuration( - request.configuration - ), - metadata=proto_utils.ToProto.metadata(request.metadata), - ) - ) - while True: - response = await stream.read() - if response == grpc.aio.EOF: # pyright: ignore [reportAttributeAccessIssue] - break - yield proto_utils.FromProto.stream_response(response) - - async def resubscribe( - self, request: TaskIdParams, *, context: ClientCallContext | None = None - ) -> AsyncGenerator[ - Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent - ]: - """Reconnects to get task updates. - - This method uses a unary server-side stream to receive updates. - - Args: - request: The `TaskIdParams` object containing the task information to reconnect to. - context: The client call context. - - Yields: - Task update events, which can be either a Task, Message, - TaskStatusUpdateEvent, or TaskArtifactUpdateEvent. - - Raises: - A2AClientInvalidStateError: If the server returns an invalid response. - """ - stream = self.stub.TaskSubscription( - a2a_pb2.TaskSubscriptionRequest(name=f'tasks/{request.id}') - ) - while True: - response = await stream.read() - if response == grpc.aio.EOF: # pyright: ignore [reportAttributeAccessIssue] - break - yield proto_utils.FromProto.stream_response(response) - - async def get_task( - self, - request: TaskQueryParams, - *, - context: ClientCallContext | None = None, - ) -> Task: - """Retrieves the current state and history of a specific task. - - Args: - request: The `TaskQueryParams` object specifying the task ID - context: The client call context. - - Returns: - A `Task` object containing the Task. - """ - task = await self.stub.GetTask( - a2a_pb2.GetTaskRequest(name=f'tasks/{request.id}') - ) - return proto_utils.FromProto.task(task) - - async def cancel_task( - self, - request: TaskIdParams, - *, - context: ClientCallContext | None = None, - ) -> Task: - """Requests the agent to cancel a specific task. - - Args: - request: The `TaskIdParams` object specifying the task ID. - context: The client call context. - - Returns: - A `Task` object containing the updated Task - """ - task = await self.stub.CancelTask( - a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}') - ) - return proto_utils.FromProto.task(task) - - async def set_task_callback( - self, - request: TaskPushNotificationConfig, - *, - context: ClientCallContext | None = None, - ) -> TaskPushNotificationConfig: - """Sets or updates the push notification configuration for a specific task. - - Args: - request: The `TaskPushNotificationConfig` object specifying the task ID and configuration. - context: The client call context. - - Returns: - A `TaskPushNotificationConfig` object containing the config. - """ - config = await self.stub.CreateTaskPushNotificationConfig( - a2a_pb2.CreateTaskPushNotificationConfigRequest( - parent='', - config_id='', - config=proto_utils.ToProto.task_push_notification_config( - request - ), - ) - ) - return proto_utils.FromProto.task_push_notification_config_request( - config - ) - - async def get_task_callback( - self, - request: TaskIdParams, # TODO: Update to a push id params - *, - context: ClientCallContext | None = None, - ) -> TaskPushNotificationConfig: - """Retrieves the push notification configuration for a specific task. - - Args: - request: The `TaskIdParams` object specifying the task ID. - context: The client call context. - - Returns: - A `TaskPushNotificationConfig` object containing the configuration. - """ - config = await self.stub.GetTaskPushNotificationConfig( - a2a_pb2.GetTaskPushNotificationConfigRequest( - name=f'tasks/{request.id}/pushNotification/undefined', - ) - ) - return proto_utils.FromProto.task_push_notification_config_request( - config - ) - - async def get_card( - self, - *, - context: ClientCallContext | None = None, - ) -> AgentCard: - """Retrieves the authenticated card (if necessary) or the public one. - - Args: - context: The client call context. - - Returns: - A `AgentCard` object containing the card. - - Raises: - grpc.RpcError: If a gRPC error occurs during the request. - """ - # If we don't have the public card, try to get that first. - card = self.agent_card - if card is None and not self._needs_extended_card: - raise ValueError('Agent card is not available.') - - if not self._needs_extended_card: - return card - - card_pb = await self.stub.GetAgentCard( - a2a_pb2.GetAgentCardRequest(), - ) - card = proto_utils.FromProto.agent_card(card_pb) - self.agent_card = card - self._needs_extended_card = False - return card - - -class GrpcClient(Client): - """GrpcClient provides the Client interface for the gRPC transport.""" - - def __init__( - self, - card: AgentCard, - config: ClientConfig, - consumers: list[Consumer], - middleware: list[ClientCallInterceptor], - ): - super().__init__(consumers, middleware) - if not config.grpc_channel_factory: - raise ValueError('GRPC client requires channel factory.') - self._card = card - self._config = config - channel = config.grpc_channel_factory(self._card.url) - stub = a2a_pb2_grpc.A2AServiceStub(channel) - self._transport_client = GrpcTransportClient(stub, self._card) - - async def send_message( - self, - request: Message, - *, - context: ClientCallContext | None = None, - ) -> AsyncIterator[ClientEvent | Message]: - """Sends a message to the agent. - - This method handles both streaming and non-streaming (polling) interactions - based on the client configuration and agent capabilities. It will yield - events as they are received from the agent. - - Args: - request: The message to send to the agent. - context: The client call context. - - Yields: - An async iterator of `ClientEvent` or a final `Message` response. - """ - config = MessageSendConfiguration( - accepted_output_modes=self._config.accepted_output_modes, - blocking=not self._config.polling, - push_notification_config=( - self._config.push_notification_configs[0] - if self._config.push_notification_configs - else None - ), - ) - if not self._config.streaming or not self._card.capabilities.streaming: - response = await self._transport_client.send_message( - MessageSendParams( - message=request, - configuration=config, - ), - context=context, - ) - result = ( - (response, None) if isinstance(response, Task) else response - ) - await self.consume(result, self._card) - yield result - return - tracker = ClientTaskManager() - stream = self._transport_client.send_message_streaming( - MessageSendParams( - message=request, - configuration=config, - ), - context=context, - ) - # Only the first event may be a Message. All others must be Task - # or TaskStatusUpdates. Separate this one out, which allows our core - # event processing logic to ignore that case. - # TODO(mikeas1): Reconcile with other transport logic. - first_event = await anext(stream) - if isinstance(first_event, Message): - yield first_event - return - yield await self._process_response(tracker, first_event) - async for result in stream: - yield await self._process_response(tracker, result) - - async def _process_response( - self, - tracker: ClientTaskManager, - event: Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent, - ) -> ClientEvent: - result = event.root.result - # Update task, check for errors, etc. - if isinstance(result, Message): - raise A2AClientInvalidStateError( - 'received a streamed Message from server after first response; this' - ' is not supported' - ) - await tracker.process(result) - result = ( - tracker.get_task_or_raise(), - None if isinstance(result, Task) else result, - ) - await self.consume(result, self._card) - return result - - async def get_task( - self, - request: TaskQueryParams, - *, - context: ClientCallContext | None = None, - ) -> Task: - """Retrieves the current state and history of a specific task. - - Args: - request: The `TaskQueryParams` object specifying the task ID. - context: The client call context. - - Returns: - A `Task` object representing the current state of the task. - """ - return await self._transport_client.get_task( - request, - context=context, - ) - - async def cancel_task( - self, - request: TaskIdParams, - *, - context: ClientCallContext | None = None, - ) -> Task: - """Requests the agent to cancel a specific task. - - Args: - request: The `TaskIdParams` object specifying the task ID. - context: The client call context. - - Returns: - A `Task` object containing the updated task status. - """ - return await self._transport_client.cancel_task( - request, - context=context, - ) - - async def set_task_callback( - self, - request: TaskPushNotificationConfig, - *, - context: ClientCallContext | None = None, - ) -> TaskPushNotificationConfig: - """Sets or updates the push notification configuration for a specific task. - - Args: - request: The `TaskPushNotificationConfig` object with the new configuration. - context: The client call context. - - Returns: - The created or updated `TaskPushNotificationConfig` object. - """ - return await self._transport_client.set_task_callback( - request, - context=context, - ) - - async def get_task_callback( - self, - request: GetTaskPushNotificationConfigParams, - *, - context: ClientCallContext | None = None, - ) -> TaskPushNotificationConfig: - """Retrieves the push notification configuration for a specific task. - - Args: - request: The `GetTaskPushNotificationConfigParams` object specifying the task. - context: The client call context. - - Returns: - A `TaskPushNotificationConfig` object containing the configuration. - """ - return await self._transport_client.get_task_callback( - request, - context=context, - ) - - async def resubscribe( - self, - request: TaskIdParams, - *, - context: ClientCallContext | None = None, - ) -> AsyncIterator[ClientEvent]: - """Resubscribes to a task's event stream. - - This is only available if both the client and server support streaming. - - Args: - request: The `TaskIdParams` object specifying the task ID to resubscribe to. - context: The client call context. - - Yields: - An async iterator of `Task` or `Message` events. - - Raises: - Exception: If streaming is not supported by the client or server. - """ - if not self._config.streaming or not self._card.capabilities.streaming: - raise NotImplementedError( - 'client and/or server do not support resubscription.' - ) - if not self._transport_client: - raise ValueError('Transport client is not initialized.') - if not hasattr(self._transport_client, 'resubscribe'): - # This can happen if the proto definitions are out of date or the method is missing - raise NotImplementedError( - 'Resubscribe is not implemented on the gRPC transport client.' - ) - # Note: works correctly for resubscription where the first event is the - # current Task state. - tracker = ClientTaskManager() - async for result in self._transport_client.resubscribe( - request, - context=context, - ): - yield await self._process_response(tracker, result) - - async def get_card( - self, - *, - context: ClientCallContext | None = None, - ) -> AgentCard: - """Retrieves the agent's card. - - This will fetch the authenticated card if necessary and update the - client's internal state with the new card. - - Args: - context: The client call context. - - Returns: - The `AgentCard` for the agent. - """ - card = await self._transport_client.get_card( - context=context, - ) - self._card = card - return card - - -def NewGrpcClient( # noqa: N802 - card: AgentCard, - config: ClientConfig, - consumers: list[Consumer], - middleware: list[ClientCallInterceptor], -) -> Client: - """Generator for the `GrpcClient` implementation.""" - return GrpcClient(card, config, consumers, middleware) diff --git a/src/a2a/client/jsonrpc_client.py b/src/a2a/client/jsonrpc_client.py deleted file mode 100644 index 5bf23dc2..00000000 --- a/src/a2a/client/jsonrpc_client.py +++ /dev/null @@ -1,851 +0,0 @@ -import json -import logging - -from collections.abc import AsyncGenerator, AsyncIterator -from typing import Any -from uuid import uuid4 - -import httpx - -from httpx_sse import SSEError, aconnect_sse - -from a2a.client.card_resolver import A2ACardResolver -from a2a.client.client import ( - Client, - ClientConfig, - ClientEvent, - Consumer, -) -from a2a.client.client_task_manager import ClientTaskManager -from a2a.client.errors import ( - A2AClientHTTPError, - A2AClientInvalidStateError, - A2AClientJSONError, - A2AClientJSONRPCError, - A2AClientTimeoutError, -) -from a2a.client.middleware import ClientCallContext, ClientCallInterceptor -from a2a.types import ( - AgentCard, - CancelTaskRequest, - CancelTaskResponse, - GetTaskPushNotificationConfigParams, - GetTaskPushNotificationConfigRequest, - GetTaskPushNotificationConfigResponse, - GetTaskRequest, - GetTaskResponse, - JSONRPCErrorResponse, - Message, - MessageSendConfiguration, - MessageSendParams, - SendMessageRequest, - SendMessageResponse, - SendStreamingMessageRequest, - SendStreamingMessageResponse, - SetTaskPushNotificationConfigRequest, - SetTaskPushNotificationConfigResponse, - Task, - TaskIdParams, - TaskPushNotificationConfig, - TaskQueryParams, - TaskResubscriptionRequest, -) -from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH -from a2a.utils.telemetry import SpanKind, trace_class - - -logger = logging.getLogger(__name__) - - -@trace_class(kind=SpanKind.CLIENT) -class JsonRpcTransportClient: - """A2A Client for interacting with an A2A agent.""" - - def __init__( - self, - httpx_client: httpx.AsyncClient, - agent_card: AgentCard | None = None, - url: str | None = None, - interceptors: list[ClientCallInterceptor] | None = None, - ): - """Initializes the A2AClient. - - Requires either an `AgentCard` or a direct `url` to the agent's RPC endpoint. - - Args: - httpx_client: An async HTTP client instance (e.g., httpx.AsyncClient). - agent_card: The agent card object. If provided, `url` is taken from `agent_card.url`. - url: The direct URL to the agent's A2A RPC endpoint. Required if `agent_card` is None. - interceptors: An optional list of client call interceptors to apply to requests. - - Raises: - ValueError: If neither `agent_card` nor `url` is provided. - """ - if agent_card: - self.url = agent_card.url - elif url: - self.url = url - else: - raise ValueError('Must provide either agent_card or url') - - self.httpx_client = httpx_client - self.agent_card = agent_card - self.interceptors = interceptors or [] - # Indicate if we have captured an extended card details so we can update - # on first call if needed. It is done this way so the caller can setup - # their auth credentials based on the public card and get the updated - # card. - self._needs_extended_card = ( - not agent_card.supports_authenticated_extended_card - if agent_card - else True - ) - - async def _apply_interceptors( - self, - method_name: str, - request_payload: dict[str, Any], - http_kwargs: dict[str, Any] | None, - context: ClientCallContext | None, - ) -> tuple[dict[str, Any], dict[str, Any]]: - """Applies all registered interceptors to the request.""" - final_http_kwargs = http_kwargs or {} - final_request_payload = request_payload - - for interceptor in self.interceptors: - ( - final_request_payload, - final_http_kwargs, - ) = await interceptor.intercept( - method_name, - final_request_payload, - final_http_kwargs, - self.agent_card, - context, - ) - return final_request_payload, final_http_kwargs - - @staticmethod - async def get_client_from_agent_card_url( - httpx_client: httpx.AsyncClient, - base_url: str, - agent_card_path: str = AGENT_CARD_WELL_KNOWN_PATH, - http_kwargs: dict[str, Any] | None = None, - ) -> 'JsonRpcTransportClient': - """[deprecated] Fetches the public AgentCard and initializes an A2A client. - - This method will always fetch the public agent card. If an authenticated - or extended agent card is required, the A2ACardResolver should be used - directly to fetch the specific card, and then the A2AClient should be - instantiated with it. - - Args: - httpx_client: An async HTTP client instance (e.g., httpx.AsyncClient). - base_url: The base URL of the agent's host. - agent_card_path: The path to the agent card endpoint, relative to the base URL. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.get request when fetching the agent card. - - Returns: - An initialized `A2AClient` instance. - - Raises: - A2AClientHTTPError: If an HTTP error occurs fetching the agent card. - A2AClientJSONError: If the agent card response is invalid. - """ - agent_card: AgentCard = await A2ACardResolver( - httpx_client, base_url=base_url, agent_card_path=agent_card_path - ).get_agent_card( - http_kwargs=http_kwargs - ) # Fetches public card by default - return JsonRpcTransportClient( - httpx_client=httpx_client, agent_card=agent_card - ) - - async def send_message( - self, - request: SendMessageRequest, - *, - http_kwargs: dict[str, Any] | None = None, - context: ClientCallContext | None = None, - ) -> SendMessageResponse: - """Sends a non-streaming message request to the agent. - - Args: - request: The `SendMessageRequest` object containing the message and configuration. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. - - Returns: - A `SendMessageResponse` object containing the agent's response (Task or Message) or an error. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - if not request.id: - request.id = str(uuid4()) - - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - 'message/send', - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - context, - ) - response_data = await self._send_request(payload, modified_kwargs) - return SendMessageResponse.model_validate(response_data) - - async def send_message_streaming( - self, - request: SendStreamingMessageRequest, - *, - http_kwargs: dict[str, Any] | None = None, - context: ClientCallContext | None = None, - ) -> AsyncGenerator[SendStreamingMessageResponse]: - """Sends a streaming message request to the agent and yields responses as they arrive. - - This method uses Server-Sent Events (SSE) to receive a stream of updates from the agent. - - Args: - request: The `SendStreamingMessageRequest` object containing the message and configuration. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. A default `timeout=None` is set but can be overridden. - context: The client call context. - - Yields: - `SendStreamingMessageResponse` objects as they are received in the SSE stream. - These can be Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent. - - Raises: - A2AClientHTTPError: If an HTTP or SSE protocol error occurs during the request. - A2AClientJSONError: If an SSE event data cannot be decoded as JSON or validated. - """ - if not request.id: - request.id = str(uuid4()) - - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - 'message/stream', - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - context, - ) - - modified_kwargs.setdefault('timeout', None) - - async with aconnect_sse( - self.httpx_client, - 'POST', - self.url, - json=payload, - **modified_kwargs, - ) as event_source: - try: - async for sse in event_source.aiter_sse(): - yield SendStreamingMessageResponse.model_validate( - json.loads(sse.data) - ) - except SSEError as e: - raise A2AClientHTTPError( - 400, - f'Invalid SSE response or protocol error: {e}', - ) from e - except json.JSONDecodeError as e: - raise A2AClientJSONError(str(e)) from e - except httpx.RequestError as e: - raise A2AClientHTTPError( - 503, f'Network communication error: {e}' - ) from e - - async def _send_request( - self, - rpc_request_payload: dict[str, Any], - http_kwargs: dict[str, Any] | None = None, - ) -> dict[str, Any]: - """Sends a non-streaming JSON-RPC request to the agent. - - Args: - rpc_request_payload: JSON RPC payload for sending the request. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - - Returns: - The JSON response payload as a dictionary. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON. - """ - try: - response = await self.httpx_client.post( - self.url, json=rpc_request_payload, **(http_kwargs or {}) - ) - response.raise_for_status() - return response.json() - except httpx.ReadTimeout as e: - raise A2AClientTimeoutError('Client Request timed out') from e - except httpx.HTTPStatusError as e: - raise A2AClientHTTPError(e.response.status_code, str(e)) from e - except json.JSONDecodeError as e: - raise A2AClientJSONError(str(e)) from e - except httpx.RequestError as e: - raise A2AClientHTTPError( - 503, f'Network communication error: {e}' - ) from e - - async def get_task( - self, - request: GetTaskRequest, - *, - http_kwargs: dict[str, Any] | None = None, - context: ClientCallContext | None = None, - ) -> GetTaskResponse: - """Retrieves the current state and history of a specific task. - - Args: - request: The `GetTaskRequest` object specifying the task ID and history length. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. - - Returns: - A `GetTaskResponse` object containing the Task or an error. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - if not request.id: - request.id = str(uuid4()) - - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - 'tasks/get', - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - context, - ) - response_data = await self._send_request(payload, modified_kwargs) - return GetTaskResponse.model_validate(response_data) - - async def cancel_task( - self, - request: CancelTaskRequest, - *, - http_kwargs: dict[str, Any] | None = None, - context: ClientCallContext | None = None, - ) -> CancelTaskResponse: - """Requests the agent to cancel a specific task. - - Args: - request: The `CancelTaskRequest` object specifying the task ID. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. - - Returns: - A `CancelTaskResponse` object containing the updated Task with canceled status or an error. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - if not request.id: - request.id = str(uuid4()) - - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - 'tasks/cancel', - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - context, - ) - response_data = await self._send_request(payload, modified_kwargs) - return CancelTaskResponse.model_validate(response_data) - - async def set_task_callback( - self, - request: SetTaskPushNotificationConfigRequest, - *, - http_kwargs: dict[str, Any] | None = None, - context: ClientCallContext | None = None, - ) -> SetTaskPushNotificationConfigResponse: - """Sets or updates the push notification configuration for a specific task. - - Args: - request: The `SetTaskPushNotificationConfigRequest` object specifying the task ID and configuration. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. - - Returns: - A `SetTaskPushNotificationConfigResponse` object containing the confirmation or an error. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - if not request.id: - request.id = str(uuid4()) - - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - 'tasks/pushNotificationConfig/set', - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - context, - ) - response_data = await self._send_request(payload, modified_kwargs) - return SetTaskPushNotificationConfigResponse.model_validate( - response_data - ) - - async def get_task_callback( - self, - request: GetTaskPushNotificationConfigRequest, - *, - http_kwargs: dict[str, Any] | None = None, - context: ClientCallContext | None = None, - ) -> GetTaskPushNotificationConfigResponse: - """Retrieves the push notification configuration for a specific task. - - Args: - request: The `GetTaskPushNotificationConfigRequest` object specifying the task ID. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. - - Returns: - A `GetTaskPushNotificationConfigResponse` object containing the configuration or an error. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - if not request.id: - request.id = str(uuid4()) - - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - 'tasks/pushNotificationConfig/get', - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - context, - ) - response_data = await self._send_request(payload, modified_kwargs) - return GetTaskPushNotificationConfigResponse.model_validate( - response_data - ) - - async def resubscribe( - self, - request: TaskResubscriptionRequest, - *, - http_kwargs: dict[str, Any] | None = None, - context: ClientCallContext | None = None, - ) -> AsyncGenerator[SendStreamingMessageResponse]: - """Reconnects to get task updates. - - This method uses Server-Sent Events (SSE) to receive a stream of updates from the agent. - - Args: - request: The `TaskResubscriptionRequest` object containing the task information to reconnect to. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. A default `timeout=None` is set but can be overridden. - context: The client call context. - - Yields: - `SendStreamingMessageResponse` objects as they are received in the SSE stream. - These can be Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent. - - Raises: - A2AClientHTTPError: If an HTTP or SSE protocol error occurs during the request. - A2AClientJSONError: If an SSE event data cannot be decoded as JSON or validated. - """ - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - 'tasks/resubscribe', - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - context, - ) - - modified_kwargs.setdefault('timeout', None) - - async with aconnect_sse( - self.httpx_client, - 'POST', - self.url, - json=payload, - **modified_kwargs, - ) as event_source: - try: - async for sse in event_source.aiter_sse(): - yield SendStreamingMessageResponse.model_validate_json( - sse.data - ) - except SSEError as e: - raise A2AClientHTTPError( - 400, - f'Invalid SSE response or protocol error: {e}', - ) from e - except json.JSONDecodeError as e: - raise A2AClientJSONError(str(e)) from e - except httpx.RequestError as e: - raise A2AClientHTTPError( - 503, f'Network communication error: {e}' - ) from e - - async def get_card( - self, - *, - http_kwargs: dict[str, Any] | None = None, - context: ClientCallContext | None = None, - ) -> AgentCard: - """Retrieves the authenticated card (if necessary) or the public one. - - Args: - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. - - Returns: - A `AgentCard` object containing the card or an error. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - # If we don't have the public card, try to get that first. - card = self.agent_card - if not card: - resolver = A2ACardResolver(self.httpx_client, self.url) - card = await resolver.get_agent_card(http_kwargs=http_kwargs) - self._needs_extended_card = ( - card.supports_authenticated_extended_card - ) - self.agent_card = card - - if not self._needs_extended_card: - return card - - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - 'card/getAuthenticated', - {}, - http_kwargs, - context, - ) - response_data = await self._send_request(payload, modified_kwargs) - card = AgentCard.model_validate(response_data) - self.agent_card = card - self._needs_extended_card = False - return card - - -@trace_class(kind=SpanKind.CLIENT) -class JsonRpcClient(Client): - """JsonRpcClient is the implementation of the JSONRPC A2A client. - - This client proxies requests to the JsonRpcTransportClient implementation - and manages the JSONRPC specific details. If passing additional arguments - in the http.post command, these should be attached to the ClientCallContext - under the dictionary key 'http_kwargs'. - """ - - def __init__( - self, - card: AgentCard, - config: ClientConfig, - consumers: list[Consumer], - middleware: list[ClientCallInterceptor], - ): - super().__init__(consumers, middleware) - if not config.httpx_client: - raise Exception('JsonRpc client requires httpx client.') - self._card = card - url = card.url - self._config = config - self._transport_client = JsonRpcTransportClient( - config.httpx_client, self._card, url, middleware - ) - - def get_http_args( - self, context: ClientCallContext | None - ) -> dict[str, Any] | None: - """Extract HTTP-specific keyword arguments from the client call context. - - Args: - context: The client call context. - - Returns: - A dictionary of HTTP arguments, or None. - """ - return context.state.get('http_kwargs', None) if context else None - - async def send_message( - self, - request: Message, - *, - context: ClientCallContext | None = None, - ) -> AsyncIterator[ClientEvent | Message]: - """Send a message to the agent and consumes the response(s). - - This method handles both blocking (non-streaming) and streaming responses - based on the client configuration and agent capabilities. - - Args: - request: The message to send. - context: The client call context. - - Yields: - An async iterator of `ClientEvent` or a final `Message` response. - - Raises: - JSONRPCError: If the agent returns a JSON-RPC error in the response. - """ - config = MessageSendConfiguration( - accepted_output_modes=self._config.accepted_output_modes, - blocking=not self._config.polling, - push_notification_config=( - self._config.push_notification_configs[0] - if self._config.push_notification_configs - else None - ), - ) - if not self._config.streaming or not self._card.capabilities.streaming: - response = await self._transport_client.send_message( - SendMessageRequest( - params=MessageSendParams( - message=request, - configuration=config, - ), - id=str(uuid4()), - ), - http_kwargs=self.get_http_args(context), - context=context, - ) - if isinstance(response.root, JSONRPCErrorResponse): - raise A2AClientJSONRPCError(response.root) - result = response.root.result - result = result if isinstance(result, Message) else (result, None) - await self.consume(result, self._card) - yield result - return - tracker = ClientTaskManager() - stream = self._transport_client.send_message_streaming( - SendStreamingMessageRequest( - params=MessageSendParams( - message=request, - configuration=config, - ), - id=str(uuid4()), - ), - http_kwargs=self.get_http_args(context), - context=context, - ) - # Only the first event may be a Message. All others must be Task - # or TaskStatusUpdates. Separate this one out, which allows our core - # event processing logic to ignore that case. - first_event = await anext(stream) - if isinstance(first_event, Message): - yield first_event - return - yield await self._process_response(tracker, first_event) - async for event in stream: - yield await self._process_response(tracker, event) - - async def _process_response( - self, - tracker: ClientTaskManager, - event: SendStreamingMessageResponse, - ) -> ClientEvent: - if isinstance(event.root, JSONRPCErrorResponse): - raise A2AClientJSONRPCError(event.root) - result = event.root.result - # Update task, check for errors, etc. - if isinstance(result, Message): - raise A2AClientInvalidStateError( - 'received a streamed Message from server after first response; this' - ' is not supported' - ) - await tracker.process(result) - result = ( - tracker.get_task_or_raise(), - None if isinstance(result, Task) else result, - ) - await self.consume(result, self._card) - return result - - async def get_task( - self, - request: TaskQueryParams, - *, - context: ClientCallContext | None = None, - ) -> Task: - """Retrieve a task from the agent. - - Args: - request: Parameters to identify the task. - context: The client call context. - - Returns: - The requested task. - """ - response = await self._transport_client.get_task( - GetTaskRequest( - params=request, - id=str(uuid4()), - ), - http_kwargs=self.get_http_args(context), - context=context, - ) - return response.root.result - - async def cancel_task( - self, - request: TaskIdParams, - *, - context: ClientCallContext | None = None, - ) -> Task: - """Cancel an ongoing task on the agent. - - Args: - request: Parameters to identify the task to cancel. - context: The client call context. - - Returns: - The task after the cancellation request. - """ - response = await self._transport_client.cancel_task( - CancelTaskRequest( - params=request, - id=str(uuid4()), - ), - http_kwargs=self.get_http_args(context), - context=context, - ) - return response.root.result - - async def set_task_callback( - self, - request: TaskPushNotificationConfig, - *, - context: ClientCallContext | None = None, - ) -> TaskPushNotificationConfig: - """Set a push notification callback for a task. - - Args: - request: The push notification configuration to set. - context: The client call context. - - Returns: - The configured task push notification configuration. - """ - response = await self._transport_client.set_task_callback( - SetTaskPushNotificationConfigRequest( - params=request, - id=str(uuid4()), - ), - http_kwargs=self.get_http_args(context), - context=context, - ) - return response.root.result - - async def get_task_callback( - self, - request: GetTaskPushNotificationConfigParams, - *, - context: ClientCallContext | None = None, - ) -> TaskPushNotificationConfig: - """Retrieve the push notification callback configuration for a task. - - Args: - request: Parameters to identify the task and configuration. - context: The client call context. - - Returns: - The requested task push notification configuration. - """ - response = await self._transport_client.get_task_callback( - GetTaskPushNotificationConfigRequest( - params=request, - id=str(uuid4()), - ), - http_kwargs=self.get_http_args(context), - context=context, - ) - return response.root.result - - async def resubscribe( - self, - request: TaskIdParams, - *, - context: ClientCallContext | None = None, - ) -> AsyncIterator[ClientEvent]: - """Resubscribe to a task's event stream. - - This is only available if both the client and server support streaming. - - Args: - request: Parameters to identify the task to resubscribe to. - context: The client call context. - - Yields: - Task events from the agent. - - Raises: - Exception: If streaming is not supported. - """ - if not self._config.streaming or not self._card.capabilities.streaming: - raise NotImplementedError( - 'client and/or server do not support resubscription.' - ) - tracker = ClientTaskManager() - async for event in self._transport_client.resubscribe( - TaskResubscriptionRequest( - params=request, - id=str(uuid4()), - ), - http_kwargs=self.get_http_args(context), - context=context, - ): - yield await self._process_response(tracker, event) - - async def get_card( - self, - *, - context: ClientCallContext | None = None, - ) -> AgentCard: - """Retrieve the agent's card. - - This may involve fetching the public card first if not already available, - and then fetching the authenticated extended card if supported and required. - - Args: - context: The client call context. - - Returns: - The agent's card. - """ - return await self._transport_client.get_card( - http_kwargs=self.get_http_args(context), - context=context, - ) - - -def NewJsonRpcClient( # noqa: N802 - card: AgentCard, - config: ClientConfig, - consumers: list[Consumer], - middleware: list[ClientCallInterceptor], -) -> Client: - """Factory function for the `JsonRpcClient` implementation.""" - return JsonRpcClient(card, config, consumers, middleware) - - -# For backward compatability define this alias. This will be deprecated in -# a future release. -A2AClient = JsonRpcTransportClient diff --git a/src/a2a/client/rest_client.py b/src/a2a/client/rest_client.py deleted file mode 100644 index 552defc3..00000000 --- a/src/a2a/client/rest_client.py +++ /dev/null @@ -1,833 +0,0 @@ -import json -import logging - -from collections.abc import AsyncGenerator, AsyncIterator -from typing import Any - -import httpx - -from google.protobuf.json_format import MessageToDict, Parse, ParseDict -from httpx_sse import SSEError, aconnect_sse - -from a2a.client.card_resolver import A2ACardResolver -from a2a.client.client import Client, ClientConfig, ClientEvent, Consumer -from a2a.client.client_task_manager import ClientTaskManager -from a2a.client.errors import ( - A2AClientHTTPError, - A2AClientInvalidStateError, - A2AClientJSONError, -) -from a2a.client.middleware import ClientCallContext, ClientCallInterceptor -from a2a.grpc import a2a_pb2 -from a2a.types import ( - AgentCard, - GetTaskPushNotificationConfigParams, - Message, - MessageSendConfiguration, - MessageSendParams, - Task, - TaskArtifactUpdateEvent, - TaskIdParams, - TaskPushNotificationConfig, - TaskQueryParams, - TaskStatusUpdateEvent, -) -from a2a.utils import proto_utils -from a2a.utils.telemetry import SpanKind, trace_class - - -logger = logging.getLogger(__name__) - - -@trace_class(kind=SpanKind.CLIENT) -class RestTransportClient: - """A2A Client for interacting with an A2A agent.""" - - def __init__( - self, - httpx_client: httpx.AsyncClient, - agent_card: AgentCard | None = None, - url: str | None = None, - interceptors: list[ClientCallInterceptor] | None = None, - ): - """Initializes the A2AClient. - - Requires either an `AgentCard` or a direct `url` to the agent's RPC endpoint. - - Args: - httpx_client: An async HTTP client instance (e.g., httpx.AsyncClient). - agent_card: The agent card object. If provided, `url` is taken from `agent_card.url`. - url: The direct URL to the agent's A2A RPC endpoint. Required if `agent_card` is None. - interceptors: An optional list of client call interceptors to apply to requests. - - Raises: - ValueError: If neither `agent_card` nor `url` is provided. - """ - if agent_card: - self.url = agent_card.url - elif url: - self.url = url - else: - raise ValueError('Must provide either agent_card or url') - # If the url ends in / remove it as this is added by the routes - if self.url.endswith('/'): - self.url = self.url[:-1] - self.httpx_client = httpx_client - self.agent_card = agent_card - self.interceptors = interceptors or [] - # Indicate if we have captured an extended card details so we can update - # on first call if needed. It is done this way so the caller can setup - # their auth credentials based on the public card and get the updated - # card. - self._needs_extended_card = ( - not agent_card.supports_authenticated_extended_card - if agent_card - else True - ) - - async def _apply_interceptors( - self, - request_payload: dict[str, Any], - http_kwargs: dict[str, Any] | None, - context: ClientCallContext | None, - ) -> tuple[dict[str, Any], dict[str, Any]]: - """Applies all registered interceptors to the request.""" - final_http_kwargs = http_kwargs or {} - final_request_payload = request_payload - # TODO: Implement interceptors for other transports - return final_request_payload, final_http_kwargs - - async def send_message( - self, - request: MessageSendParams, - *, - http_kwargs: dict[str, Any] | None = None, - context: ClientCallContext | None = None, - ) -> Task | Message: - """Sends a non-streaming message request to the agent. - - Args: - request: The `MessageSendParams` object containing the message and configuration. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. - - Returns: - A `Task` or `Message` object containing the agent's response. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - pb = a2a_pb2.SendMessageRequest( - request=proto_utils.ToProto.message(request.message), - configuration=proto_utils.ToProto.message_send_configuration( - request.configuration - ), - metadata=( - proto_utils.ToProto.metadata(request.metadata) - if request.metadata - else None - ), - ) - payload = MessageToDict(pb) - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - payload, - http_kwargs, - context, - ) - response_data = await self._send_post_request( - '/v1/message:send', payload, modified_kwargs - ) - response_pb = a2a_pb2.SendMessageResponse() - ParseDict(response_data, response_pb) - return proto_utils.FromProto.task_or_message(response_pb) - - async def send_message_streaming( - self, - request: MessageSendParams, - *, - http_kwargs: dict[str, Any] | None = None, - context: ClientCallContext | None = None, - ) -> AsyncGenerator[ - Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message - ]: - """Sends a streaming message request to the agent and yields responses as they arrive. - - This method uses Server-Sent Events (SSE) to receive a stream of updates from the agent. - - Args: - request: The `MessageSendParams` object containing the message and configuration. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. A default `timeout=None` is set but can be overridden. - context: The client call context. - - Yields: - Objects as they are received in the SSE stream. - These can be Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent. - - Raises: - A2AClientHTTPError: If an HTTP or SSE protocol error occurs during the request. - A2AClientJSONError: If an SSE event data cannot be decoded as JSON or validated. - """ - pb = a2a_pb2.SendMessageRequest( - request=proto_utils.ToProto.message(request.message), - configuration=proto_utils.ToProto.message_send_configuration( - request.configuration - ), - metadata=( - proto_utils.ToProto.metadata(request.metadata) - if request.metadata - else None - ), - ) - payload = MessageToDict(pb) - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - payload, - http_kwargs, - context, - ) - - modified_kwargs.setdefault('timeout', None) - - async with aconnect_sse( - self.httpx_client, - 'POST', - f'{self.url}/v1/message:stream', - json=payload, - **modified_kwargs, - ) as event_source: - try: - async for sse in event_source.aiter_sse(): - event = a2a_pb2.StreamResponse() - Parse(sse.data, event) - yield proto_utils.FromProto.stream_response(event) - except SSEError as e: - raise A2AClientHTTPError( - 400, - f'Invalid SSE response or protocol error: {e}', - ) from e - except json.JSONDecodeError as e: - raise A2AClientJSONError(str(e)) from e - except httpx.RequestError as e: - raise A2AClientHTTPError( - 503, f'Network communication error: {e}' - ) from e - - async def _send_post_request( - self, - target: str, - rpc_request_payload: dict[str, Any], - http_kwargs: dict[str, Any] | None = None, - ) -> dict[str, Any]: - """Sends a non-streaming JSON-RPC request to the agent. - - Args: - target: url path - rpc_request_payload: JSON payload for sending the request. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - - Returns: - The JSON response payload as a dictionary. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON. - """ - try: - response = await self.httpx_client.post( - f'{self.url}{target}', - json=rpc_request_payload, - **(http_kwargs or {}), - ) - response.raise_for_status() - return response.json() - except httpx.HTTPStatusError as e: - raise A2AClientHTTPError(e.response.status_code, str(e)) from e - except json.JSONDecodeError as e: - raise A2AClientJSONError(str(e)) from e - except httpx.RequestError as e: - raise A2AClientHTTPError( - 503, f'Network communication error: {e}' - ) from e - - async def _send_get_request( - self, - target: str, - query_params: dict[str, str], - http_kwargs: dict[str, Any] | None = None, - ) -> dict[str, Any]: - """Sends a non-streaming JSON-RPC request to the agent. - - Args: - target: url path - query_params: HTTP query params for the request. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - - Returns: - The JSON response payload as a dictionary. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON. - """ - try: - response = await self.httpx_client.get( - f'{self.url}{target}', - params=query_params, - **(http_kwargs or {}), - ) - response.raise_for_status() - return response.json() - except httpx.HTTPStatusError as e: - raise A2AClientHTTPError(e.response.status_code, str(e)) from e - except json.JSONDecodeError as e: - raise A2AClientJSONError(str(e)) from e - except httpx.RequestError as e: - raise A2AClientHTTPError( - 503, f'Network communication error: {e}' - ) from e - - async def get_task( - self, - request: TaskQueryParams, - *, - http_kwargs: dict[str, Any] | None = None, - context: ClientCallContext | None = None, - ) -> Task: - """Retrieves the current state and history of a specific task. - - Args: - request: The `TaskQueryParams` object specifying the task ID and history length. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. - - Returns: - A `Task` object containing the Task. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - # Apply interceptors before sending - only for the http kwargs - payload, modified_kwargs = await self._apply_interceptors( - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - context, - ) - response_data = await self._send_get_request( - f'/v1/tasks/{request.taskId}', - {'historyLength': str(request.history_length)} - if request.history_length - else {}, - modified_kwargs, - ) - task = a2a_pb2.Task() - ParseDict(response_data, task) - return proto_utils.FromProto.task(task) - - async def cancel_task( - self, - request: TaskIdParams, - *, - http_kwargs: dict[str, Any] | None = None, - context: ClientCallContext | None = None, - ) -> Task: - """Requests the agent to cancel a specific task. - - Args: - request: The `TaskIdParams` object specifying the task ID. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. - - Returns: - A `Task` object containing the updated Task with canceled status - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - pb = a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}') - payload = MessageToDict(pb) - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - payload, - http_kwargs, - context, - ) - response_data = await self._send_post_request( - f'/v1/tasks/{request.id}:cancel', payload, modified_kwargs - ) - task = a2a_pb2.Task() - ParseDict(response_data, task) - return proto_utils.FromProto.task(task) - - async def set_task_callback( - self, - request: TaskPushNotificationConfig, - *, - http_kwargs: dict[str, Any] | None = None, - context: ClientCallContext | None = None, - ) -> TaskPushNotificationConfig: - """Sets or updates the push notification configuration for a specific task. - - Args: - request: The `TaskPushNotificationConfig` object specifying the task ID and configuration. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. - - Returns: - A `TaskPushNotificationConfig` object containing the confirmation. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - pb = a2a_pb2.CreateTaskPushNotificationConfigRequest( - parent=f'tasks/{request.taskId}', - config_id=request.push_notification_config.id, - config=proto_utils.ToProto.task_push_notification_config(request), - ) - payload = MessageToDict(pb) - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - payload, http_kwargs, context - ) - response_data = await self._send_post_request( - f'/v1/tasks/{request.taskId}/pushNotificationConfigs/', - payload, - modified_kwargs, - ) - config = a2a_pb2.TaskPushNotificationConfig() - ParseDict(response_data, config) - return proto_utils.FromProto.task_push_notification_config(config) - - async def get_task_callback( - self, - request: GetTaskPushNotificationConfigParams, - *, - http_kwargs: dict[str, Any] | None = None, - context: ClientCallContext | None = None, - ) -> TaskPushNotificationConfig: - """Retrieves the push notification configuration for a specific task. - - Args: - request: The `GetTaskPushNotificationConfigParams` object specifying the task ID. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. - - Returns: - A `TaskPushNotificationConfig` object containing the configuration. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - pb = a2a_pb2.GetTaskPushNotificationConfigRequest( - name=f'tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}', - ) - payload = MessageToDict(pb) - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - payload, - http_kwargs, - context, - ) - response_data = await self._send_get_request( - f'/v1/tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}', - {}, - modified_kwargs, - ) - config = a2a_pb2.TaskPushNotificationConfig() - ParseDict(response_data, config) - return proto_utils.FromProto.task_push_notification_config(config) - - async def resubscribe( - self, - request: TaskIdParams, - *, - http_kwargs: dict[str, Any] | None = None, - context: ClientCallContext | None = None, - ) -> AsyncGenerator[ - Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message - ]: - """Reconnects to get task updates. - - This method uses Server-Sent Events (SSE) to receive a stream of updates from the agent. - - Args: - request: The `TaskIdParams` object containing the task information to reconnect to. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. A default `timeout=None` is set but can be overridden. - context: The client call context. - - Yields: - Objects as they are received in the SSE stream. - These can be Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent. - - Raises: - A2AClientHTTPError: If an HTTP or SSE protocol error occurs during the request. - A2AClientJSONError: If an SSE event data cannot be decoded as JSON or validated. - """ - pb = a2a_pb2.TaskSubscriptionRequest( - name=f'tasks/{request.id}', - ) - payload = MessageToDict(pb) - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - payload, - http_kwargs, - context, - ) - - modified_kwargs.setdefault('timeout', None) - - async with aconnect_sse( - self.httpx_client, - 'GET', - f'{self.url}/v1/tasks/{request.id}:subscribe', - json=payload, - **modified_kwargs, - ) as event_source: - try: - async for sse in event_source.aiter_sse(): - event = a2a_pb2.StreamResponse() - Parse(sse.data, event) - yield proto_utils.FromProto.stream_response(event) - except SSEError as e: - raise A2AClientHTTPError( - 400, - f'Invalid SSE response or protocol error: {e}', - ) from e - except json.JSONDecodeError as e: - raise A2AClientJSONError(str(e)) from e - except httpx.RequestError as e: - raise A2AClientHTTPError( - 503, f'Network communication error: {e}' - ) from e - - async def get_card( - self, - *, - http_kwargs: dict[str, Any] | None = None, - context: ClientCallContext | None = None, - ) -> AgentCard: - """Retrieves the authenticated card (if necessary) or the public one. - - Args: - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. - - Returns: - A `AgentCard` object containing the card or an error. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - # If we don't have the public card, try to get that first. - card = self.agent_card - if not card: - resolver = A2ACardResolver(self.httpx_client, self.url) - card = await resolver.get_agent_card(http_kwargs=http_kwargs) - self._needs_extended_card = ( - card.supports_authenticated_extended_card - ) - self.agent_card = card - - if not self._needs_extended_card: - return card - - # Apply interceptors before sending - _, modified_kwargs = await self._apply_interceptors( - {}, - http_kwargs, - context, - ) - response_data = await self._send_get_request( - '/v1/card/get', {}, modified_kwargs - ) - card = AgentCard.model_validate(response_data) - self.agent_card = card - self._needs_extended_card = False - return card - - -@trace_class(kind=SpanKind.CLIENT) -class RestClient(Client): - """RestClient is the implementation of the RESTful A2A client. - - This client proxies requests to the RestTransportClient implementation - and manages the REST specific details. If passing additional arguments - in the http.post command, these should be attached to the ClientCallContext - under the dictionary key 'http_kwargs'. - """ - - def __init__( - self, - card: AgentCard, - config: ClientConfig, - consumers: list[Consumer], - middleware: list[ClientCallInterceptor], - ): - super().__init__(consumers, middleware) - if not config.httpx_client: - raise ValueError('RestClient client requires httpx client.') - self._card = card - url = card.url - self._config = config - self._transport_client = RestTransportClient( - config.httpx_client, self._card, url, middleware - ) - - def get_http_args( - self, context: ClientCallContext | None - ) -> dict[str, Any] | None: - """Extract HTTP-specific keyword arguments from the client call context. - - Args: - context: The client call context. - - Returns: - A dictionary of HTTP arguments, or None. - """ - return context.state.get('http_kwargs', None) if context else None - - async def send_message( - self, - request: Message, - *, - context: ClientCallContext | None = None, - ) -> AsyncIterator[Message | ClientEvent]: - """Send a message to the agent and consumes the response(s). - - This method handles both blocking (non-streaming) and streaming responses - based on the client configuration and agent capabilities. - - Args: - request: The message to send. - context: The client call context. - - Yields: - The final message or task result from the agent. - """ - config = MessageSendConfiguration( - accepted_output_modes=self._config.accepted_output_modes, - blocking=not self._config.polling, - push_notification_config=( - self._config.push_notification_configs[0] - if self._config.push_notification_configs - else None - ), - ) - if not self._config.streaming or not self._card.capabilities.streaming: - response = await self._transport_client.send_message( - MessageSendParams( - message=request, - configuration=config, - ), - http_kwargs=self.get_http_args(context), - context=context, - ) - result = ( - response if isinstance(response, Message) else (response, None) - ) - await self.consume(result, self._card) - yield result - return - tracker = ClientTaskManager() - stream = self._transport_client.send_message_streaming( - MessageSendParams( - message=request, - configuration=config, - ), - http_kwargs=self.get_http_args(context), - context=context, - ) - # Only the first event may be a Message. All others must be Task - # or TaskStatusUpdates. Separate this one out, which allows our core - # event processing logic to ignore that case. - first_event = await anext(stream) - if isinstance(first_event, Message): - yield first_event - return - yield await self._process_response(tracker, first_event) - async for event in stream: - yield await self._process_response(tracker, event) - - async def _process_response( - self, - tracker: ClientTaskManager, - event: Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message, - ) -> ClientEvent: - result = event.root.result - # Update task, check for errors, etc. - if isinstance(result, Message): - raise A2AClientInvalidStateError( - 'received a streamed Message from server after first response; this' - ' is not supported' - ) - await tracker.process(result) - result = ( - tracker.get_task_or_raise(), - None if isinstance(result, Task) else result, - ) - await self.consume(result, self._card) - return result - - async def get_task( - self, - request: TaskQueryParams, - *, - context: ClientCallContext | None = None, - ) -> Task: - """Retrieve a task from the agent. - - Args: - request: Parameters to identify the task. - context: The client call context. - - Returns: - The requested task. - """ - return await self._transport_client.get_task( - request, - http_kwargs=self.get_http_args(context), - context=context, - ) - - async def cancel_task( - self, - request: TaskIdParams, - *, - context: ClientCallContext | None = None, - ) -> Task: - """Cancel an ongoing task on the agent. - - Args: - request: Parameters to identify the task to cancel. - context: The client call context. - - Returns: - The task after the cancellation request. - """ - return await self._transport_client.cancel_task( - request, - http_kwargs=self.get_http_args(context), - context=context, - ) - - async def set_task_callback( - self, - request: TaskPushNotificationConfig, - *, - context: ClientCallContext | None = None, - ) -> TaskPushNotificationConfig: - """Set a push notification callback for a task. - - Args: - request: The push notification configuration to set. - context: The client call context. - - Returns: - The configured task push notification configuration. - """ - return await self._transport_client.set_task_callback( - request, - http_kwargs=self.get_http_args(context), - context=context, - ) - - async def get_task_callback( - self, - request: GetTaskPushNotificationConfigParams, - *, - context: ClientCallContext | None = None, - ) -> TaskPushNotificationConfig: - """Retrieve the push notification callback configuration for a task. - - Args: - request: Parameters to identify the task and configuration. - context: The client call context. - - Returns: - The requested task push notification configuration. - """ - return await self._transport_client.get_task_callback( - request, - http_kwargs=self.get_http_args(context), - context=context, - ) - - async def resubscribe( - self, - request: TaskIdParams, - *, - context: ClientCallContext | None = None, - ) -> AsyncIterator[ClientEvent]: - """Resubscribe to a task's event stream. - - This is only available if both the client and server support streaming. - - Args: - request: Parameters to identify the task to resubscribe to. - context: The client call context. - - Yields: - Task events from the agent. - - Raises: - Exception: If streaming is not supported. - """ - if not self._config.streaming or not self._card.capabilities.streaming: - raise NotImplementedError( - 'client and/or server do not support resubscription.' - ) - tracker = ClientTaskManager() - async for event in self._transport_client.resubscribe( - request, - http_kwargs=self.get_http_args(context), - context=context, - ): - # Update task, check for errors, etc. - yield await self._process_response(tracker, event) - - async def get_card( - self, - *, - context: ClientCallContext | None = None, - ) -> AgentCard: - """Retrieve the agent's card. - - This may involve fetching the public card first if not already available, - and then fetching the authenticated extended card if supported and required. - - Args: - context: The client call context. - - Returns: - The agent's card. - """ - return await self._transport_client.get_card( - http_kwargs=self.get_http_args(context), - context=context, - ) - - -def NewRestfulClient( # noqa: N802 - card: AgentCard, - config: ClientConfig, - consumers: list[Consumer], - middleware: list[ClientCallInterceptor], -) -> Client: - """Factory function for the `RestClient` implementation.""" - return RestClient(card, config, consumers, middleware) diff --git a/src/a2a/client/transports/__init__.py b/src/a2a/client/transports/__init__.py new file mode 100644 index 00000000..8bcca4e4 --- /dev/null +++ b/src/a2a/client/transports/__init__.py @@ -0,0 +1,19 @@ +"""A2A Client Transports.""" + +from a2a.client.transports.base import ClientTransport +from a2a.client.transports.jsonrpc import JsonRpcTransport +from a2a.client.transports.rest import RestTransport + + +try: + from a2a.client.transports.grpc import GrpcTransport +except ImportError: + GrpcTransport = None + + +__all__ = [ + "ClientTransport", + "GrpcTransport", + "JsonRpcTransport", + "RestTransport", +] diff --git a/src/a2a/client/transports/base.py b/src/a2a/client/transports/base.py new file mode 100644 index 00000000..23bf1c56 --- /dev/null +++ b/src/a2a/client/transports/base.py @@ -0,0 +1,101 @@ +from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator + +from a2a.client.middleware import ClientCallContext +from a2a.types import ( + AgentCard, + GetTaskPushNotificationConfigParams, + Message, + MessageSendParams, + Task, + TaskArtifactUpdateEvent, + TaskIdParams, + TaskPushNotificationConfig, + TaskQueryParams, + TaskStatusUpdateEvent, +) + + +class ClientTransport(ABC): + """Abstract base class for a client transport.""" + + @abstractmethod + async def send_message( + self, + request: MessageSendParams, + *, + context: ClientCallContext | None = None, + ) -> Task | Message: + """Sends a non-streaming message request to the agent.""" + + @abstractmethod + async def send_message_streaming( + self, + request: MessageSendParams, + *, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[ + Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent + ]: + """Sends a streaming message request to the agent and yields responses as they arrive.""" + yield + + @abstractmethod + async def get_task( + self, + request: TaskQueryParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + """Retrieves the current state and history of a specific task.""" + + @abstractmethod + async def cancel_task( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + """Requests the agent to cancel a specific task.""" + + @abstractmethod + async def set_task_callback( + self, + request: TaskPushNotificationConfig, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Sets or updates the push notification configuration for a specific task.""" + + @abstractmethod + async def get_task_callback( + self, + request: GetTaskPushNotificationConfigParams, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Retrieves the push notification configuration for a specific task.""" + + @abstractmethod + async def resubscribe( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[ + Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent + ]: + """Reconnects to get task updates.""" + yield + + @abstractmethod + async def get_card( + self, + *, + context: ClientCallContext | None = None, + ) -> AgentCard: + """Retrieves the agent's card.""" + + @abstractmethod + async def close(self) -> None: + """Closes the transport.""" diff --git a/src/a2a/client/transports/grpc.py b/src/a2a/client/transports/grpc.py new file mode 100644 index 00000000..fcba6f27 --- /dev/null +++ b/src/a2a/client/transports/grpc.py @@ -0,0 +1,190 @@ +import logging + +from collections.abc import AsyncGenerator + + +try: + import grpc +except ImportError as e: + raise ImportError( + "A2AGrpcClient requires grpcio and grpcio-tools to be installed. " + "Install with: " + "'pip install a2a-sdk[grpc]'" + ) from e + +from a2a.client.middleware import ClientCallContext +from a2a.client.transports.base import ClientTransport +from a2a.grpc import a2a_pb2, a2a_pb2_grpc +from a2a.types import ( + AgentCard, + GetTaskPushNotificationConfigParams, + Message, + MessageSendParams, + Task, + TaskArtifactUpdateEvent, + TaskIdParams, + TaskPushNotificationConfig, + TaskQueryParams, + TaskStatusUpdateEvent, +) +from a2a.utils import proto_utils +from a2a.utils.telemetry import SpanKind, trace_class + + +logger = logging.getLogger(__name__) + + +@trace_class(kind=SpanKind.CLIENT) +class GrpcTransport(ClientTransport): + """A gRPC transport for the A2A client.""" + + def __init__( + self, + grpc_stub: a2a_pb2_grpc.A2AServiceStub, + agent_card: AgentCard | None, + ): + """Initializes the GrpcTransport.""" + self.agent_card = agent_card + self.stub = grpc_stub + self._needs_extended_card = ( + agent_card.supports_authenticated_extended_card if agent_card else True + ) + + async def send_message( + self, + request: MessageSendParams, + *, + context: ClientCallContext | None = None, + ) -> Task | Message: + """Sends a non-streaming message request to the agent.""" + response = await self.stub.SendMessage( + a2a_pb2.SendMessageRequest( + request=proto_utils.ToProto.message(request.message), + configuration=proto_utils.ToProto.message_send_configuration( + request.configuration + ), + metadata=proto_utils.ToProto.metadata(request.metadata), + ) + ) + if response.task: + return proto_utils.FromProto.task(response.task) + return proto_utils.FromProto.message(response.msg) + + async def send_message_streaming( + self, + request: MessageSendParams, + *, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[ + Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent + ]: + """Sends a streaming message request to the agent and yields responses as they arrive.""" + stream = self.stub.SendStreamingMessage( + a2a_pb2.SendMessageRequest( + request=proto_utils.ToProto.message(request.message), + configuration=proto_utils.ToProto.message_send_configuration( + request.configuration + ), + metadata=proto_utils.ToProto.metadata(request.metadata), + ) + ) + while True: + response = await stream.read() + if response == grpc.aio.EOF: + break + yield proto_utils.FromProto.stream_response(response) + + async def resubscribe( + self, request: TaskIdParams, *, context: ClientCallContext | None = None + ) -> AsyncGenerator[ + Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent + ]: + """Reconnects to get task updates.""" + stream = self.stub.TaskSubscription( + a2a_pb2.TaskSubscriptionRequest(name=f"tasks/{request.id}") + ) + while True: + response = await stream.read() + if response == grpc.aio.EOF: + break + yield proto_utils.FromProto.stream_response(response) + + async def get_task( + self, + request: TaskQueryParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + """Retrieves the current state and history of a specific task.""" + task = await self.stub.GetTask( + a2a_pb2.GetTaskRequest(name=f"tasks/{request.id}") + ) + return proto_utils.FromProto.task(task) + + async def cancel_task( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + """Requests the agent to cancel a specific task.""" + task = await self.stub.CancelTask( + a2a_pb2.CancelTaskRequest(name=f"tasks/{request.id}") + ) + return proto_utils.FromProto.task(task) + + async def set_task_callback( + self, + request: TaskPushNotificationConfig, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Sets or updates the push notification configuration for a specific task.""" + config = await self.stub.CreateTaskPushNotificationConfig( + a2a_pb2.CreateTaskPushNotificationConfigRequest( + parent="", + config_id="", + config=proto_utils.ToProto.task_push_notification_config(request), + ) + ) + return proto_utils.FromProto.task_push_notification_config(config) + + async def get_task_callback( + self, + request: GetTaskPushNotificationConfigParams, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Retrieves the push notification configuration for a specific task.""" + config = await self.stub.GetTaskPushNotificationConfig( + a2a_pb2.GetTaskPushNotificationConfigRequest( + name=f"tasks/{request.id}/pushNotification/{request.push_notification_config_id}", + ) + ) + return proto_utils.FromProto.task_push_notification_config(config) + + async def get_card( + self, + *, + context: ClientCallContext | None = None, + ) -> AgentCard: + """Retrieves the agent's card.""" + card = self.agent_card + if card is None and not self._needs_extended_card: + raise ValueError("Agent card is not available.") + + if not self._needs_extended_card: + return card + + card_pb = await self.stub.GetAgentCard( + a2a_pb2.GetAgentCardRequest(), + ) + card = proto_utils.FromProto.agent_card(card_pb) + self.agent_card = card + self._needs_extended_card = False + return card + + async def close(self) -> None: + """Closes the gRPC channel.""" + if hasattr(self.stub, "close"): + await self.stub.close() diff --git a/src/a2a/client/transports/jsonrpc.py b/src/a2a/client/transports/jsonrpc.py new file mode 100644 index 00000000..26f92749 --- /dev/null +++ b/src/a2a/client/transports/jsonrpc.py @@ -0,0 +1,337 @@ +import json +import logging + +from collections.abc import AsyncGenerator +from typing import Any +from uuid import uuid4 + +import httpx + +from httpx_sse import SSEError, aconnect_sse + +from a2a.client.card_resolver import A2ACardResolver +from a2a.client.errors import ( + A2AClientHTTPError, + A2AClientJSONError, + A2AClientJSONRPCError, + A2AClientTimeoutError, +) +from a2a.client.middleware import ClientCallContext, ClientCallInterceptor +from a2a.client.transports.base import ClientTransport +from a2a.types import ( + AgentCard, + CancelTaskRequest, + GetTaskPushNotificationConfigParams, + GetTaskPushNotificationConfigRequest, + GetTaskRequest, + JSONRPCErrorResponse, + Message, + MessageSendParams, + SendMessageRequest, + SendStreamingMessageRequest, + SetTaskPushNotificationConfigRequest, + Task, + TaskArtifactUpdateEvent, + TaskIdParams, + TaskPushNotificationConfig, + TaskQueryParams, + TaskResubscriptionRequest, + TaskStatusUpdateEvent, +) +from a2a.utils.telemetry import SpanKind, trace_class + + +logger = logging.getLogger(__name__) + + +@trace_class(kind=SpanKind.CLIENT) +class JsonRpcTransport(ClientTransport): + """A JSON-RPC transport for the A2A client.""" + + def __init__( + self, + httpx_client: httpx.AsyncClient, + agent_card: AgentCard | None = None, + url: str | None = None, + interceptors: list[ClientCallInterceptor] | None = None, + ): + """Initializes the JsonRpcTransport.""" + if agent_card: + self.url = agent_card.url + elif url: + self.url = url + else: + raise ValueError("Must provide either agent_card or url") + + self.httpx_client = httpx_client + self.agent_card = agent_card + self.interceptors = interceptors or [] + self._needs_extended_card = ( + agent_card.supports_authenticated_extended_card if agent_card else True + ) + + async def _apply_interceptors( + self, + method_name: str, + request_payload: dict[str, Any], + http_kwargs: dict[str, Any] | None, + context: ClientCallContext | None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + final_http_kwargs = http_kwargs or {} + final_request_payload = request_payload + + for interceptor in self.interceptors: + ( + final_request_payload, + final_http_kwargs, + ) = await interceptor.intercept( + method_name, + final_request_payload, + final_http_kwargs, + self.agent_card, + context, + ) + return final_request_payload, final_http_kwargs + + def _get_http_args(self, context: ClientCallContext | None) -> dict[str, Any] | None: + return context.state.get("http_kwargs") if context else None + + async def send_message( + self, + request: MessageSendParams, + *, + context: ClientCallContext | None = None, + ) -> Task | Message: + """Sends a non-streaming message request to the agent.""" + rpc_request = SendMessageRequest(params=request, id=str(uuid4())) + payload, modified_kwargs = await self._apply_interceptors( + "message/send", + rpc_request.model_dump(mode="json", exclude_none=True), + self._get_http_args(context), + context, + ) + response_data = await self._send_request(payload, modified_kwargs) + response = SendMessageRequest.model_validate(response_data) + if isinstance(response.root, JSONRPCErrorResponse): + raise A2AClientJSONRPCError(response.root) + return response.root.result + + async def send_message_streaming( + self, + request: MessageSendParams, + *, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[ + Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent + ]: + """Sends a streaming message request to the agent and yields responses as they arrive.""" + rpc_request = SendStreamingMessageRequest(params=request, id=str(uuid4())) + payload, modified_kwargs = await self._apply_interceptors( + "message/stream", + rpc_request.model_dump(mode="json", exclude_none=True), + self._get_http_args(context), + context, + ) + + modified_kwargs.setdefault("timeout", None) + + async with aconnect_sse( + self.httpx_client, + "POST", + self.url, + json=payload, + **modified_kwargs, + ) as event_source: + try: + async for sse in event_source.aiter_sse(): + response = SendStreamingMessageRequest.model_validate( + json.loads(sse.data) + ) + if isinstance(response.root, JSONRPCErrorResponse): + raise A2AClientJSONRPCError(response.root) + yield response.root.result + except SSEError as e: + raise A2AClientHTTPError( + 400, f"Invalid SSE response or protocol error: {e}" + ) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + except httpx.RequestError as e: + raise A2AClientHTTPError( + 503, f"Network communication error: {e}" + ) from e + + async def _send_request( + self, + rpc_request_payload: dict[str, Any], + http_kwargs: dict[str, Any] | None = None, + ) -> dict[str, Any]: + try: + response = await self.httpx_client.post( + self.url, json=rpc_request_payload, **(http_kwargs or {}) + ) + response.raise_for_status() + return response.json() + except httpx.ReadTimeout as e: + raise A2AClientTimeoutError("Client Request timed out") from e + except httpx.HTTPStatusError as e: + raise A2AClientHTTPError(e.response.status_code, str(e)) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + except httpx.RequestError as e: + raise A2AClientHTTPError(503, f"Network communication error: {e}") from e + + async def get_task( + self, + request: TaskQueryParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + """Retrieves the current state and history of a specific task.""" + rpc_request = GetTaskRequest(params=request, id=str(uuid4())) + payload, modified_kwargs = await self._apply_interceptors( + "tasks/get", + rpc_request.model_dump(mode="json", exclude_none=True), + self._get_http_args(context), + context, + ) + response_data = await self._send_request(payload, modified_kwargs) + response = GetTaskRequest.model_validate(response_data) + if isinstance(response.root, JSONRPCErrorResponse): + raise A2AClientJSONRPCError(response.root) + return response.root.result + + async def cancel_task( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + """Requests the agent to cancel a specific task.""" + rpc_request = CancelTaskRequest(params=request, id=str(uuid4())) + payload, modified_kwargs = await self._apply_interceptors( + "tasks/cancel", + rpc_request.model_dump(mode="json", exclude_none=True), + self._get_http_args(context), + context, + ) + response_data = await self._send_request(payload, modified_kwargs) + response = CancelTaskRequest.model_validate(response_data) + if isinstance(response.root, JSONRPCErrorResponse): + raise A2AClientJSONRPCError(response.root) + return response.root.result + + async def set_task_callback( + self, + request: TaskPushNotificationConfig, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Sets or updates the push notification configuration for a specific task.""" + rpc_request = SetTaskPushNotificationConfigRequest(params=request, id=str(uuid4())) + payload, modified_kwargs = await self._apply_interceptors( + "tasks/pushNotificationConfig/set", + rpc_request.model_dump(mode="json", exclude_none=True), + self._get_http_args(context), + context, + ) + response_data = await self._send_request(payload, modified_kwargs) + response = SetTaskPushNotificationConfigRequest.model_validate(response_data) + if isinstance(response.root, JSONRPCErrorResponse): + raise A2AClientJSONRPCError(response.root) + return response.root.result + + async def get_task_callback( + self, + request: GetTaskPushNotificationConfigParams, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Retrieves the push notification configuration for a specific task.""" + rpc_request = GetTaskPushNotificationConfigRequest(params=request, id=str(uuid4())) + payload, modified_kwargs = await self._apply_interceptors( + "tasks/pushNotificationConfig/get", + rpc_request.model_dump(mode="json", exclude_none=True), + self._get_http_args(context), + context, + ) + response_data = await self._send_request(payload, modified_kwargs) + response = GetTaskPushNotificationConfigRequest.model_validate(response_data) + if isinstance(response.root, JSONRPCErrorResponse): + raise A2AClientJSONRPCError(response.root) + return response.root.result + + async def resubscribe( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[ + Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent + ]: + """Reconnects to get task updates.""" + rpc_request = TaskResubscriptionRequest(params=request, id=str(uuid4())) + payload, modified_kwargs = await self._apply_interceptors( + "tasks/resubscribe", + rpc_request.model_dump(mode="json", exclude_none=True), + self._get_http_args(context), + context, + ) + + modified_kwargs.setdefault("timeout", None) + + async with aconnect_sse( + self.httpx_client, + "POST", + self.url, + json=payload, + **modified_kwargs, + ) as event_source: + try: + async for sse in event_source.aiter_sse(): + response = SendStreamingMessageRequest.model_validate_json(sse.data) + if isinstance(response.root, JSONRPCErrorResponse): + raise A2AClientJSONRPCError(response.root) + yield response.root.result + except SSEError as e: + raise A2AClientHTTPError( + 400, f"Invalid SSE response or protocol error: {e}" + ) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + except httpx.RequestError as e: + raise A2AClientHTTPError( + 503, f"Network communication error: {e}" + ) from e + + async def get_card( + self, + *, + context: ClientCallContext | None = None, + ) -> AgentCard: + """Retrieves the agent's card.""" + card = self.agent_card + if not card: + resolver = A2ACardResolver(self.httpx_client, self.url) + card = await resolver.get_agent_card(http_kwargs=self._get_http_args(context)) + self._needs_extended_card = card.supports_authenticated_extended_card + self.agent_card = card + + if not self._needs_extended_card: + return card + + payload, modified_kwargs = await self._apply_interceptors( + "card/getAuthenticated", + {}, + self._get_http_args(context), + context, + ) + response_data = await self._send_request(payload, modified_kwargs) + card = AgentCard.model_validate(response_data) + self.agent_card = card + self._needs_extended_card = False + return card + + async def close(self) -> None: + """Closes the httpx client.""" + await self.httpx_client.aclose() diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py new file mode 100644 index 00000000..7c0ef7ab --- /dev/null +++ b/src/a2a/client/transports/rest.py @@ -0,0 +1,383 @@ +import json +import logging + +from collections.abc import AsyncGenerator +from typing import Any + +import httpx + +from google.protobuf.json_format import MessageToDict, Parse, ParseDict +from httpx_sse import SSEError, aconnect_sse + +from a2a.client.card_resolver import A2ACardResolver +from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError +from a2a.client.middleware import ClientCallContext, ClientCallInterceptor +from a2a.client.transports.base import ClientTransport +from a2a.grpc import a2a_pb2 +from a2a.types import ( + AgentCard, + GetTaskPushNotificationConfigParams, + Message, + MessageSendParams, + Task, + TaskArtifactUpdateEvent, + TaskIdParams, + TaskPushNotificationConfig, + TaskQueryParams, + TaskStatusUpdateEvent, +) +from a2a.utils import proto_utils +from a2a.utils.telemetry import SpanKind, trace_class + + +logger = logging.getLogger(__name__) + + +@trace_class(kind=SpanKind.CLIENT) +class RestTransport(ClientTransport): + """A REST transport for the A2A client.""" + + def __init__( + self, + httpx_client: httpx.AsyncClient, + agent_card: AgentCard | None = None, + url: str | None = None, + interceptors: list[ClientCallInterceptor] | None = None, + ): + """Initializes the RestTransport.""" + if agent_card: + self.url = agent_card.url + elif url: + self.url = url + else: + raise ValueError('Must provide either agent_card or url') + if self.url.endswith('/'): + self.url = self.url[:-1] + self.httpx_client = httpx_client + self.agent_card = agent_card + self.interceptors = interceptors or [] + self._needs_extended_card = ( + agent_card.supports_authenticated_extended_card + if agent_card + else True + ) + + async def _apply_interceptors( + self, + request_payload: dict[str, Any], + http_kwargs: dict[str, Any] | None, + context: ClientCallContext | None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + final_http_kwargs = http_kwargs or {} + final_request_payload = request_payload + # TODO: Implement interceptors for other transports + return final_request_payload, final_http_kwargs + + def _get_http_args( + self, context: ClientCallContext | None + ) -> dict[str, Any] | None: + return context.state.get('http_kwargs') if context else None + + async def send_message( + self, + request: MessageSendParams, + *, + context: ClientCallContext | None = None, + ) -> Task | Message: + """Sends a non-streaming message request to the agent.""" + pb = a2a_pb2.SendMessageRequest( + request=proto_utils.ToProto.message(request.message), + configuration=proto_utils.ToProto.message_send_configuration( + request.configuration + ), + metadata=( + proto_utils.ToProto.metadata(request.metadata) + if request.metadata + else None + ), + ) + payload = MessageToDict(pb) + payload, modified_kwargs = await self._apply_interceptors( + payload, + self._get_http_args(context), + context, + ) + response_data = await self._send_post_request( + '/v1/message:send', payload, modified_kwargs + ) + response_pb = a2a_pb2.SendMessageResponse() + ParseDict(response_data, response_pb) + return proto_utils.FromProto.task_or_message(response_pb) + + async def send_message_streaming( + self, + request: MessageSendParams, + *, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[ + Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message + ]: + """Sends a streaming message request to the agent and yields responses as they arrive.""" + pb = a2a_pb2.SendMessageRequest( + request=proto_utils.ToProto.message(request.message), + configuration=proto_utils.ToProto.message_send_configuration( + request.configuration + ), + metadata=( + proto_utils.ToProto.metadata(request.metadata) + if request.metadata + else None + ), + ) + payload = MessageToDict(pb) + payload, modified_kwargs = await self._apply_interceptors( + payload, + self._get_http_args(context), + context, + ) + + modified_kwargs.setdefault('timeout', None) + + async with aconnect_sse( + self.httpx_client, + 'POST', + f'{self.url}/v1/message:stream', + json=payload, + **modified_kwargs, + ) as event_source: + try: + async for sse in event_source.aiter_sse(): + event = a2a_pb2.StreamResponse() + Parse(sse.data, event) + yield proto_utils.FromProto.stream_response(event) + except SSEError as e: + raise A2AClientHTTPError( + 400, f'Invalid SSE response or protocol error: {e}' + ) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + except httpx.RequestError as e: + raise A2AClientHTTPError( + 503, f'Network communication error: {e}' + ) from e + + async def _send_post_request( + self, + target: str, + rpc_request_payload: dict[str, Any], + http_kwargs: dict[str, Any] | None = None, + ) -> dict[str, Any]: + try: + response = await self.httpx_client.post( + f'{self.url}{target}', + json=rpc_request_payload, + **(http_kwargs or {}), + ) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as e: + raise A2AClientHTTPError(e.response.status_code, str(e)) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + except httpx.RequestError as e: + raise A2AClientHTTPError( + 503, f'Network communication error: {e}' + ) from e + + async def _send_get_request( + self, + target: str, + query_params: dict[str, str], + http_kwargs: dict[str, Any] | None = None, + ) -> dict[str, Any]: + try: + response = await self.httpx_client.get( + f'{self.url}{target}', + params=query_params, + **(http_kwargs or {}), + ) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as e: + raise A2AClientHTTPError(e.response.status_code, str(e)) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + except httpx.RequestError as e: + raise A2AClientHTTPError( + 503, f'Network communication error: {e}' + ) from e + + async def get_task( + self, + request: TaskQueryParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + """Retrieves the current state and history of a specific task.""" + payload, modified_kwargs = await self._apply_interceptors( + request.model_dump(mode='json', exclude_none=True), + self._get_http_args(context), + context, + ) + response_data = await self._send_get_request( + f'/v1/tasks/{request.id}', + {'historyLength': str(request.history_length)} + if request.history_length + else {}, + modified_kwargs, + ) + task = a2a_pb2.Task() + ParseDict(response_data, task) + return proto_utils.FromProto.task(task) + + async def cancel_task( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + """Requests the agent to cancel a specific task.""" + pb = a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}') + payload = MessageToDict(pb) + payload, modified_kwargs = await self._apply_interceptors( + payload, + self._get_http_args(context), + context, + ) + response_data = await self._send_post_request( + f'/v1/tasks/{request.id}:cancel', payload, modified_kwargs + ) + task = a2a_pb2.Task() + ParseDict(response_data, task) + return proto_utils.FromProto.task(task) + + async def set_task_callback( + self, + request: TaskPushNotificationConfig, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Sets or updates the push notification configuration for a specific task.""" + pb = a2a_pb2.CreateTaskPushNotificationConfigRequest( + parent=f'tasks/{request.task_id}', + config_id=request.push_notification_config.id, + config=proto_utils.ToProto.task_push_notification_config(request), + ) + payload = MessageToDict(pb) + payload, modified_kwargs = await self._apply_interceptors( + payload, self._get_http_args(context), context + ) + response_data = await self._send_post_request( + f'/v1/tasks/{request.task_id}/pushNotificationConfigs/', + payload, + modified_kwargs, + ) + config = a2a_pb2.TaskPushNotificationConfig() + ParseDict(response_data, config) + return proto_utils.FromProto.task_push_notification_config(config) + + async def get_task_callback( + self, + request: GetTaskPushNotificationConfigParams, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Retrieves the push notification configuration for a specific task.""" + pb = a2a_pb2.GetTaskPushNotificationConfigRequest( + name=f'tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}', + ) + payload = MessageToDict(pb) + payload, modified_kwargs = await self._apply_interceptors( + payload, + self._get_http_args(context), + context, + ) + response_data = await self._send_get_request( + f'/v1/tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}', + {}, + modified_kwargs, + ) + config = a2a_pb2.TaskPushNotificationConfig() + ParseDict(response_data, config) + return proto_utils.FromProto.task_push_notification_config(config) + + async def resubscribe( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[ + Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message + ]: + """Reconnects to get task updates.""" + pb = a2a_pb2.TaskSubscriptionRequest( + name=f'tasks/{request.id}', + ) + payload = MessageToDict(pb) + payload, modified_kwargs = await self._apply_interceptors( + payload, + self._get_http_args(context), + context, + ) + + modified_kwargs.setdefault('timeout', None) + + async with aconnect_sse( + self.httpx_client, + 'GET', + f'{self.url}/v1/tasks/{request.id}:subscribe', + json=payload, + **modified_kwargs, + ) as event_source: + try: + async for sse in event_source.aiter_sse(): + event = a2a_pb2.StreamResponse() + Parse(sse.data, event) + yield proto_utils.FromProto.stream_response(event) + except SSEError as e: + raise A2AClientHTTPError( + 400, f'Invalid SSE response or protocol error: {e}' + ) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + except httpx.RequestError as e: + raise A2AClientHTTPError( + 503, f'Network communication error: {e}' + ) from e + + async def get_card( + self, + *, + context: ClientCallContext | None = None, + ) -> AgentCard: + """Retrieves the agent's card.""" + card = self.agent_card + if not card: + resolver = A2ACardResolver(self.httpx_client, self.url) + card = await resolver.get_agent_card( + http_kwargs=self._get_http_args(context) + ) + self._needs_extended_card = ( + card.supports_authenticated_extended_card + ) + self.agent_card = card + + if not self._needs_extended_card: + return card + + _, modified_kwargs = await self._apply_interceptors( + {}, + self._get_http_args(context), + context, + ) + response_data = await self._send_get_request( + '/v1/card/get', {}, modified_kwargs + ) + card = AgentCard.model_validate(response_data) + self.agent_card = card + self._needs_extended_card = False + return card + + async def close(self) -> None: + """Closes the httpx client.""" + await self.httpx_client.aclose() From 0289cb7ea3f89908e26fa59c20c5220ff0fb911c Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Tue, 29 Jul 2025 20:07:51 +0000 Subject: [PATCH 02/12] Gemini authored: update tests --- src/a2a/client/base_client.py | 18 +- src/a2a/client/client.py | 4 + src/a2a/client/client_factory.py | 10 +- src/a2a/client/transports/base.py | 2 + src/a2a/client/transports/grpc.py | 33 +- src/a2a/client/transports/jsonrpc.py | 110 ++-- tests/client/test_auth_middleware.py | 89 +-- tests/client/test_grpc_client.py | 22 +- tests/client/test_jsonrpc_client.py | 923 +++------------------------ 9 files changed, 253 insertions(+), 958 deletions(-) diff --git a/src/a2a/client/base_client.py b/src/a2a/client/base_client.py index 01600366..f4a8d03d 100644 --- a/src/a2a/client/base_client.py +++ b/src/a2a/client/base_client.py @@ -73,7 +73,9 @@ async def send_message( params = MessageSendParams(message=request, configuration=config) if not self._config.streaming or not self._card.capabilities.streaming: - response = await self._transport.send_message(params, context=context) + response = await self._transport.send_message( + params, context=context + ) result = ( (response, None) if isinstance(response, Task) else response ) @@ -85,6 +87,9 @@ async def send_message( stream = self._transport.send_message_streaming(params, context=context) first_event = await anext(stream) + # The response from a server may be either exactly one Message or a + # series of Task updates. Separate out the first message for special + # case handling, which allows us to simplify further stream processing. if isinstance(first_event, Message): await self.consume(first_event, self._card) yield first_event @@ -102,7 +107,7 @@ async def _process_response( ) -> ClientEvent: if isinstance(event, Message): raise A2AClientInvalidStateError( - "received a streamed Message from server after first response; this is not supported" + 'received a streamed Message from server after first response; this is not supported' ) await tracker.process(event) task = tracker.get_task_or_raise() @@ -201,11 +206,16 @@ async def resubscribe( """ if not self._config.streaming or not self._card.capabilities.streaming: raise NotImplementedError( - "client and/or server do not support resubscription." + 'client and/or server do not support resubscription.' ) tracker = ClientTaskManager() - async for event in self._transport.resubscribe(request, context=context): + # Note: resubscribe can only be called on an existing task. As such, + # we should never see Message updates, despite the typing of the service + # definition indicating it may be possible. + async for event in self._transport.resubscribe( + request, context=context + ): yield await self._process_response(tracker, event) async def get_card( diff --git a/src/a2a/client/client.py b/src/a2a/client/client.py index c51597c0..7cc10423 100644 --- a/src/a2a/client/client.py +++ b/src/a2a/client/client.py @@ -119,6 +119,8 @@ async def send_message( pairs, or a `Message`. Client will also send these values to any configured `Consumer`s in the client. """ + return + yield @abstractmethod async def get_task( @@ -164,6 +166,8 @@ async def resubscribe( context: ClientCallContext | None = None, ) -> AsyncIterator[ClientEvent]: """Resubscribes to a task's event stream.""" + return + yield @abstractmethod async def get_card( diff --git a/src/a2a/client/client_factory.py b/src/a2a/client/client_factory.py index a97f62d2..1e312793 100644 --- a/src/a2a/client/client_factory.py +++ b/src/a2a/client/client_factory.py @@ -111,9 +111,9 @@ def create( transport_protocol = x break if not transport_protocol: - raise ValueError("no compatible transports found.") + raise ValueError('no compatible transports found.') if transport_protocol not in self._registry: - raise ValueError(f"no client available for {transport_protocol}") + raise ValueError(f'no client available for {transport_protocol}') all_consumers = self._consumers.copy() if consumers: @@ -146,8 +146,8 @@ def minimal_agent_card( capabilities=AgentCapabilities(), default_input_modes=[], default_output_modes=[], - description="", + description='', skills=[], - version="", - name="", + version='', + name='', ) diff --git a/src/a2a/client/transports/base.py b/src/a2a/client/transports/base.py index 23bf1c56..ad693f24 100644 --- a/src/a2a/client/transports/base.py +++ b/src/a2a/client/transports/base.py @@ -38,6 +38,7 @@ async def send_message_streaming( Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent ]: """Sends a streaming message request to the agent and yields responses as they arrive.""" + return yield @abstractmethod @@ -86,6 +87,7 @@ async def resubscribe( Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent ]: """Reconnects to get task updates.""" + return yield @abstractmethod diff --git a/src/a2a/client/transports/grpc.py b/src/a2a/client/transports/grpc.py index fcba6f27..e75146bc 100644 --- a/src/a2a/client/transports/grpc.py +++ b/src/a2a/client/transports/grpc.py @@ -7,8 +7,8 @@ import grpc except ImportError as e: raise ImportError( - "A2AGrpcClient requires grpcio and grpcio-tools to be installed. " - "Install with: " + 'A2AGrpcClient requires grpcio and grpcio-tools to be installed. ' + 'Install with: ' "'pip install a2a-sdk[grpc]'" ) from e @@ -47,7 +47,9 @@ def __init__( self.agent_card = agent_card self.stub = grpc_stub self._needs_extended_card = ( - agent_card.supports_authenticated_extended_card if agent_card else True + agent_card.supports_authenticated_extended_card + if agent_card + else True ) async def send_message( @@ -101,7 +103,7 @@ async def resubscribe( ]: """Reconnects to get task updates.""" stream = self.stub.TaskSubscription( - a2a_pb2.TaskSubscriptionRequest(name=f"tasks/{request.id}") + a2a_pb2.TaskSubscriptionRequest(name=f'tasks/{request.id}') ) while True: response = await stream.read() @@ -117,7 +119,7 @@ async def get_task( ) -> Task: """Retrieves the current state and history of a specific task.""" task = await self.stub.GetTask( - a2a_pb2.GetTaskRequest(name=f"tasks/{request.id}") + a2a_pb2.GetTaskRequest(name=f'tasks/{request.id}') ) return proto_utils.FromProto.task(task) @@ -129,7 +131,7 @@ async def cancel_task( ) -> Task: """Requests the agent to cancel a specific task.""" task = await self.stub.CancelTask( - a2a_pb2.CancelTaskRequest(name=f"tasks/{request.id}") + a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}') ) return proto_utils.FromProto.task(task) @@ -142,9 +144,11 @@ async def set_task_callback( """Sets or updates the push notification configuration for a specific task.""" config = await self.stub.CreateTaskPushNotificationConfig( a2a_pb2.CreateTaskPushNotificationConfigRequest( - parent="", - config_id="", - config=proto_utils.ToProto.task_push_notification_config(request), + parent='', + config_id='', + config=proto_utils.ToProto.task_push_notification_config( + request + ), ) ) return proto_utils.FromProto.task_push_notification_config(config) @@ -158,7 +162,7 @@ async def get_task_callback( """Retrieves the push notification configuration for a specific task.""" config = await self.stub.GetTaskPushNotificationConfig( a2a_pb2.GetTaskPushNotificationConfigRequest( - name=f"tasks/{request.id}/pushNotification/{request.push_notification_config_id}", + name=f'tasks/{request.id}/pushNotification/{request.push_notification_config_id}', ) ) return proto_utils.FromProto.task_push_notification_config(config) @@ -170,11 +174,10 @@ async def get_card( ) -> AgentCard: """Retrieves the agent's card.""" card = self.agent_card - if card is None and not self._needs_extended_card: - raise ValueError("Agent card is not available.") - - if not self._needs_extended_card: + if card and not self._needs_extended_card: return card + if card is None and not self._needs_extended_card: + raise ValueError('Agent card is not available.') card_pb = await self.stub.GetAgentCard( a2a_pb2.GetAgentCardRequest(), @@ -186,5 +189,5 @@ async def get_card( async def close(self) -> None: """Closes the gRPC channel.""" - if hasattr(self.stub, "close"): + if hasattr(self.stub, 'close'): await self.stub.close() diff --git a/src/a2a/client/transports/jsonrpc.py b/src/a2a/client/transports/jsonrpc.py index 26f92749..be8bca5e 100644 --- a/src/a2a/client/transports/jsonrpc.py +++ b/src/a2a/client/transports/jsonrpc.py @@ -21,15 +21,21 @@ from a2a.types import ( AgentCard, CancelTaskRequest, + CancelTaskResponse, GetTaskPushNotificationConfigParams, GetTaskPushNotificationConfigRequest, + GetTaskPushNotificationConfigResponse, GetTaskRequest, + GetTaskResponse, JSONRPCErrorResponse, Message, MessageSendParams, SendMessageRequest, + SendMessageResponse, SendStreamingMessageRequest, + SendStreamingMessageResponse, SetTaskPushNotificationConfigRequest, + SetTaskPushNotificationConfigResponse, Task, TaskArtifactUpdateEvent, TaskIdParams, @@ -61,13 +67,15 @@ def __init__( elif url: self.url = url else: - raise ValueError("Must provide either agent_card or url") + raise ValueError('Must provide either agent_card or url') self.httpx_client = httpx_client self.agent_card = agent_card self.interceptors = interceptors or [] self._needs_extended_card = ( - agent_card.supports_authenticated_extended_card if agent_card else True + agent_card.supports_authenticated_extended_card + if agent_card + else True ) async def _apply_interceptors( @@ -93,8 +101,10 @@ async def _apply_interceptors( ) return final_request_payload, final_http_kwargs - def _get_http_args(self, context: ClientCallContext | None) -> dict[str, Any] | None: - return context.state.get("http_kwargs") if context else None + def _get_http_args( + self, context: ClientCallContext | None + ) -> dict[str, Any] | None: + return context.state.get('http_kwargs') if context else None async def send_message( self, @@ -105,13 +115,13 @@ async def send_message( """Sends a non-streaming message request to the agent.""" rpc_request = SendMessageRequest(params=request, id=str(uuid4())) payload, modified_kwargs = await self._apply_interceptors( - "message/send", - rpc_request.model_dump(mode="json", exclude_none=True), + 'message/send', + rpc_request.model_dump(mode='json', exclude_none=True), self._get_http_args(context), context, ) response_data = await self._send_request(payload, modified_kwargs) - response = SendMessageRequest.model_validate(response_data) + response = SendMessageResponse.model_validate(response_data) if isinstance(response.root, JSONRPCErrorResponse): raise A2AClientJSONRPCError(response.root) return response.root.result @@ -125,26 +135,28 @@ async def send_message_streaming( Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent ]: """Sends a streaming message request to the agent and yields responses as they arrive.""" - rpc_request = SendStreamingMessageRequest(params=request, id=str(uuid4())) + rpc_request = SendStreamingMessageRequest( + params=request, id=str(uuid4()) + ) payload, modified_kwargs = await self._apply_interceptors( - "message/stream", - rpc_request.model_dump(mode="json", exclude_none=True), + 'message/stream', + rpc_request.model_dump(mode='json', exclude_none=True), self._get_http_args(context), context, ) - modified_kwargs.setdefault("timeout", None) + modified_kwargs.setdefault('timeout', None) async with aconnect_sse( self.httpx_client, - "POST", + 'POST', self.url, json=payload, **modified_kwargs, ) as event_source: try: async for sse in event_source.aiter_sse(): - response = SendStreamingMessageRequest.model_validate( + response = SendStreamingMessageResponse.model_validate( json.loads(sse.data) ) if isinstance(response.root, JSONRPCErrorResponse): @@ -152,13 +164,13 @@ async def send_message_streaming( yield response.root.result except SSEError as e: raise A2AClientHTTPError( - 400, f"Invalid SSE response or protocol error: {e}" + 400, f'Invalid SSE response or protocol error: {e}' ) from e except json.JSONDecodeError as e: raise A2AClientJSONError(str(e)) from e except httpx.RequestError as e: raise A2AClientHTTPError( - 503, f"Network communication error: {e}" + 503, f'Network communication error: {e}' ) from e async def _send_request( @@ -173,13 +185,15 @@ async def _send_request( response.raise_for_status() return response.json() except httpx.ReadTimeout as e: - raise A2AClientTimeoutError("Client Request timed out") from e + raise A2AClientTimeoutError('Client Request timed out') from e except httpx.HTTPStatusError as e: raise A2AClientHTTPError(e.response.status_code, str(e)) from e except json.JSONDecodeError as e: raise A2AClientJSONError(str(e)) from e except httpx.RequestError as e: - raise A2AClientHTTPError(503, f"Network communication error: {e}") from e + raise A2AClientHTTPError( + 503, f'Network communication error: {e}' + ) from e async def get_task( self, @@ -190,13 +204,13 @@ async def get_task( """Retrieves the current state and history of a specific task.""" rpc_request = GetTaskRequest(params=request, id=str(uuid4())) payload, modified_kwargs = await self._apply_interceptors( - "tasks/get", - rpc_request.model_dump(mode="json", exclude_none=True), + 'tasks/get', + rpc_request.model_dump(mode='json', exclude_none=True), self._get_http_args(context), context, ) response_data = await self._send_request(payload, modified_kwargs) - response = GetTaskRequest.model_validate(response_data) + response = GetTaskResponse.model_validate(response_data) if isinstance(response.root, JSONRPCErrorResponse): raise A2AClientJSONRPCError(response.root) return response.root.result @@ -210,13 +224,13 @@ async def cancel_task( """Requests the agent to cancel a specific task.""" rpc_request = CancelTaskRequest(params=request, id=str(uuid4())) payload, modified_kwargs = await self._apply_interceptors( - "tasks/cancel", - rpc_request.model_dump(mode="json", exclude_none=True), + 'tasks/cancel', + rpc_request.model_dump(mode='json', exclude_none=True), self._get_http_args(context), context, ) response_data = await self._send_request(payload, modified_kwargs) - response = CancelTaskRequest.model_validate(response_data) + response = CancelTaskResponse.model_validate(response_data) if isinstance(response.root, JSONRPCErrorResponse): raise A2AClientJSONRPCError(response.root) return response.root.result @@ -228,15 +242,19 @@ async def set_task_callback( context: ClientCallContext | None = None, ) -> TaskPushNotificationConfig: """Sets or updates the push notification configuration for a specific task.""" - rpc_request = SetTaskPushNotificationConfigRequest(params=request, id=str(uuid4())) + rpc_request = SetTaskPushNotificationConfigRequest( + params=request, id=str(uuid4()) + ) payload, modified_kwargs = await self._apply_interceptors( - "tasks/pushNotificationConfig/set", - rpc_request.model_dump(mode="json", exclude_none=True), + 'tasks/pushNotificationConfig/set', + rpc_request.model_dump(mode='json', exclude_none=True), self._get_http_args(context), context, ) response_data = await self._send_request(payload, modified_kwargs) - response = SetTaskPushNotificationConfigRequest.model_validate(response_data) + response = SetTaskPushNotificationConfigResponse.model_validate( + response_data + ) if isinstance(response.root, JSONRPCErrorResponse): raise A2AClientJSONRPCError(response.root) return response.root.result @@ -248,15 +266,19 @@ async def get_task_callback( context: ClientCallContext | None = None, ) -> TaskPushNotificationConfig: """Retrieves the push notification configuration for a specific task.""" - rpc_request = GetTaskPushNotificationConfigRequest(params=request, id=str(uuid4())) + rpc_request = GetTaskPushNotificationConfigRequest( + params=request, id=str(uuid4()) + ) payload, modified_kwargs = await self._apply_interceptors( - "tasks/pushNotificationConfig/get", - rpc_request.model_dump(mode="json", exclude_none=True), + 'tasks/pushNotificationConfig/get', + rpc_request.model_dump(mode='json', exclude_none=True), self._get_http_args(context), context, ) response_data = await self._send_request(payload, modified_kwargs) - response = GetTaskPushNotificationConfigRequest.model_validate(response_data) + response = GetTaskPushNotificationConfigResponse.model_validate( + response_data + ) if isinstance(response.root, JSONRPCErrorResponse): raise A2AClientJSONRPCError(response.root) return response.root.result @@ -272,36 +294,38 @@ async def resubscribe( """Reconnects to get task updates.""" rpc_request = TaskResubscriptionRequest(params=request, id=str(uuid4())) payload, modified_kwargs = await self._apply_interceptors( - "tasks/resubscribe", - rpc_request.model_dump(mode="json", exclude_none=True), + 'tasks/resubscribe', + rpc_request.model_dump(mode='json', exclude_none=True), self._get_http_args(context), context, ) - modified_kwargs.setdefault("timeout", None) + modified_kwargs.setdefault('timeout', None) async with aconnect_sse( self.httpx_client, - "POST", + 'POST', self.url, json=payload, **modified_kwargs, ) as event_source: try: async for sse in event_source.aiter_sse(): - response = SendStreamingMessageRequest.model_validate_json(sse.data) + response = SendStreamingMessageResponse.model_validate_json( + sse.data + ) if isinstance(response.root, JSONRPCErrorResponse): raise A2AClientJSONRPCError(response.root) yield response.root.result except SSEError as e: raise A2AClientHTTPError( - 400, f"Invalid SSE response or protocol error: {e}" + 400, f'Invalid SSE response or protocol error: {e}' ) from e except json.JSONDecodeError as e: raise A2AClientJSONError(str(e)) from e except httpx.RequestError as e: raise A2AClientHTTPError( - 503, f"Network communication error: {e}" + 503, f'Network communication error: {e}' ) from e async def get_card( @@ -313,15 +337,19 @@ async def get_card( card = self.agent_card if not card: resolver = A2ACardResolver(self.httpx_client, self.url) - card = await resolver.get_agent_card(http_kwargs=self._get_http_args(context)) - self._needs_extended_card = card.supports_authenticated_extended_card + card = await resolver.get_agent_card( + http_kwargs=self._get_http_args(context) + ) + self._needs_extended_card = ( + card.supports_authenticated_extended_card + ) self.agent_card = card if not self._needs_extended_card: return card payload, modified_kwargs = await self._apply_interceptors( - "card/getAuthenticated", + 'agent/getAuthenticatedExtendedCard', {}, self._get_http_args(context), context, diff --git a/tests/client/test_auth_middleware.py b/tests/client/test_auth_middleware.py index 55fb5b8b..1c37992f 100644 --- a/tests/client/test_auth_middleware.py +++ b/tests/client/test_auth_middleware.py @@ -1,3 +1,4 @@ +import json from collections.abc import Callable from dataclasses import dataclass from typing import Any @@ -6,25 +7,30 @@ import pytest import respx -from a2a.client import A2AClient -from a2a.client.auth import AuthInterceptor, InMemoryContextCredentialStore -from a2a.client.middleware import ClientCallContext, ClientCallInterceptor +from a2a.client import ( + AuthInterceptor, + Client, + ClientCallContext, + ClientCallInterceptor, + ClientConfig, + ClientFactory, + InMemoryContextCredentialStore, +) from a2a.types import ( - APIKeySecurityScheme, AgentCapabilities, AgentCard, + APIKeySecurityScheme, AuthorizationCodeOAuthFlow, HTTPAuthSecurityScheme, In, Message, - MessageSendParams, OAuth2SecurityScheme, OAuthFlows, OpenIdConnectSecurityScheme, Role, SecurityScheme, - SendMessageRequest, SendMessageSuccessResponse, + TransportProtocol, ) @@ -49,10 +55,11 @@ async def intercept( return request_payload, http_kwargs -def build_success_response() -> dict: - """Creates a valid JSON-RPC success response as dict.""" - return SendMessageSuccessResponse( - id='1', +def build_success_response(request: httpx.Request) -> httpx.Response: + """Creates a valid JSON-RPC success response based on the request.""" + request_payload = json.loads(request.content) + response_payload = SendMessageSuccessResponse( + id=request_payload['id'], jsonrpc='2.0', result=Message( kind='message', @@ -61,41 +68,33 @@ def build_success_response() -> dict: parts=[], ), ).model_dump(mode='json') + return httpx.Response(200, json=response_payload) -def build_send_message_request() -> SendMessageRequest: - """Builds a minimal SendMessageRequest.""" - return SendMessageRequest( - id='1', - params=MessageSendParams( - message=Message( - message_id='msg1', - role=Role.user, - parts=[], - ) - ), +def build_message() -> Message: + """Builds a minimal Message.""" + return Message( + message_id='msg1', + role=Role.user, + parts=[], ) async def send_message( - client: A2AClient, + client: Client, url: str, session_id: str | None = None, ) -> httpx.Request: """Mocks the response and sends a message using the client.""" - respx.post(url).mock( - return_value=httpx.Response( - 200, - json=build_success_response(), - ) - ) + respx.post(url).mock(side_effect=build_success_response) context = ClientCallContext( state={'sessionId': session_id} if session_id else {} ) - await client.send_message( - request=build_send_message_request(), + async for _ in client.send_message( + request=build_message(), context=context, - ) + ): + pass return respx.calls.last.request @@ -170,11 +169,26 @@ async def test_client_with_simple_interceptor(): """ url = 'http://agent.com/rpc' interceptor = HeaderInterceptor('X-Test-Header', 'Test-Value-123') + card = AgentCard( + url=url, + name='testbot', + description='test bot', + version='1.0', + default_input_modes=[], + default_output_modes=[], + skills=[], + capabilities=AgentCapabilities(), + preferred_transport=TransportProtocol.jsonrpc, + ) async with httpx.AsyncClient() as http_client: - client = A2AClient( - httpx_client=http_client, url=url, interceptors=[interceptor] + config = ClientConfig( + httpx_client=http_client, + supported_transports=[TransportProtocol.jsonrpc], ) + factory = ClientFactory(config) + client = factory.create(card, interceptors=[interceptor]) + request = await send_message(client, url) assert request.headers['x-test-header'] == 'Test-Value-123' @@ -293,14 +307,17 @@ async def test_auth_interceptor_variants(test_case, store): root=test_case.security_scheme ) }, + preferred_transport=TransportProtocol.jsonrpc, ) async with httpx.AsyncClient() as http_client: - client = A2AClient( + config = ClientConfig( httpx_client=http_client, - agent_card=agent_card, - interceptors=[auth_interceptor], + supported_transports=[TransportProtocol.jsonrpc], ) + factory = ClientFactory(config) + client = factory.create(agent_card, interceptors=[auth_interceptor]) + request = await send_message( client, test_case.url, test_case.session_id ) diff --git a/tests/client/test_grpc_client.py b/tests/client/test_grpc_client.py index 26967d73..7a9d6830 100644 --- a/tests/client/test_grpc_client.py +++ b/tests/client/test_grpc_client.py @@ -2,7 +2,7 @@ import pytest -from a2a.client import A2AGrpcClient +from a2a.client.transports.grpc import GrpcTransport from a2a.grpc import a2a_pb2, a2a_pb2_grpc from a2a.types import ( AgentCapabilities, @@ -51,11 +51,11 @@ def sample_agent_card() -> AgentCard: @pytest.fixture -def grpc_client( +def grpc_transport( mock_grpc_stub: AsyncMock, sample_agent_card: AgentCard -) -> A2AGrpcClient: - """Provides an A2AGrpcClient instance.""" - return A2AGrpcClient(grpc_stub=mock_grpc_stub, agent_card=sample_agent_card) +) -> GrpcTransport: + """Provides a GrpcTransport instance.""" + return GrpcTransport(grpc_stub=mock_grpc_stub, agent_card=sample_agent_card) @pytest.fixture @@ -92,7 +92,7 @@ def sample_message() -> Message: @pytest.mark.asyncio async def test_send_message_task_response( - grpc_client: A2AGrpcClient, + grpc_transport: GrpcTransport, mock_grpc_stub: AsyncMock, sample_message_send_params: MessageSendParams, sample_task: Task, @@ -102,7 +102,7 @@ async def test_send_message_task_response( task=proto_utils.ToProto.task(sample_task) ) - response = await grpc_client.send_message(sample_message_send_params) + response = await grpc_transport.send_message(sample_message_send_params) mock_grpc_stub.SendMessage.assert_awaited_once() assert isinstance(response, Task) @@ -111,13 +111,13 @@ async def test_send_message_task_response( @pytest.mark.asyncio async def test_get_task( - grpc_client: A2AGrpcClient, mock_grpc_stub: AsyncMock, sample_task: Task + grpc_transport: GrpcTransport, mock_grpc_stub: AsyncMock, sample_task: Task ): """Test retrieving a task.""" mock_grpc_stub.GetTask.return_value = proto_utils.ToProto.task(sample_task) params = TaskQueryParams(id=sample_task.id) - response = await grpc_client.get_task(params) + response = await grpc_transport.get_task(params) mock_grpc_stub.GetTask.assert_awaited_once_with( a2a_pb2.GetTaskRequest(name=f'tasks/{sample_task.id}') @@ -127,7 +127,7 @@ async def test_get_task( @pytest.mark.asyncio async def test_cancel_task( - grpc_client: A2AGrpcClient, mock_grpc_stub: AsyncMock, sample_task: Task + grpc_transport: GrpcTransport, mock_grpc_stub: AsyncMock, sample_task: Task ): """Test cancelling a task.""" cancelled_task = sample_task.model_copy() @@ -137,7 +137,7 @@ async def test_cancel_task( ) params = TaskIdParams(id=sample_task.id) - response = await grpc_client.cancel_task(params) + response = await grpc_transport.cancel_task(params) mock_grpc_stub.CancelTask.assert_awaited_once_with( a2a_pb2.CancelTaskRequest(name=f'tasks/{sample_task.id}') diff --git a/tests/client/test_jsonrpc_client.py b/tests/client/test_jsonrpc_client.py index c1ecc7ff..9698129f 100644 --- a/tests/client/test_jsonrpc_client.py +++ b/tests/client/test_jsonrpc_client.py @@ -1,55 +1,37 @@ import json - from collections.abc import AsyncGenerator from typing import Any from unittest.mock import ANY, AsyncMock, MagicMock, patch import httpx import pytest - -from httpx_sse import EventSource, SSEError, ServerSentEvent +from httpx_sse import EventSource, ServerSentEvent, SSEError from a2a.client import ( A2ACardResolver, - A2AClient, A2AClientHTTPError, A2AClientJSONError, A2AClientTimeoutError, create_text_message_object, ) +from a2a.client.transports.jsonrpc import JsonRpcTransport from a2a.types import ( - A2ARequest, AgentCapabilities, AgentCard, AgentSkill, - CancelTaskRequest, - CancelTaskResponse, - CancelTaskSuccessResponse, - GetTaskPushNotificationConfigRequest, - GetTaskPushNotificationConfigResponse, - GetTaskPushNotificationConfigSuccessResponse, - GetTaskRequest, - GetTaskResponse, InvalidParamsError, - JSONRPCErrorResponse, + Message, MessageSendParams, PushNotificationConfig, Role, - SendMessageRequest, - SendMessageResponse, SendMessageSuccessResponse, - SendStreamingMessageRequest, - SendStreamingMessageResponse, - SetTaskPushNotificationConfigRequest, - SetTaskPushNotificationConfigResponse, - SetTaskPushNotificationConfigSuccessResponse, + Task, TaskIdParams, TaskNotCancelableError, TaskPushNotificationConfig, TaskQueryParams, ) - AGENT_CARD = AgentCard( name='Hello World Agent', description='Just a hello world agent', @@ -116,15 +98,13 @@ def mock_httpx_client() -> AsyncMock: @pytest.fixture def mock_agent_card() -> MagicMock: mock = MagicMock(spec=AgentCard, url='http://agent.example.com/api') - # The attribute is accessed in the client's __init__ to determine if an - # extended card needs to be fetched. mock.supports_authenticated_extended_card = False return mock async def async_iterable_from_list( items: list[ServerSentEvent], -) -> AsyncGenerator[ServerSentEvent]: +) -> AsyncGenerator[ServerSentEvent, None]: """Helper to create an async iterable from a list.""" for item in items: yield item @@ -134,9 +114,7 @@ class TestA2ACardResolver: BASE_URL = 'http://example.com' AGENT_CARD_PATH = '/.well-known/agent.json' FULL_AGENT_CARD_URL = f'{BASE_URL}{AGENT_CARD_PATH}' - EXTENDED_AGENT_CARD_PATH = ( - '/agent/authenticatedExtendedCard' # Default path - ) + EXTENDED_AGENT_CARD_PATH = '/agent/authenticatedExtendedCard' @pytest.mark.asyncio async def test_init_parameters_stored_correctly( @@ -153,7 +131,6 @@ async def test_init_parameters_stored_correctly( assert resolver.agent_card_path == custom_path.lstrip('/') assert resolver.httpx_client == mock_httpx_client - # Test default agent_card_path resolver_default_path = A2ACardResolver( httpx_client=mock_httpx_client, base_url=base_url, @@ -164,13 +141,10 @@ async def test_init_parameters_stored_correctly( async def test_init_strips_slashes(self, mock_httpx_client: AsyncMock): resolver = A2ACardResolver( httpx_client=mock_httpx_client, - base_url='http://example.com/', # With trailing slash - agent_card_path='/.well-known/agent.json/', # With leading/trailing slash + base_url='http://example.com/', + agent_card_path='/.well-known/agent.json/', ) - assert ( - resolver.base_url == 'http://example.com' - ) # Trailing slash stripped - # constructor lstrips agent_card_path, but keeps trailing if provided + assert resolver.base_url == 'http://example.com' assert resolver.agent_card_path == '.well-known/agent.json/' @pytest.mark.asyncio @@ -195,7 +169,6 @@ async def test_get_agent_card_success_public_only( mock_response.raise_for_status.assert_called_once() assert isinstance(agent_card, AgentCard) assert agent_card == AGENT_CARD - # Ensure only one call was made (for the public card) assert mock_httpx_client.get.call_count == 1 @pytest.mark.asyncio @@ -207,8 +180,6 @@ async def test_get_agent_card_success_with_specified_path_for_extended_card( extended_card_response.json.return_value = ( AGENT_CARD_EXTENDED.model_dump(mode='json') ) - - # Mock the single call for the extended card mock_httpx_client.get.return_value = extended_card_response resolver = A2ACardResolver( @@ -217,7 +188,6 @@ async def test_get_agent_card_success_with_specified_path_for_extended_card( agent_card_path=self.AGENT_CARD_PATH, ) - # Fetch the extended card by providing its relative path and example auth auth_kwargs = {'headers': {'Authorization': 'Bearer test token'}} agent_card_result = await resolver.get_agent_card( relative_card_path=self.EXTENDED_AGENT_CARD_PATH, @@ -231,11 +201,8 @@ async def test_get_agent_card_success_with_specified_path_for_extended_card( expected_extended_url, **auth_kwargs ) extended_card_response.raise_for_status.assert_called_once() - assert isinstance(agent_card_result, AgentCard) - assert ( - agent_card_result == AGENT_CARD_EXTENDED - ) # Should return the extended card + assert agent_card_result == AGENT_CARD_EXTENDED @pytest.mark.asyncio async def test_get_agent_card_validation_error( @@ -243,7 +210,6 @@ async def test_get_agent_card_validation_error( ): mock_response = AsyncMock(spec=httpx.Response) mock_response.status_code = 200 - # Data that will cause a Pydantic ValidationError mock_response.json.return_value = { 'invalid_field': 'value', 'name': 'Test Agent', @@ -253,31 +219,23 @@ async def test_get_agent_card_validation_error( resolver = A2ACardResolver( httpx_client=mock_httpx_client, base_url=self.BASE_URL ) - # The call that is expected to raise an error should be within pytest.raises with pytest.raises(A2AClientJSONError) as exc_info: - await resolver.get_agent_card() # Fetches from default path + await resolver.get_agent_card() assert ( f'Failed to validate agent card structure from {self.FULL_AGENT_CARD_URL}' in str(exc_info.value) ) - assert 'invalid_field' in str( - exc_info.value - ) # Check if Pydantic error details are present - assert ( - mock_httpx_client.get.call_count == 1 - ) # Should only be called once + assert 'invalid_field' in str(exc_info.value) + assert mock_httpx_client.get.call_count == 1 @pytest.mark.asyncio async def test_get_agent_card_http_status_error( self, mock_httpx_client: AsyncMock ): - mock_response = MagicMock( - spec=httpx.Response - ) # Use MagicMock for response attribute + mock_response = MagicMock(spec=httpx.Response) mock_response.status_code = 404 mock_response.text = 'Not Found' - http_status_error = httpx.HTTPStatusError( 'Not Found', request=MagicMock(), response=mock_response ) @@ -306,7 +264,6 @@ async def test_get_agent_card_json_decode_error( ): mock_response = AsyncMock(spec=httpx.Response) mock_response.status_code = 200 - # Define json_error before using it json_error = json.JSONDecodeError('Expecting value', 'doc', 0) mock_response.json.side_effect = json_error mock_httpx_client.get.return_value = mock_response @@ -320,7 +277,6 @@ async def test_get_agent_card_json_decode_error( with pytest.raises(A2AClientJSONError) as exc_info: await resolver.get_agent_card() - # Assertions using exc_info must be after the with block assert ( f'Failed to parse JSON for agent card from {self.FULL_AGENT_CARD_URL}' in str(exc_info.value) @@ -353,430 +309,153 @@ async def test_get_agent_card_request_error( mock_httpx_client.get.assert_called_once_with(self.FULL_AGENT_CARD_URL) -class TestA2AClient: +class TestJsonRpcTransport: AGENT_URL = 'http://agent.example.com/api' def test_init_with_agent_card( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): - client = A2AClient( + client = JsonRpcTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card ) assert client.url == mock_agent_card.url assert client.httpx_client == mock_httpx_client def test_init_with_url(self, mock_httpx_client: AsyncMock): - client = A2AClient(httpx_client=mock_httpx_client, url=self.AGENT_URL) + client = JsonRpcTransport( + httpx_client=mock_httpx_client, url=self.AGENT_URL + ) assert client.url == self.AGENT_URL assert client.httpx_client == mock_httpx_client def test_init_with_agent_card_and_url_prioritizes_agent_card( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): - client = A2AClient( + client = JsonRpcTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card, url='http://otherurl.com', ) - assert ( - client.url == mock_agent_card.url - ) # Agent card URL should be used + assert client.url == mock_agent_card.url def test_init_raises_value_error_if_no_card_or_url( self, mock_httpx_client: AsyncMock ): with pytest.raises(ValueError) as exc_info: - A2AClient(httpx_client=mock_httpx_client) + JsonRpcTransport(httpx_client=mock_httpx_client) assert 'Must provide either agent_card or url' in str(exc_info.value) @pytest.mark.asyncio - async def test_get_client_from_agent_card_url_success( + async def test_send_message_success( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): - base_url = 'http://example.com' - agent_card_path = '/.well-known/custom-agent.json' - resolver_kwargs = {'timeout': 30} - - mock_resolver_instance = AsyncMock(spec=A2ACardResolver) - mock_resolver_instance.get_agent_card.return_value = mock_agent_card - - with patch( - 'a2a.client.jsonrpc_client.A2ACardResolver', - return_value=mock_resolver_instance, - ) as mock_resolver_class: - client = await A2AClient.get_client_from_agent_card_url( - httpx_client=mock_httpx_client, - base_url=base_url, - agent_card_path=agent_card_path, - http_kwargs=resolver_kwargs, - ) - - mock_resolver_class.assert_called_once_with( - mock_httpx_client, - base_url=base_url, - agent_card_path=agent_card_path, - ) - mock_resolver_instance.get_agent_card.assert_called_once_with( - http_kwargs=resolver_kwargs, - # relative_card_path=None is implied by not passing it - ) - assert isinstance(client, A2AClient) - assert client.url == mock_agent_card.url - assert client.httpx_client == mock_httpx_client - - @pytest.mark.asyncio - async def test_get_client_from_agent_card_url_resolver_error( - self, mock_httpx_client: AsyncMock - ): - error_to_raise = A2AClientHTTPError(404, 'Agent card not found') - with patch( - 'a2a.client.jsonrpc_client.A2ACardResolver.get_agent_card', - new_callable=AsyncMock, - side_effect=error_to_raise, - ): - with pytest.raises(A2AClientHTTPError) as exc_info: - await A2AClient.get_client_from_agent_card_url( - httpx_client=mock_httpx_client, - base_url='http://example.com', - ) - assert exc_info.value == error_to_raise - - @pytest.mark.asyncio - async def test_send_message_success_use_request( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - client = A2AClient( + client = JsonRpcTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card ) - params = MessageSendParams( message=create_text_message_object(content='Hello') ) - - request = SendMessageRequest(id=123, params=params) - success_response = create_text_message_object( role=Role.agent, content='Hi there!' - ).model_dump(exclude_none=True) + ) + rpc_response = SendMessageSuccessResponse( + id='123', jsonrpc='2.0', result=success_response + ) + response = httpx.Response( + 200, json=rpc_response.model_dump(mode='json') + ) + response.request = httpx.Request('POST', 'http://agent.example.com/api') + mock_httpx_client.post.return_value = response - rpc_response: dict[str, Any] = { - 'id': 123, - 'jsonrpc': '2.0', - 'result': success_response, - } + response = await client.send_message(request=params) - with patch.object( - client, '_send_request', new_callable=AsyncMock - ) as mock_send_req: - mock_send_req.return_value = rpc_response - response = await client.send_message( - request=request, http_kwargs={'timeout': 10} - ) - - assert mock_send_req.call_count == 1 - called_args, called_kwargs = mock_send_req.call_args - assert not called_kwargs # no kwargs to _send_request - assert len(called_args) == 2 - json_rpc_request: dict[str, Any] = called_args[0] - assert isinstance(json_rpc_request['id'], int) - http_kwargs: dict[str, Any] = called_args[1] - assert http_kwargs['timeout'] == 10 - - a2a_request_arg = A2ARequest.model_validate(json_rpc_request) - assert isinstance(a2a_request_arg.root, SendMessageRequest) - assert isinstance(a2a_request_arg.root.params, MessageSendParams) - - assert a2a_request_arg.root.params.model_dump( - exclude_none=True - ) == params.model_dump(exclude_none=True) - - assert isinstance(response, SendMessageResponse) - assert isinstance(response.root, SendMessageSuccessResponse) - assert ( - response.root.result.model_dump(exclude_none=True) - == success_response - ) + assert isinstance(response, Message) + assert response.model_dump() == success_response.model_dump() @pytest.mark.asyncio async def test_send_message_error_response( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): - client = A2AClient( + client = JsonRpcTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card ) - params = MessageSendParams( message=create_text_message_object(content='Hello') ) - - request = SendMessageRequest(id=123, params=params) - error_response = InvalidParamsError() - - rpc_response: dict[str, Any] = { - 'id': 123, + rpc_response = { + 'id': '123', 'jsonrpc': '2.0', 'error': error_response.model_dump(exclude_none=True), } + mock_httpx_client.post.return_value.json.return_value = rpc_response - with patch.object( - client, '_send_request', new_callable=AsyncMock - ) as mock_send_req: - mock_send_req.return_value = rpc_response - response = await client.send_message(request=request) - - assert isinstance(response, SendMessageResponse) - assert isinstance(response.root, JSONRPCErrorResponse) - assert response.root.error.model_dump( - exclude_none=True - ) == InvalidParamsError().model_dump(exclude_none=True) + with pytest.raises(Exception): + await client.send_message(request=params) @pytest.mark.asyncio - @patch('a2a.client.jsonrpc_client.aconnect_sse') - async def test_send_message_streaming_success_request( + @patch('a2a.client.transports.jsonrpc.aconnect_sse') + async def test_send_message_streaming_success( self, mock_aconnect_sse: AsyncMock, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock, ): - client = A2AClient( + client = JsonRpcTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card ) params = MessageSendParams( message=create_text_message_object(content='Hello stream') ) - - request = SendStreamingMessageRequest(id=123, params=params) - - mock_stream_response_1_dict: dict[str, Any] = { - 'id': 'stream_id_123', - 'jsonrpc': '2.0', - 'result': create_text_message_object( + mock_stream_response_1 = SendMessageSuccessResponse( + id='stream_id_123', + jsonrpc='2.0', + result=create_text_message_object( content='First part ', role=Role.agent - ).model_dump(mode='json', exclude_none=True), - } - mock_stream_response_2_dict: dict[str, Any] = { - 'id': 'stream_id_123', - 'jsonrpc': '2.0', - 'result': create_text_message_object( + ), + ) + mock_stream_response_2 = SendMessageSuccessResponse( + id='stream_id_123', + jsonrpc='2.0', + result=create_text_message_object( content='second part ', role=Role.agent - ).model_dump(mode='json', exclude_none=True), - } - + ), + ) sse_event_1 = ServerSentEvent( - data=json.dumps(mock_stream_response_1_dict) + data=mock_stream_response_1.model_dump_json() ) sse_event_2 = ServerSentEvent( - data=json.dumps(mock_stream_response_2_dict) + data=mock_stream_response_2.model_dump_json() ) - - mock_event_source = AsyncMock(spec=EventSource) - with patch.object(mock_event_source, 'aiter_sse') as mock_aiter_sse: - mock_aiter_sse.return_value = async_iterable_from_list( - [sse_event_1, sse_event_2] - ) - mock_aconnect_sse.return_value.__aenter__.return_value = ( - mock_event_source - ) - - results: list[Any] = [] - async for response in client.send_message_streaming( - request=request - ): - results.append(response) - - assert len(results) == 2 - assert isinstance(results[0], SendStreamingMessageResponse) - # Assuming SendStreamingMessageResponse is a RootModel like SendMessageResponse - assert results[0].root.id == 'stream_id_123' - assert ( - results[0].root.result.model_dump( # type: ignore - mode='json', exclude_none=True - ) - == mock_stream_response_1_dict['result'] - ) - - assert isinstance(results[1], SendStreamingMessageResponse) - assert results[1].root.id == 'stream_id_123' - assert ( - results[1].root.result.model_dump( # type: ignore - mode='json', exclude_none=True - ) - == mock_stream_response_2_dict['result'] - ) - - mock_aconnect_sse.assert_called_once() - call_args, call_kwargs = mock_aconnect_sse.call_args - assert call_args[0] == mock_httpx_client - assert call_args[1] == 'POST' - assert call_args[2] == mock_agent_card.url - - sent_json_payload = call_kwargs['json'] - assert sent_json_payload['method'] == 'message/stream' - assert sent_json_payload['params'] == params.model_dump( - mode='json', exclude_none=True - ) - assert ( - call_kwargs['timeout'] is None - ) # Default timeout for streaming - - @pytest.mark.asyncio - @patch('a2a.client.jsonrpc_client.aconnect_sse') - async def test_send_message_streaming_http_kwargs_passed( - self, - mock_aconnect_sse: AsyncMock, - mock_httpx_client: AsyncMock, - mock_agent_card: MagicMock, - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - params = MessageSendParams( - message=create_text_message_object(content='Stream with kwargs') - ) - request = SendStreamingMessageRequest(id='kwarg_req', params=params) - custom_kwargs = { - 'headers': {'X-Custom-Header': 'TestValue'}, - 'timeout': 60, - } - - # Setup mock_aconnect_sse to behave minimally mock_event_source = AsyncMock(spec=EventSource) mock_event_source.aiter_sse.return_value = async_iterable_from_list( - [] - ) # No events needed for this test - mock_aconnect_sse.return_value.__aenter__.return_value = ( - mock_event_source - ) - - async for _ in client.send_message_streaming( - request=request, http_kwargs=custom_kwargs - ): - pass # We just want to check the call to aconnect_sse - - mock_aconnect_sse.assert_called_once() - _, called_kwargs = mock_aconnect_sse.call_args - assert called_kwargs['headers'] == custom_kwargs['headers'] - assert ( - called_kwargs['timeout'] == custom_kwargs['timeout'] - ) # Ensure custom timeout is used - - @pytest.mark.asyncio - @patch('a2a.client.jsonrpc_client.aconnect_sse') - async def test_send_message_streaming_sse_error_handling( - self, - mock_aconnect_sse: AsyncMock, - mock_httpx_client: AsyncMock, - mock_agent_card: MagicMock, - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - request = SendStreamingMessageRequest( - id='sse_err_req', - params=MessageSendParams( - message=create_text_message_object(content='SSE error test') - ), - ) - - # Configure the mock aconnect_sse to raise SSEError when aiter_sse is called - mock_event_source = AsyncMock(spec=EventSource) - mock_event_source.aiter_sse.side_effect = SSEError( - 'Simulated SSE protocol error' + [sse_event_1, sse_event_2] ) mock_aconnect_sse.return_value.__aenter__.return_value = ( mock_event_source ) - with pytest.raises(A2AClientHTTPError) as exc_info: - async for _ in client.send_message_streaming(request=request): - pass - - assert exc_info.value.status_code == 400 # As per client implementation - assert 'Invalid SSE response or protocol error' in str(exc_info.value) - assert 'Simulated SSE protocol error' in str(exc_info.value) - - @pytest.mark.asyncio - @patch('a2a.client.jsonrpc_client.aconnect_sse') - async def test_send_message_streaming_json_decode_error_handling( - self, - mock_aconnect_sse: AsyncMock, - mock_httpx_client: AsyncMock, - mock_agent_card: MagicMock, - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - request = SendStreamingMessageRequest( - id='json_err_req', - params=MessageSendParams( - message=create_text_message_object(content='JSON error test') - ), - ) - - # Malformed JSON event - malformed_sse_event = ServerSentEvent(data='not valid json') + results = [ + item async for item in client.send_message_streaming(request=params) + ] - mock_event_source = AsyncMock(spec=EventSource) - # json.loads will be called on "not valid json" and raise JSONDecodeError - mock_event_source.aiter_sse.return_value = async_iterable_from_list( - [malformed_sse_event] + assert len(results) == 2 + assert isinstance(results[0], Message) + assert ( + results[0].model_dump() + == mock_stream_response_1.result.model_dump() ) - mock_aconnect_sse.return_value.__aenter__.return_value = ( - mock_event_source + assert isinstance(results[1], Message) + assert ( + results[1].model_dump() + == mock_stream_response_2.result.model_dump() ) - with pytest.raises(A2AClientJSONError) as exc_info: - async for _ in client.send_message_streaming(request=request): - pass - - assert 'Expecting value: line 1 column 1 (char 0)' in str( - exc_info.value - ) # Example of JSONDecodeError message - - @pytest.mark.asyncio - @patch('a2a.client.jsonrpc_client.aconnect_sse') - async def test_send_message_streaming_httpx_request_error_handling( - self, - mock_aconnect_sse: AsyncMock, - mock_httpx_client: AsyncMock, - mock_agent_card: MagicMock, - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - request = SendStreamingMessageRequest( - id='httpx_err_req', - params=MessageSendParams( - message=create_text_message_object(content='httpx error test') - ), - ) - - # Configure aconnect_sse itself to raise httpx.RequestError (e.g., during connection) - # This needs to be raised when aconnect_sse is entered or iterated. - # One way is to make the context manager's __aenter__ raise it, or aiter_sse. - # For simplicity, let's make aiter_sse raise it, as if the error occurs after connection. - mock_event_source = AsyncMock(spec=EventSource) - mock_event_source.aiter_sse.side_effect = httpx.RequestError( - 'Simulated network error', request=MagicMock() - ) - mock_aconnect_sse.return_value.__aenter__.return_value = ( - mock_event_source - ) - - with pytest.raises(A2AClientHTTPError) as exc_info: - async for _ in client.send_message_streaming(request=request): - pass - - assert exc_info.value.status_code == 503 # As per client implementation - assert 'Network communication error' in str(exc_info.value) - assert 'Simulated network error' in str(exc_info.value) - @pytest.mark.asyncio async def test_send_request_http_status_error( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): - client = A2AClient( + client = JsonRpcTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card ) mock_response = MagicMock(spec=httpx.Response) @@ -797,7 +476,7 @@ async def test_send_request_http_status_error( async def test_send_request_json_decode_error( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): - client = A2AClient( + client = JsonRpcTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card ) mock_response = AsyncMock(spec=httpx.Response) @@ -815,7 +494,7 @@ async def test_send_request_json_decode_error( async def test_send_request_httpx_request_error( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): - client = A2AClient( + client = JsonRpcTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card ) request_error = httpx.RequestError('Network issue', request=MagicMock()) @@ -828,451 +507,6 @@ async def test_send_request_httpx_request_error( assert 'Network communication error' in str(exc_info.value) assert 'Network issue' in str(exc_info.value) - @pytest.mark.asyncio - async def test_set_task_callback_success( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - task_id_val = 'task_set_cb_001' - # Correctly create the PushNotificationConfig (inner model) - push_config_payload = PushNotificationConfig( - url='https://callback.example.com/taskupdate' - ) - # Correctly create the TaskPushNotificationConfig (outer model) - params_model = TaskPushNotificationConfig( - task_id=task_id_val, push_notification_config=push_config_payload - ) - - # request.id will be generated by the client method if not provided - request = SetTaskPushNotificationConfigRequest( - id='', params=params_model - ) # Test ID auto-generation - - # The result for a successful set operation is the same config - rpc_response_payload: dict[str, Any] = { - 'id': ANY, # Will be checked against generated ID - 'jsonrpc': '2.0', - 'result': params_model.model_dump(mode='json', exclude_none=True), - } - - with ( - patch.object( - client, '_send_request', new_callable=AsyncMock - ) as mock_send_req, - patch( - 'a2a.client.jsonrpc_client.uuid4', - return_value=MagicMock(hex='testuuid'), - ) as mock_uuid, - ): - # Capture the generated ID for assertion - generated_id = str(mock_uuid.return_value) - rpc_response_payload['id'] = ( - generated_id # Ensure mock response uses the generated ID - ) - mock_send_req.return_value = rpc_response_payload - - response = await client.set_task_callback(request=request) - - mock_send_req.assert_called_once() - called_args, _ = mock_send_req.call_args - sent_json_payload = called_args[0] - - assert sent_json_payload['id'] == generated_id - assert ( - sent_json_payload['method'] - == 'tasks/pushNotificationConfig/set' - ) - assert sent_json_payload['params'] == params_model.model_dump( - mode='json', exclude_none=True - ) - - assert isinstance(response, SetTaskPushNotificationConfigResponse) - assert isinstance( - response.root, SetTaskPushNotificationConfigSuccessResponse - ) - assert response.root.id == generated_id - assert response.root.result.model_dump( - mode='json', exclude_none=True - ) == params_model.model_dump(mode='json', exclude_none=True) - - @pytest.mark.asyncio - async def test_set_task_callback_error_response( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - req_id = 'set_cb_err_req' - push_config_payload = PushNotificationConfig(url='https://errors.com') - params_model = TaskPushNotificationConfig( - task_id='task_err_cb', push_notification_config=push_config_payload - ) - request = SetTaskPushNotificationConfigRequest( - id=req_id, params=params_model - ) - error_details = InvalidParamsError(message='Invalid callback URL') - - rpc_response_payload: dict[str, Any] = { - 'id': req_id, - 'jsonrpc': '2.0', - 'error': error_details.model_dump(mode='json', exclude_none=True), - } - - with patch.object( - client, '_send_request', new_callable=AsyncMock - ) as mock_send_req: - mock_send_req.return_value = rpc_response_payload - response = await client.set_task_callback(request=request) - - assert isinstance(response, SetTaskPushNotificationConfigResponse) - assert isinstance(response.root, JSONRPCErrorResponse) - assert response.root.error.model_dump( - mode='json', exclude_none=True - ) == error_details.model_dump(mode='json', exclude_none=True) - assert response.root.id == req_id - - @pytest.mark.asyncio - async def test_set_task_callback_http_kwargs_passed( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - push_config_payload = PushNotificationConfig(url='https://kwargs.com') - params_model = TaskPushNotificationConfig( - task_id='task_cb_kwargs', - push_notification_config=push_config_payload, - ) - request = SetTaskPushNotificationConfigRequest( - id='cb_kwargs_req', params=params_model - ) - custom_kwargs = {'headers': {'X-Callback-Token': 'secret'}} - - # Minimal successful response - rpc_response_payload: dict[str, Any] = { - 'id': 'cb_kwargs_req', - 'jsonrpc': '2.0', - 'result': params_model.model_dump(mode='json'), - } - - with patch.object( - client, '_send_request', new_callable=AsyncMock - ) as mock_send_req: - mock_send_req.return_value = rpc_response_payload - await client.set_task_callback( - request=request, http_kwargs=custom_kwargs - ) - - mock_send_req.assert_called_once() - called_args, _ = mock_send_req.call_args # Correctly unpack args - assert ( - called_args[1] == custom_kwargs - ) # http_kwargs is the second positional arg - - @pytest.mark.asyncio - async def test_get_task_callback_success( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - task_id_val = 'task_get_cb_001' - params_model = TaskIdParams( - id=task_id_val - ) # Params for get is just TaskIdParams - - request = GetTaskPushNotificationConfigRequest( - id='', params=params_model - ) # ID is empty string for auto-generation test - - # Expected result for a successful get operation - push_config_payload = PushNotificationConfig( - url='https://callback.example.com/taskupdate' - ) - expected_callback_config = TaskPushNotificationConfig( - task_id=task_id_val, push_notification_config=push_config_payload - ) - rpc_response_payload: dict[str, Any] = { - 'id': ANY, - 'jsonrpc': '2.0', - 'result': expected_callback_config.model_dump( - mode='json', exclude_none=True - ), - } - - with ( - patch.object( - client, '_send_request', new_callable=AsyncMock - ) as mock_send_req, - patch( - 'a2a.client.jsonrpc_client.uuid4', - return_value=MagicMock(hex='testgetuuid'), - ) as mock_uuid, - ): - generated_id = str(mock_uuid.return_value) - rpc_response_payload['id'] = generated_id - mock_send_req.return_value = rpc_response_payload - - response = await client.get_task_callback(request=request) - - mock_send_req.assert_called_once() - called_args, _ = mock_send_req.call_args - sent_json_payload = called_args[0] - - assert sent_json_payload['id'] == generated_id - assert ( - sent_json_payload['method'] - == 'tasks/pushNotificationConfig/get' - ) - assert sent_json_payload['params'] == params_model.model_dump( - mode='json', exclude_none=True - ) - - assert isinstance(response, GetTaskPushNotificationConfigResponse) - assert isinstance( - response.root, GetTaskPushNotificationConfigSuccessResponse - ) - assert response.root.id == generated_id - assert response.root.result.model_dump( - mode='json', exclude_none=True - ) == expected_callback_config.model_dump( - mode='json', exclude_none=True - ) - - @pytest.mark.asyncio - async def test_get_task_callback_error_response( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - req_id = 'get_cb_err_req' - params_model = TaskIdParams(id='task_get_err_cb') - request = GetTaskPushNotificationConfigRequest( - id=req_id, params=params_model - ) - error_details = TaskNotCancelableError( - message='Cannot get callback for uncancelable task' - ) # Example error - - rpc_response_payload: dict[str, Any] = { - 'id': req_id, - 'jsonrpc': '2.0', - 'error': error_details.model_dump(mode='json', exclude_none=True), - } - - with patch.object( - client, '_send_request', new_callable=AsyncMock - ) as mock_send_req: - mock_send_req.return_value = rpc_response_payload - response = await client.get_task_callback(request=request) - - assert isinstance(response, GetTaskPushNotificationConfigResponse) - assert isinstance(response.root, JSONRPCErrorResponse) - assert response.root.error.model_dump( - mode='json', exclude_none=True - ) == error_details.model_dump(mode='json', exclude_none=True) - assert response.root.id == req_id - - @pytest.mark.asyncio - async def test_get_task_callback_http_kwargs_passed( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - params_model = TaskIdParams(id='task_get_cb_kwargs') - request = GetTaskPushNotificationConfigRequest( - id='get_cb_kwargs_req', params=params_model - ) - custom_kwargs = {'headers': {'X-Tenant-ID': 'tenant-x'}} - - # Correctly create the nested PushNotificationConfig - push_config_payload_for_expected = PushNotificationConfig( - url='https://getkwargs.com' - ) - expected_callback_config = TaskPushNotificationConfig( - task_id='task_get_cb_kwargs', - push_notification_config=push_config_payload_for_expected, - ) - rpc_response_payload: dict[str, Any] = { - 'id': 'get_cb_kwargs_req', - 'jsonrpc': '2.0', - 'result': expected_callback_config.model_dump(mode='json'), - } - - with patch.object( - client, '_send_request', new_callable=AsyncMock - ) as mock_send_req: - mock_send_req.return_value = rpc_response_payload - await client.get_task_callback( - request=request, http_kwargs=custom_kwargs - ) - - mock_send_req.assert_called_once() - called_args, _ = mock_send_req.call_args # Correctly unpack args - assert ( - called_args[1] == custom_kwargs - ) # http_kwargs is the second positional arg - - @pytest.mark.asyncio - async def test_get_task_success_use_request( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - task_id_val = 'task_for_req_obj' - params_model = TaskQueryParams(id=task_id_val) - request_obj_id = 789 - request = GetTaskRequest(id=request_obj_id, params=params_model) - - rpc_response_payload: dict[str, Any] = { - 'id': request_obj_id, - 'jsonrpc': '2.0', - 'result': MINIMAL_TASK, - } - - with patch.object( - client, '_send_request', new_callable=AsyncMock - ) as mock_send_req: - mock_send_req.return_value = rpc_response_payload - response = await client.get_task( - request=request, http_kwargs={'timeout': 20} - ) - - assert mock_send_req.call_count == 1 - called_args, called_kwargs = mock_send_req.call_args - assert len(called_args) == 2 - json_rpc_request_sent: dict[str, Any] = called_args[0] - assert not called_kwargs # no extra kwargs to _send_request - http_kwargs: dict[str, Any] = called_args[1] - assert http_kwargs['timeout'] == 20 - - assert json_rpc_request_sent['method'] == 'tasks/get' - assert json_rpc_request_sent['id'] == request_obj_id - assert json_rpc_request_sent['params'] == params_model.model_dump( - mode='json', exclude_none=True - ) - - assert isinstance(response, GetTaskResponse) - assert hasattr(response.root, 'result') - assert ( - response.root.result.model_dump(mode='json', exclude_none=True) # type: ignore - == MINIMAL_TASK - ) - assert response.root.id == request_obj_id - - @pytest.mark.asyncio - async def test_get_task_error_response( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - params_model = TaskQueryParams(id='task_error_case') - request = GetTaskRequest(id='err_req_id', params=params_model) - error_details = InvalidParamsError() - - rpc_response_payload: dict[str, Any] = { - 'id': 'err_req_id', - 'jsonrpc': '2.0', - 'error': error_details.model_dump(mode='json', exclude_none=True), - } - - with patch.object( - client, '_send_request', new_callable=AsyncMock - ) as mock_send_req: - mock_send_req.return_value = rpc_response_payload - response = await client.get_task(request=request) - - assert isinstance(response, GetTaskResponse) - assert isinstance(response.root, JSONRPCErrorResponse) - assert response.root.error.model_dump( - mode='json', exclude_none=True - ) == error_details.model_dump(exclude_none=True) - assert response.root.id == 'err_req_id' - - @pytest.mark.asyncio - async def test_cancel_task_success_use_request( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - task_id_val = MINIMAL_CANCELLED_TASK['id'] - params_model = TaskIdParams(id=task_id_val) - request_obj_id = 'cancel_req_obj_id_001' - request = CancelTaskRequest(id=request_obj_id, params=params_model) - - rpc_response_payload: dict[str, Any] = { - 'id': request_obj_id, - 'jsonrpc': '2.0', - 'result': MINIMAL_CANCELLED_TASK, - } - - with patch.object( - client, '_send_request', new_callable=AsyncMock - ) as mock_send_req: - mock_send_req.return_value = rpc_response_payload - response = await client.cancel_task( - request=request, http_kwargs={'timeout': 15} - ) - - assert mock_send_req.call_count == 1 - called_args, called_kwargs = mock_send_req.call_args - assert not called_kwargs # no extra kwargs to _send_request - assert len(called_args) == 2 - json_rpc_request_sent: dict[str, Any] = called_args[0] - http_kwargs: dict[str, Any] = called_args[1] - assert http_kwargs['timeout'] == 15 - - assert json_rpc_request_sent['method'] == 'tasks/cancel' - assert json_rpc_request_sent['id'] == request_obj_id - assert json_rpc_request_sent['params'] == params_model.model_dump( - mode='json', exclude_none=True - ) - - assert isinstance(response, CancelTaskResponse) - assert isinstance(response.root, CancelTaskSuccessResponse) - assert ( - response.root.result.model_dump(mode='json', exclude_none=True) # type: ignore - == MINIMAL_CANCELLED_TASK - ) - assert response.root.id == request_obj_id - - @pytest.mark.asyncio - async def test_cancel_task_error_response( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - params_model = TaskIdParams(id='task_cancel_error_case') - request = CancelTaskRequest(id='err_cancel_req', params=params_model) - error_details = TaskNotCancelableError() - - rpc_response_payload: dict[str, Any] = { - 'id': 'err_cancel_req', - 'jsonrpc': '2.0', - 'error': error_details.model_dump(mode='json', exclude_none=True), - } - - with patch.object( - client, '_send_request', new_callable=AsyncMock - ) as mock_send_req: - mock_send_req.return_value = rpc_response_payload - response = await client.cancel_task(request=request) - - assert isinstance(response, CancelTaskResponse) - assert isinstance(response.root, JSONRPCErrorResponse) - assert response.root.error.model_dump( - mode='json', exclude_none=True - ) == error_details.model_dump(exclude_none=True) - assert response.root.id == 'err_cancel_req' - @pytest.mark.asyncio async def test_send_message_client_timeout( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock @@ -1280,17 +514,14 @@ async def test_send_message_client_timeout( mock_httpx_client.post.side_effect = httpx.ReadTimeout( 'Request timed out' ) - client = A2AClient( + client = JsonRpcTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card ) - params = MessageSendParams( message=create_text_message_object(content='Hello') ) - request = SendMessageRequest(id=123, params=params) - with pytest.raises(A2AClientTimeoutError) as exc_info: - await client.send_message(request=request) + await client.send_message(request=params) assert 'Request timed out' in str(exc_info.value) From 8e70dc4888df858ee5090396ee88644b5264d82a Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Tue, 29 Jul 2025 22:22:39 +0000 Subject: [PATCH 03/12] Gemini authored: add an integration test that exercises transport client + server --- .../test_client_server_integration.py | 233 ++++++++++++++++++ 1 file changed, 233 insertions(+) create mode 100644 tests/integration/test_client_server_integration.py diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py new file mode 100644 index 00000000..53ffb384 --- /dev/null +++ b/tests/integration/test_client_server_integration.py @@ -0,0 +1,233 @@ +import asyncio +from typing import Any, AsyncGenerator, NamedTuple +from unittest.mock import AsyncMock + +import grpc +import httpx +import pytest +import pytest_asyncio +from grpc.aio import Channel +from starlette.testclient import TestClient + +from a2a.client.transports import JsonRpcTransport, RestTransport +from a2a.client.transports.base import ClientTransport +from a2a.client.transports.grpc import GrpcTransport +from a2a.grpc import a2a_pb2_grpc +from a2a.server.apps import A2AFastAPIApplication, A2ARESTFastAPIApplication +from a2a.server.request_handlers import GrpcHandler, RequestHandler +from a2a.types import ( + AgentCapabilities, + AgentCard, + AgentInterface, + Message, + MessageSendParams, + Part, + Role, + Task, + TaskState, + TaskStatus, + TextPart, + TransportProtocol, +) + +# --- Test Constants --- + +TASK_FROM_STREAM = Task( + id="task-123-stream", + context_id="ctx-456-stream", + status=TaskStatus(state=TaskState.completed), + kind="task", +) + + +# --- Test Fixtures --- + + +@pytest.fixture +def mock_request_handler() -> AsyncMock: + """Provides a mock RequestHandler for the server-side handlers.""" + handler = AsyncMock(spec=RequestHandler) + + async def stream_side_effect(*args, **kwargs): + yield TASK_FROM_STREAM + + handler.on_message_send_stream.side_effect = stream_side_effect + return handler + + +@pytest.fixture +def agent_card() -> AgentCard: + """Provides a sample AgentCard for tests.""" + return AgentCard( + name='Test Agent', + description='An agent for integration testing.', + url='http://testserver', + version='1.0.0', + capabilities=AgentCapabilities(streaming=True), + skills=[], + default_input_modes=['text/plain'], + default_output_modes=['text/plain'], + preferred_transport=TransportProtocol.jsonrpc, + additional_interfaces=[ + AgentInterface( + transport=TransportProtocol.http_json, url='http://testserver' + ), + AgentInterface( + transport=TransportProtocol.grpc, url='localhost:50051' + ), + ], + ) + + +class TransportSetup(NamedTuple): + """Holds the transport and handler for a given test.""" + + transport: ClientTransport + handler: AsyncMock + + +# --- HTTP/JSON-RPC/REST Setup --- + + +@pytest.fixture +def http_base_setup(mock_request_handler: AsyncMock, agent_card: AgentCard): + """A base fixture to patch the sse-starlette event loop issue.""" + from sse_starlette import sse + + sse.AppStatus.should_exit_event = asyncio.Event() + yield mock_request_handler, agent_card + + +@pytest.fixture +def jsonrpc_setup(http_base_setup) -> TransportSetup: + """Sets up the JsonRpcTransport and in-memory server.""" + mock_request_handler, agent_card = http_base_setup + app_builder = A2AFastAPIApplication(agent_card, mock_request_handler) + app = app_builder.build() + httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) + transport = JsonRpcTransport( + httpx_client=httpx_client, agent_card=agent_card + ) + return TransportSetup(transport=transport, handler=mock_request_handler) + + +@pytest.fixture +def rest_setup(http_base_setup) -> TransportSetup: + """Sets up the RestTransport and in-memory server.""" + mock_request_handler, agent_card = http_base_setup + app_builder = A2ARESTFastAPIApplication(agent_card, mock_request_handler) + app = app_builder.build() + httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) + transport = RestTransport(httpx_client=httpx_client, agent_card=agent_card) + return TransportSetup(transport=transport, handler=mock_request_handler) + + +# --- gRPC Setup --- + + +@pytest_asyncio.fixture +async def grpc_server_and_handler( + mock_request_handler: AsyncMock, agent_card: AgentCard +) -> AsyncGenerator[tuple[str, AsyncMock], None]: + """Creates and manages an in-process gRPC test server.""" + server = grpc.aio.server() + port = server.add_insecure_port('[::]:0') + server_address = f'localhost:{port}' + servicer = GrpcHandler(agent_card, mock_request_handler) + a2a_pb2_grpc.add_A2AServiceServicer_to_server(servicer, server) + await server.start() + yield server_address, mock_request_handler + await server.stop(0) + + +# --- The Integration Test --- + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'transport_setup_fixture', + [ + pytest.param('jsonrpc_setup', id='JSON-RPC'), + pytest.param('rest_setup', id='REST'), + ], +) +async def test_http_transport_sends_message( + transport_setup_fixture: str, request +) -> None: + """ + Integration test for HTTP-based transports (JSON-RPC, REST). + """ + transport_setup: TransportSetup = request.getfixturevalue( + transport_setup_fixture + ) + transport = transport_setup.transport + handler = transport_setup.handler + + message_to_send = Message( + role=Role.user, + message_id='msg-integration-test', + parts=[Part(root=TextPart(text='Hello, integration test!'))], + ) + params = MessageSendParams(message=message_to_send) + + stream = transport.send_message_streaming(request=params) + first_event = await anext(stream) + + assert first_event.id == TASK_FROM_STREAM.id + assert first_event.context_id == TASK_FROM_STREAM.context_id + + handler.on_message_send_stream.assert_called_once() + call_args, _ = handler.on_message_send_stream.call_args + received_params: MessageSendParams = call_args[0] + + assert received_params.message.message_id == message_to_send.message_id + assert ( + received_params.message.parts[0].root.text + == message_to_send.parts[0].root.text + ) + + if hasattr(transport, 'close'): + await transport.close() + + +@pytest.mark.asyncio +async def test_grpc_transport_sends_message( + grpc_server_and_handler: tuple[str, AsyncMock], + agent_card: AgentCard, +) -> None: + """ + Integration test specifically for the gRPC transport. + """ + server_address, handler = grpc_server_and_handler + agent_card.url = server_address + + def channel_factory(address: str) -> Channel: + return grpc.aio.insecure_channel(address) + + stub = a2a_pb2_grpc.A2AServiceStub(channel_factory(server_address)) + transport = GrpcTransport(grpc_stub=stub, agent_card=agent_card) + + message_to_send = Message( + role=Role.user, + message_id='msg-grpc-integration-test', + parts=[Part(root=TextPart(text='Hello, gRPC integration test!'))], + ) + params = MessageSendParams(message=message_to_send) + + stream = transport.send_message_streaming(request=params) + first_event = await anext(stream) + + assert first_event.id == TASK_FROM_STREAM.id + assert first_event.context_id == TASK_FROM_STREAM.context_id + + handler.on_message_send_stream.assert_called_once() + call_args, _ = handler.on_message_send_stream.call_args + received_params: MessageSendParams = call_args[0] + + assert received_params.message.message_id == message_to_send.message_id + assert ( + received_params.message.parts[0].root.text + == message_to_send.parts[0].root.text + ) + + await transport.close() From d632debd671794fd1aeb1cb6477c377bef428395 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Wed, 30 Jul 2025 03:09:44 +0000 Subject: [PATCH 04/12] Add integration tests, fix uncovered bugs. --- src/a2a/client/transports/grpc.py | 6 +- src/a2a/client/transports/jsonrpc.py | 13 +- src/a2a/client/transports/rest.py | 4 +- src/a2a/server/apps/rest/rest_app.py | 9 +- .../server/request_handlers/rest_handler.py | 15 +- src/a2a/utils/helpers.py | 35 +- src/a2a/utils/proto_utils.py | 8 +- .../test_client_server_integration.py | 503 +++++++++++++++++- tests/utils/test_proto_utils.py | 9 +- 9 files changed, 558 insertions(+), 44 deletions(-) diff --git a/src/a2a/client/transports/grpc.py b/src/a2a/client/transports/grpc.py index e75146bc..c340ed63 100644 --- a/src/a2a/client/transports/grpc.py +++ b/src/a2a/client/transports/grpc.py @@ -144,8 +144,8 @@ async def set_task_callback( """Sets or updates the push notification configuration for a specific task.""" config = await self.stub.CreateTaskPushNotificationConfig( a2a_pb2.CreateTaskPushNotificationConfigRequest( - parent='', - config_id='', + parent=f'tasks/{request.task_id}', + config_id=request.push_notification_config.id, config=proto_utils.ToProto.task_push_notification_config( request ), @@ -162,7 +162,7 @@ async def get_task_callback( """Retrieves the push notification configuration for a specific task.""" config = await self.stub.GetTaskPushNotificationConfig( a2a_pb2.GetTaskPushNotificationConfigRequest( - name=f'tasks/{request.id}/pushNotification/{request.push_notification_config_id}', + name=f'tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}', ) ) return proto_utils.FromProto.task_push_notification_config(config) diff --git a/src/a2a/client/transports/jsonrpc.py b/src/a2a/client/transports/jsonrpc.py index be8bca5e..4bc2dfa6 100644 --- a/src/a2a/client/transports/jsonrpc.py +++ b/src/a2a/client/transports/jsonrpc.py @@ -28,6 +28,7 @@ GetTaskRequest, GetTaskResponse, JSONRPCErrorResponse, + JSONRPCRequest, Message, MessageSendParams, SendMessageRequest, @@ -348,13 +349,21 @@ async def get_card( if not self._needs_extended_card: return card - payload, modified_kwargs = await self._apply_interceptors( + _, modified_kwargs = await self._apply_interceptors( 'agent/getAuthenticatedExtendedCard', {}, self._get_http_args(context), context, ) - response_data = await self._send_request(payload, modified_kwargs) + + response_data = await self._send_request( + JSONRPCRequest( + method='agent/getAuthenticatedExtendedCard', + params={}, + id=str(uuid4()), + ).model_dump(), + modified_kwargs, + ) card = AgentCard.model_validate(response_data) self.agent_card = card self._needs_extended_card = False diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index 7c0ef7ab..9b22d165 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -268,7 +268,7 @@ async def set_task_callback( payload, self._get_http_args(context), context ) response_data = await self._send_post_request( - f'/v1/tasks/{request.task_id}/pushNotificationConfigs/', + f'/v1/tasks/{request.task_id}/pushNotificationConfigs', payload, modified_kwargs, ) @@ -371,7 +371,7 @@ async def get_card( context, ) response_data = await self._send_get_request( - '/v1/card/get', {}, modified_kwargs + '/v1/card', {}, modified_kwargs ) card = AgentCard.model_validate(response_data) self.agent_card = card diff --git a/src/a2a/server/apps/rest/rest_app.py b/src/a2a/server/apps/rest/rest_app.py index 57b2c295..ad723f04 100644 --- a/src/a2a/server/apps/rest/rest_app.py +++ b/src/a2a/server/apps/rest/rest_app.py @@ -17,9 +17,7 @@ ) from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.server.request_handlers.rest_handler import ( - RESTHandler, -) +from a2a.server.request_handlers.rest_handler import RESTHandler from a2a.types import ( A2AError, AgentCard, @@ -198,10 +196,13 @@ def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]: ('/v1/message:send', 'POST'): functools.partial( self._handle_request, self.handler.on_message_send ), - ('/v1/message:stream', 'POST'): functools.partial( + ('/v1/message:stream', 'GET'): functools.partial( self._handle_streaming_request, self.handler.on_message_send_stream, ), + ('/v1/tasks/{id}:cancel', 'POST'): functools.partial( + self._handle_request, self.handler.on_cancel_task + ), ('/v1/tasks/{id}:subscribe', 'POST'): functools.partial( self._handle_streaming_request, self.handler.on_resubscribe_to_task, diff --git a/src/a2a/server/request_handlers/rest_handler.py b/src/a2a/server/request_handlers/rest_handler.py index 953d5e77..31337534 100644 --- a/src/a2a/server/request_handlers/rest_handler.py +++ b/src/a2a/server/request_handlers/rest_handler.py @@ -2,7 +2,7 @@ from collections.abc import AsyncIterable -from google.protobuf.json_format import MessageToJson, Parse +from google.protobuf.json_format import MessageToDict, MessageToJson, Parse from starlette.requests import Request from a2a.grpc import a2a_pb2 @@ -85,7 +85,7 @@ async def on_message_send( task_or_message = await self.request_handler.on_message_send( a2a_request, context ) - return MessageToJson( + return MessageToDict( proto_utils.ToProto.task_or_message(task_or_message) ) except ServerError as e: @@ -161,7 +161,7 @@ async def on_cancel_task( TaskIdParams(id=task_id), context ) if task: - return MessageToJson(proto_utils.ToProto.task(task)) + return MessageToDict(proto_utils.ToProto.task(task)) raise ServerError(error=TaskNotFoundError()) except ServerError as e: raise A2AErrorWrapperError( @@ -236,7 +236,7 @@ async def get_push_notification( params, context ) ) - return MessageToJson( + return MessageToDict( proto_utils.ToProto.task_push_notification_config(config) ) except ServerError as e: @@ -270,7 +270,7 @@ async def set_push_notification( found. """ try: - _ = request.path_params['id'] + task_id = request.path_params['id'] body = await request.body() params = a2a_pb2.CreateTaskPushNotificationConfigRequest() Parse(body, params) @@ -279,12 +279,13 @@ async def set_push_notification( params, ) ) + a2a_request.task_id = task_id config = ( await self.request_handler.on_set_task_push_notification_config( a2a_request, context ) ) - return MessageToJson( + return MessageToDict( proto_utils.ToProto.task_push_notification_config(config) ) except ServerError as e: @@ -318,7 +319,7 @@ async def on_get_task( params = TaskQueryParams(id=task_id, history_length=history_length) task = await self.request_handler.on_get_task(params, context) if task: - return MessageToJson(proto_utils.ToProto.task(task)) + return MessageToDict(proto_utils.ToProto.task(task)) raise ServerError(error=TaskNotFoundError()) except ServerError as e: raise A2AErrorWrapperError( diff --git a/src/a2a/utils/helpers.py b/src/a2a/utils/helpers.py index 091268ba..5f9ee94b 100644 --- a/src/a2a/utils/helpers.py +++ b/src/a2a/utils/helpers.py @@ -1,6 +1,7 @@ """General utility functions for the A2A Python SDK.""" import functools +import inspect import logging from collections.abc import Callable @@ -135,16 +136,32 @@ def validate( """ def decorator(function: Callable) -> Callable: - def wrapper(self: Any, *args, **kwargs) -> Any: - if not expression(self): - final_message = error_message or str(expression) - logger.error(f'Unsupported Operation: {final_message}') - raise ServerError( - UnsupportedOperationError(message=final_message) - ) - return function(self, *args, **kwargs) + if inspect.iscoroutinefunction(function): + + @functools.wraps(function) + async def async_wrapper(self: Any, *args, **kwargs) -> Any: + if not expression(self): + final_message = error_message or str(expression) + logger.error(f'Unsupported Operation: {final_message}') + raise ServerError( + UnsupportedOperationError(message=final_message) + ) + return await function(self, *args, **kwargs) + + return async_wrapper + else: - return wrapper + @functools.wraps(function) + def sync_wrapper(self: Any, *args, **kwargs) -> Any: + if not expression(self): + final_message = error_message or str(expression) + logger.error(f'Unsupported Operation: {final_message}') + raise ServerError( + UnsupportedOperationError(message=final_message) + ) + return function(self, *args, **kwargs) + + return sync_wrapper return decorator diff --git a/src/a2a/utils/proto_utils.py b/src/a2a/utils/proto_utils.py index 4d87280f..05f03ed1 100644 --- a/src/a2a/utils/proto_utils.py +++ b/src/a2a/utils/proto_utils.py @@ -2,6 +2,7 @@ """Utils for converting between proto and Python types.""" import json +import logging import re from typing import Any @@ -13,9 +14,12 @@ from a2a.utils.errors import ServerError +logger = logging.getLogger(__name__) + + # Regexp patterns for matching -_TASK_NAME_MATCH = r'tasks/(\w+)' -_TASK_PUSH_CONFIG_NAME_MATCH = r'tasks/(\w+)/pushNotificationConfigs/(\w+)' +_TASK_NAME_MATCH = r'tasks/([\w-]+)' +_TASK_PUSH_CONFIG_NAME_MATCH = r'tasks/([\w-]+)/pushNotificationConfigs/([\w-]+)' class ToProto: diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index 53ffb384..56c18bb4 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -1,6 +1,7 @@ + import asyncio from typing import Any, AsyncGenerator, NamedTuple -from unittest.mock import AsyncMock +from unittest.mock import ANY, AsyncMock import grpc import httpx @@ -19,13 +20,20 @@ AgentCapabilities, AgentCard, AgentInterface, + GetTaskPushNotificationConfigParams, Message, MessageSendParams, Part, + PushNotificationConfig, Role, Task, + TaskArtifactUpdateEvent, + TaskIdParams, + TaskPushNotificationConfig, + TaskQueryParams, TaskState, TaskStatus, + TaskStatusUpdateEvent, TextPart, TransportProtocol, ) @@ -39,6 +47,41 @@ kind="task", ) +TASK_FROM_BLOCKING = Task( + id="task-789-blocking", + context_id="ctx-101-blocking", + status=TaskStatus(state=TaskState.completed), + kind="task", +) + +GET_TASK_RESPONSE = Task( + id="task-get-456", + context_id="ctx-get-789", + status=TaskStatus(state=TaskState.working), + kind="task", +) + +CANCEL_TASK_RESPONSE = Task( + id="task-cancel-789", + context_id="ctx-cancel-101", + status=TaskStatus(state=TaskState.canceled), + kind="task", +) + +CALLBACK_CONFIG = TaskPushNotificationConfig( + task_id="task-callback-123", + push_notification_config=PushNotificationConfig( + id="pnc-abc", url="http://callback.example.com", token='' + ), +) + +RESUBSCRIBE_EVENT = TaskStatusUpdateEvent( + task_id="task-resub-456", + context_id="ctx-resub-789", + status=TaskStatus(state=TaskState.working), + final=False, +) + # --- Test Fixtures --- @@ -48,10 +91,26 @@ def mock_request_handler() -> AsyncMock: """Provides a mock RequestHandler for the server-side handlers.""" handler = AsyncMock(spec=RequestHandler) + # Configure on_message_send for non-streaming calls + handler.on_message_send.return_value = TASK_FROM_BLOCKING + + # Configure on_message_send_stream for streaming calls async def stream_side_effect(*args, **kwargs): yield TASK_FROM_STREAM handler.on_message_send_stream.side_effect = stream_side_effect + + # Configure other methods + handler.on_get_task.return_value = GET_TASK_RESPONSE + handler.on_cancel_task.return_value = CANCEL_TASK_RESPONSE + handler.on_set_task_push_notification_config.side_effect = lambda params, context: params + handler.on_get_task_push_notification_config.return_value = CALLBACK_CONFIG + + async def resubscribe_side_effect(*args, **kwargs): + yield RESUBSCRIBE_EVENT + + handler.on_resubscribe_to_task.side_effect = resubscribe_side_effect + return handler @@ -63,11 +122,12 @@ def agent_card() -> AgentCard: description='An agent for integration testing.', url='http://testserver', version='1.0.0', - capabilities=AgentCapabilities(streaming=True), + capabilities=AgentCapabilities(streaming=True, push_notifications=True), skills=[], default_input_modes=['text/plain'], default_output_modes=['text/plain'], preferred_transport=TransportProtocol.jsonrpc, + supports_authenticated_extended_card=True, additional_interfaces=[ AgentInterface( transport=TransportProtocol.http_json, url='http://testserver' @@ -102,7 +162,9 @@ def http_base_setup(mock_request_handler: AsyncMock, agent_card: AgentCard): def jsonrpc_setup(http_base_setup) -> TransportSetup: """Sets up the JsonRpcTransport and in-memory server.""" mock_request_handler, agent_card = http_base_setup - app_builder = A2AFastAPIApplication(agent_card, mock_request_handler) + app_builder = A2AFastAPIApplication( + agent_card, mock_request_handler, extended_agent_card=agent_card + ) app = app_builder.build() httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) transport = JsonRpcTransport( @@ -140,7 +202,7 @@ async def grpc_server_and_handler( await server.stop(0) -# --- The Integration Test --- +# --- The Integration Tests --- @pytest.mark.asyncio @@ -151,11 +213,11 @@ async def grpc_server_and_handler( pytest.param('rest_setup', id='REST'), ], ) -async def test_http_transport_sends_message( +async def test_http_transport_sends_message_streaming( transport_setup_fixture: str, request ) -> None: """ - Integration test for HTTP-based transports (JSON-RPC, REST). + Integration test for HTTP-based transports (JSON-RPC, REST) streaming. """ transport_setup: TransportSetup = request.getfixturevalue( transport_setup_fixture @@ -191,12 +253,12 @@ async def test_http_transport_sends_message( @pytest.mark.asyncio -async def test_grpc_transport_sends_message( +async def test_grpc_transport_sends_message_streaming( grpc_server_and_handler: tuple[str, AsyncMock], agent_card: AgentCard, ) -> None: """ - Integration test specifically for the gRPC transport. + Integration test specifically for the gRPC transport streaming. """ server_address, handler = grpc_server_and_handler agent_card.url = server_address @@ -231,3 +293,428 @@ def channel_factory(address: str) -> Channel: ) await transport.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'transport_setup_fixture', + [ + pytest.param('jsonrpc_setup', id='JSON-RPC'), + pytest.param('rest_setup', id='REST'), + ], +) +async def test_http_transport_sends_message_blocking( + transport_setup_fixture: str, request +) -> None: + """ + Integration test for HTTP-based transports (JSON-RPC, REST) blocking. + """ + transport_setup: TransportSetup = request.getfixturevalue( + transport_setup_fixture + ) + transport = transport_setup.transport + handler = transport_setup.handler + + message_to_send = Message( + role=Role.user, + message_id='msg-integration-test-blocking', + parts=[Part(root=TextPart(text='Hello, blocking test!'))], + ) + params = MessageSendParams(message=message_to_send) + + result = await transport.send_message(request=params) + + assert result.id == TASK_FROM_BLOCKING.id + assert result.context_id == TASK_FROM_BLOCKING.context_id + + handler.on_message_send.assert_awaited_once() + call_args, _ = handler.on_message_send.call_args + received_params: MessageSendParams = call_args[0] + + assert received_params.message.message_id == message_to_send.message_id + assert ( + received_params.message.parts[0].root.text + == message_to_send.parts[0].root.text + ) + + if hasattr(transport, 'close'): + await transport.close() + + +@pytest.mark.asyncio +async def test_grpc_transport_sends_message_blocking( + grpc_server_and_handler: tuple[str, AsyncMock], + agent_card: AgentCard, +) -> None: + """ + Integration test specifically for the gRPC transport blocking. + """ + server_address, handler = grpc_server_and_handler + agent_card.url = server_address + + def channel_factory(address: str) -> Channel: + return grpc.aio.insecure_channel(address) + + stub = a2a_pb2_grpc.A2AServiceStub(channel_factory(server_address)) + transport = GrpcTransport(grpc_stub=stub, agent_card=agent_card) + + message_to_send = Message( + role=Role.user, + message_id='msg-grpc-integration-test-blocking', + parts=[Part(root=TextPart(text='Hello, gRPC blocking test!'))], + ) + params = MessageSendParams(message=message_to_send) + + result = await transport.send_message(request=params) + + assert result.id == TASK_FROM_BLOCKING.id + assert result.context_id == TASK_FROM_BLOCKING.context_id + + handler.on_message_send.assert_awaited_once() + call_args, _ = handler.on_message_send.call_args + received_params: MessageSendParams = call_args[0] + + assert received_params.message.message_id == message_to_send.message_id + assert ( + received_params.message.parts[0].root.text + == message_to_send.parts[0].root.text + ) + + await transport.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'transport_setup_fixture', + [ + pytest.param('jsonrpc_setup', id='JSON-RPC'), + pytest.param('rest_setup', id='REST'), + ], +) +async def test_http_transport_get_task( + transport_setup_fixture: str, request +) -> None: + transport_setup: TransportSetup = request.getfixturevalue( + transport_setup_fixture + ) + transport = transport_setup.transport + handler = transport_setup.handler + + params = TaskQueryParams(id=GET_TASK_RESPONSE.id) + result = await transport.get_task(request=params) + + assert result.id == GET_TASK_RESPONSE.id + handler.on_get_task.assert_awaited_once_with(params, ANY) + + if hasattr(transport, 'close'): + await transport.close() + + +@pytest.mark.asyncio +async def test_grpc_transport_get_task( + grpc_server_and_handler: tuple[str, AsyncMock], + agent_card: AgentCard, +) -> None: + server_address, handler = grpc_server_and_handler + agent_card.url = server_address + + def channel_factory(address: str) -> Channel: + return grpc.aio.insecure_channel(address) + + stub = a2a_pb2_grpc.A2AServiceStub(channel_factory(server_address)) + transport = GrpcTransport(grpc_stub=stub, agent_card=agent_card) + + params = TaskQueryParams(id=GET_TASK_RESPONSE.id) + result = await transport.get_task(request=params) + + assert result.id == GET_TASK_RESPONSE.id + handler.on_get_task.assert_awaited_once() + assert handler.on_get_task.call_args[0][0].id == GET_TASK_RESPONSE.id + + await transport.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'transport_setup_fixture', + [ + pytest.param('jsonrpc_setup', id='JSON-RPC'), + pytest.param('rest_setup', id='REST'), + ], +) +async def test_http_transport_cancel_task( + transport_setup_fixture: str, request +) -> None: + transport_setup: TransportSetup = request.getfixturevalue( + transport_setup_fixture + ) + transport = transport_setup.transport + handler = transport_setup.handler + + params = TaskIdParams(id=CANCEL_TASK_RESPONSE.id) + result = await transport.cancel_task(request=params) + + assert result.id == CANCEL_TASK_RESPONSE.id + handler.on_cancel_task.assert_awaited_once_with(params, ANY) + + if hasattr(transport, 'close'): + await transport.close() + + +@pytest.mark.asyncio +async def test_grpc_transport_cancel_task( + grpc_server_and_handler: tuple[str, AsyncMock], + agent_card: AgentCard, +) -> None: + server_address, handler = grpc_server_and_handler + agent_card.url = server_address + + def channel_factory(address: str) -> Channel: + return grpc.aio.insecure_channel(address) + + stub = a2a_pb2_grpc.A2AServiceStub(channel_factory(server_address)) + transport = GrpcTransport(grpc_stub=stub, agent_card=agent_card) + + params = TaskIdParams(id=CANCEL_TASK_RESPONSE.id) + result = await transport.cancel_task(request=params) + + assert result.id == CANCEL_TASK_RESPONSE.id + handler.on_cancel_task.assert_awaited_once() + assert handler.on_cancel_task.call_args[0][0].id == CANCEL_TASK_RESPONSE.id + + await transport.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'transport_setup_fixture', + [ + pytest.param('jsonrpc_setup', id='JSON-RPC'), + pytest.param('rest_setup', id='REST'), + ], +) +async def test_http_transport_set_task_callback( + transport_setup_fixture: str, request +) -> None: + transport_setup: TransportSetup = request.getfixturevalue( + transport_setup_fixture + ) + transport = transport_setup.transport + handler = transport_setup.handler + + params = CALLBACK_CONFIG + result = await transport.set_task_callback(request=params) + + assert result.task_id == CALLBACK_CONFIG.task_id + assert result.push_notification_config.id == CALLBACK_CONFIG.push_notification_config.id + assert result.push_notification_config.url == CALLBACK_CONFIG.push_notification_config.url + handler.on_set_task_push_notification_config.assert_awaited_once_with( + params, ANY + ) + + if hasattr(transport, 'close'): + await transport.close() + + +@pytest.mark.asyncio +async def test_grpc_transport_set_task_callback( + grpc_server_and_handler: tuple[str, AsyncMock], + agent_card: AgentCard, +) -> None: + server_address, handler = grpc_server_and_handler + agent_card.url = server_address + + def channel_factory(address: str) -> Channel: + return grpc.aio.insecure_channel(address) + + stub = a2a_pb2_grpc.A2AServiceStub(channel_factory(server_address)) + transport = GrpcTransport(grpc_stub=stub, agent_card=agent_card) + + params = CALLBACK_CONFIG + result = await transport.set_task_callback(request=params) + + assert result.task_id == CALLBACK_CONFIG.task_id + assert result.push_notification_config.id == CALLBACK_CONFIG.push_notification_config.id + assert result.push_notification_config.url == CALLBACK_CONFIG.push_notification_config.url + handler.on_set_task_push_notification_config.assert_awaited_once() + assert ( + handler.on_set_task_push_notification_config.call_args[0][0].task_id + == CALLBACK_CONFIG.task_id + ) + + await transport.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'transport_setup_fixture', + [ + pytest.param('jsonrpc_setup', id='JSON-RPC'), + pytest.param('rest_setup', id='REST'), + ], +) +async def test_http_transport_get_task_callback( + transport_setup_fixture: str, request +) -> None: + transport_setup: TransportSetup = request.getfixturevalue( + transport_setup_fixture + ) + transport = transport_setup.transport + handler = transport_setup.handler + + params = GetTaskPushNotificationConfigParams( + id=CALLBACK_CONFIG.task_id, + push_notification_config_id=CALLBACK_CONFIG.push_notification_config.id, + ) + result = await transport.get_task_callback(request=params) + + assert result.task_id == CALLBACK_CONFIG.task_id + assert result.push_notification_config.id == CALLBACK_CONFIG.push_notification_config.id + assert result.push_notification_config.url == CALLBACK_CONFIG.push_notification_config.url + handler.on_get_task_push_notification_config.assert_awaited_once_with( + params, ANY + ) + + if hasattr(transport, 'close'): + await transport.close() + + +@pytest.mark.asyncio +async def test_grpc_transport_get_task_callback( + grpc_server_and_handler: tuple[str, AsyncMock], + agent_card: AgentCard, +) -> None: + server_address, handler = grpc_server_and_handler + agent_card.url = server_address + + def channel_factory(address: str) -> Channel: + return grpc.aio.insecure_channel(address) + + stub = a2a_pb2_grpc.A2AServiceStub(channel_factory(server_address)) + transport = GrpcTransport(grpc_stub=stub, agent_card=agent_card) + + params = GetTaskPushNotificationConfigParams( + id=CALLBACK_CONFIG.task_id, + push_notification_config_id=CALLBACK_CONFIG.push_notification_config.id, + ) + result = await transport.get_task_callback(request=params) + + assert result.task_id == CALLBACK_CONFIG.task_id + assert result.push_notification_config.id == CALLBACK_CONFIG.push_notification_config.id + assert result.push_notification_config.url == CALLBACK_CONFIG.push_notification_config.url + handler.on_get_task_push_notification_config.assert_awaited_once() + assert ( + handler.on_get_task_push_notification_config.call_args[0][0].id + == CALLBACK_CONFIG.task_id + ) + + await transport.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'transport_setup_fixture', + [ + pytest.param('jsonrpc_setup', id='JSON-RPC'), + pytest.param('rest_setup', id='REST'), + ], +) +async def test_http_transport_resubscribe( + transport_setup_fixture: str, request +) -> None: + transport_setup: TransportSetup = request.getfixturevalue( + transport_setup_fixture + ) + transport = transport_setup.transport + handler = transport_setup.handler + + params = TaskIdParams(id=RESUBSCRIBE_EVENT.task_id) + stream = transport.resubscribe(request=params) + first_event = await anext(stream) + + assert first_event.task_id == RESUBSCRIBE_EVENT.task_id + handler.on_resubscribe_to_task.assert_called_once_with(params, ANY) + + if hasattr(transport, 'close'): + await transport.close() + + +@pytest.mark.asyncio +async def test_grpc_transport_resubscribe( + grpc_server_and_handler: tuple[str, AsyncMock], + agent_card: AgentCard, +) -> None: + server_address, handler = grpc_server_and_handler + agent_card.url = server_address + + def channel_factory(address: str) -> Channel: + return grpc.aio.insecure_channel(address) + + stub = a2a_pb2_grpc.A2AServiceStub(channel_factory(server_address)) + transport = GrpcTransport(grpc_stub=stub, agent_card=agent_card) + + params = TaskIdParams(id=RESUBSCRIBE_EVENT.task_id) + stream = transport.resubscribe(request=params) + first_event = await anext(stream) + + assert first_event.task_id == RESUBSCRIBE_EVENT.task_id + handler.on_resubscribe_to_task.assert_called_once() + assert ( + handler.on_resubscribe_to_task.call_args[0][0].id + == RESUBSCRIBE_EVENT.task_id + ) + + await transport.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'transport_setup_fixture', + [ + pytest.param('jsonrpc_setup', id='JSON-RPC'), + pytest.param('rest_setup', id='REST'), + ], +) +async def test_http_transport_get_card( + transport_setup_fixture: str, request, agent_card: AgentCard +) -> None: + transport_setup: TransportSetup = request.getfixturevalue( + transport_setup_fixture + ) + transport = transport_setup.transport + + # The transport starts with a minimal card, get_card() fetches the full one + transport.agent_card.supports_authenticated_extended_card = True + result = await transport.get_card() + + assert result.name == agent_card.name + assert transport.agent_card.name == agent_card.name + assert transport._needs_extended_card is False + + if hasattr(transport, 'close'): + await transport.close() + + +@pytest.mark.asyncio +async def test_grpc_transport_get_card( + grpc_server_and_handler: tuple[str, AsyncMock], + agent_card: AgentCard, +) -> None: + server_address, _ = grpc_server_and_handler + agent_card.url = server_address + + def channel_factory(address: str) -> Channel: + return grpc.aio.insecure_channel(address) + + stub = a2a_pb2_grpc.A2AServiceStub(channel_factory(server_address)) + transport = GrpcTransport(grpc_stub=stub, agent_card=agent_card) + + # The transport starts with a minimal card, get_card() fetches the full one + transport.agent_card.supports_authenticated_extended_card = True + result = await transport.get_card() + + assert result.name == agent_card.name + assert transport.agent_card.name == agent_card.name + assert transport._needs_extended_card is False + + await transport.close() diff --git a/tests/utils/test_proto_utils.py b/tests/utils/test_proto_utils.py index 05c027fe..74699c68 100644 --- a/tests/utils/test_proto_utils.py +++ b/tests/utils/test_proto_utils.py @@ -237,13 +237,8 @@ def test_task_id_params_from_proto_invalid_name(self): assert isinstance(exc_info.value.error, types.InvalidParamsError) def test_task_push_config_from_proto_invalid_parent(self): - request = proto_utils.ToProto.task_push_notification_config( - types.TaskPushNotificationConfig( - task_id='test-task-id', - push_notification_config=types.PushNotificationConfig( - url='test_url' - ), - ) + request = a2a_pb2.TaskPushNotificationConfig( + name='invalid-name-format' ) with pytest.raises(ServerError) as exc_info: proto_utils.FromProto.task_push_notification_config(request) From 2cd2fcdcddf8ea15a13f3f91a5231a80f1b48851 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Wed, 30 Jul 2025 03:42:12 +0000 Subject: [PATCH 05/12] Fix remaining broken tests --- src/a2a/client/transports/jsonrpc.py | 24 +++++++++++++----------- src/a2a/server/apps/rest/rest_app.py | 4 ++-- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src/a2a/client/transports/jsonrpc.py b/src/a2a/client/transports/jsonrpc.py index 4bc2dfa6..0b3236c1 100644 --- a/src/a2a/client/transports/jsonrpc.py +++ b/src/a2a/client/transports/jsonrpc.py @@ -22,13 +22,14 @@ AgentCard, CancelTaskRequest, CancelTaskResponse, + GetAuthenticatedExtendedCardRequest, + GetAuthenticatedExtendedCardResponse, GetTaskPushNotificationConfigParams, GetTaskPushNotificationConfigRequest, GetTaskPushNotificationConfigResponse, GetTaskRequest, GetTaskResponse, JSONRPCErrorResponse, - JSONRPCRequest, Message, MessageSendParams, SendMessageRequest, @@ -349,23 +350,24 @@ async def get_card( if not self._needs_extended_card: return card - _, modified_kwargs = await self._apply_interceptors( - 'agent/getAuthenticatedExtendedCard', - {}, + request = GetAuthenticatedExtendedCardRequest(id=str(uuid4())) + payload, modified_kwargs = await self._apply_interceptors( + request.method, + request.model_dump(mode='json', exclude_none=True), self._get_http_args(context), context, ) response_data = await self._send_request( - JSONRPCRequest( - method='agent/getAuthenticatedExtendedCard', - params={}, - id=str(uuid4()), - ).model_dump(), + payload, modified_kwargs, ) - card = AgentCard.model_validate(response_data) - self.agent_card = card + response = GetAuthenticatedExtendedCardResponse.model_validate( + response_data + ) + if isinstance(response.root, JSONRPCErrorResponse): + raise A2AClientJSONRPCError(response.root) + self.agent_card = response.root.result self._needs_extended_card = False return card diff --git a/src/a2a/server/apps/rest/rest_app.py b/src/a2a/server/apps/rest/rest_app.py index ad723f04..5c81d024 100644 --- a/src/a2a/server/apps/rest/rest_app.py +++ b/src/a2a/server/apps/rest/rest_app.py @@ -196,14 +196,14 @@ def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]: ('/v1/message:send', 'POST'): functools.partial( self._handle_request, self.handler.on_message_send ), - ('/v1/message:stream', 'GET'): functools.partial( + ('/v1/message:stream', 'POST'): functools.partial( self._handle_streaming_request, self.handler.on_message_send_stream, ), ('/v1/tasks/{id}:cancel', 'POST'): functools.partial( self._handle_request, self.handler.on_cancel_task ), - ('/v1/tasks/{id}:subscribe', 'POST'): functools.partial( + ('/v1/tasks/{id}:subscribe', 'GET'): functools.partial( self._handle_streaming_request, self.handler.on_resubscribe_to_task, ), From 43509d8ac0c9b75b270ddf9afe3af0f13cf56e77 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Wed, 30 Jul 2025 03:54:05 +0000 Subject: [PATCH 06/12] Gemini authored: attempt to restore JSON-RPC client tests --- src/a2a/utils/helpers.py | 21 ++- tests/client/test_jsonrpc_client.py | 252 +++++++++++++++++++++++++++- 2 files changed, 261 insertions(+), 12 deletions(-) diff --git a/src/a2a/utils/helpers.py b/src/a2a/utils/helpers.py index 5f9ee94b..0760690b 100644 --- a/src/a2a/utils/helpers.py +++ b/src/a2a/utils/helpers.py @@ -149,19 +149,18 @@ async def async_wrapper(self: Any, *args, **kwargs) -> Any: return await function(self, *args, **kwargs) return async_wrapper - else: - @functools.wraps(function) - def sync_wrapper(self: Any, *args, **kwargs) -> Any: - if not expression(self): - final_message = error_message or str(expression) - logger.error(f'Unsupported Operation: {final_message}') - raise ServerError( - UnsupportedOperationError(message=final_message) - ) - return function(self, *args, **kwargs) + @functools.wraps(function) + def sync_wrapper(self: Any, *args, **kwargs) -> Any: + if not expression(self): + final_message = error_message or str(expression) + logger.error(f'Unsupported Operation: {final_message}') + raise ServerError( + UnsupportedOperationError(message=final_message) + ) + return function(self, *args, **kwargs) - return sync_wrapper + return sync_wrapper return decorator diff --git a/tests/client/test_jsonrpc_client.py b/tests/client/test_jsonrpc_client.py index 0441fce5..57c8e191 100644 --- a/tests/client/test_jsonrpc_client.py +++ b/tests/client/test_jsonrpc_client.py @@ -528,4 +528,254 @@ async def test_send_message_client_timeout( with pytest.raises(A2AClientTimeoutError) as exc_info: await client.send_message(request=params) - assert 'Request timed out' in str(exc_info.value) + assert 'Client Request timed out' in str(exc_info.value) + + @pytest.mark.asyncio + async def test_get_task_success( + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + ): + client = JsonRpcTransport( + httpx_client=mock_httpx_client, agent_card=mock_agent_card + ) + params = TaskQueryParams(id='task-abc') + rpc_response = { + 'id': '123', + 'jsonrpc': '2.0', + 'result': MINIMAL_TASK, + } + with patch.object( + client, '_send_request', new_callable=AsyncMock + ) as mock_send_request: + mock_send_request.return_value = rpc_response + response = await client.get_task(request=params) + + assert isinstance(response, Task) + assert response.model_dump() == Task.model_validate( + MINIMAL_TASK + ).model_dump() + mock_send_request.assert_called_once() + sent_payload = mock_send_request.call_args.args[0] + assert sent_payload['method'] == 'tasks/get' + + @pytest.mark.asyncio + async def test_cancel_task_success( + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + ): + client = JsonRpcTransport( + httpx_client=mock_httpx_client, agent_card=mock_agent_card + ) + params = TaskIdParams(id='task-abc') + rpc_response = { + 'id': '123', + 'jsonrpc': '2.0', + 'result': MINIMAL_CANCELLED_TASK, + } + with patch.object( + client, '_send_request', new_callable=AsyncMock + ) as mock_send_request: + mock_send_request.return_value = rpc_response + response = await client.cancel_task(request=params) + + assert isinstance(response, Task) + assert response.model_dump() == Task.model_validate( + MINIMAL_CANCELLED_TASK + ).model_dump() + mock_send_request.assert_called_once() + sent_payload = mock_send_request.call_args.args[0] + assert sent_payload['method'] == 'tasks/cancel' + + @pytest.mark.asyncio + async def test_set_task_callback_success( + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + ): + client = JsonRpcTransport( + httpx_client=mock_httpx_client, agent_card=mock_agent_card + ) + params = TaskPushNotificationConfig( + task_id='task-abc', + push_notification_config=PushNotificationConfig( + url='http://callback.com' + ), + ) + rpc_response = { + 'id': '123', + 'jsonrpc': '2.0', + 'result': params.model_dump(mode='json'), + } + with patch.object( + client, '_send_request', new_callable=AsyncMock + ) as mock_send_request: + mock_send_request.return_value = rpc_response + response = await client.set_task_callback(request=params) + + assert isinstance(response, TaskPushNotificationConfig) + assert response.model_dump() == params.model_dump() + mock_send_request.assert_called_once() + sent_payload = mock_send_request.call_args.args[0] + assert sent_payload['method'] == 'tasks/pushNotificationConfig/set' + + @pytest.mark.asyncio + async def test_get_task_callback_success( + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + ): + client = JsonRpcTransport( + httpx_client=mock_httpx_client, agent_card=mock_agent_card + ) + params = TaskIdParams(id='task-abc') + expected_response = TaskPushNotificationConfig( + task_id='task-abc', + push_notification_config=PushNotificationConfig( + url='http://callback.com' + ), + ) + rpc_response = { + 'id': '123', + 'jsonrpc': '2.0', + 'result': expected_response.model_dump(mode='json'), + } + with patch.object( + client, '_send_request', new_callable=AsyncMock + ) as mock_send_request: + mock_send_request.return_value = rpc_response + response = await client.get_task_callback(request=params) + + assert isinstance(response, TaskPushNotificationConfig) + assert response.model_dump() == expected_response.model_dump() + mock_send_request.assert_called_once() + sent_payload = mock_send_request.call_args.args[0] + assert sent_payload['method'] == 'tasks/pushNotificationConfig/get' + + @pytest.mark.asyncio + @patch('a2a.client.transports.jsonrpc.aconnect_sse') + async def test_send_message_streaming_sse_error( + self, + mock_aconnect_sse: AsyncMock, + mock_httpx_client: AsyncMock, + mock_agent_card: MagicMock, + ): + client = JsonRpcTransport( + httpx_client=mock_httpx_client, agent_card=mock_agent_card + ) + params = MessageSendParams( + message=create_text_message_object(content='Hello stream') + ) + mock_event_source = AsyncMock(spec=EventSource) + mock_event_source.aiter_sse.side_effect = SSEError( + 'Simulated SSE error' + ) + mock_aconnect_sse.return_value.__aenter__.return_value = ( + mock_event_source + ) + + with pytest.raises(A2AClientHTTPError): + _ = [ + item + async for item in client.send_message_streaming(request=params) + ] + + @pytest.mark.asyncio + @patch('a2a.client.transports.jsonrpc.aconnect_sse') + async def test_send_message_streaming_json_error( + self, + mock_aconnect_sse: AsyncMock, + mock_httpx_client: AsyncMock, + mock_agent_card: MagicMock, + ): + client = JsonRpcTransport( + httpx_client=mock_httpx_client, agent_card=mock_agent_card + ) + params = MessageSendParams( + message=create_text_message_object(content='Hello stream') + ) + sse_event = ServerSentEvent(data='{invalid json') + mock_event_source = AsyncMock(spec=EventSource) + mock_event_source.aiter_sse.return_value = async_iterable_from_list( + [sse_event] + ) + mock_aconnect_sse.return_value.__aenter__.return_value = ( + mock_event_source + ) + + with pytest.raises(A2AClientJSONError): + _ = [ + item + async for item in client.send_message_streaming(request=params) + ] + + @pytest.mark.asyncio + @patch('a2a.client.transports.jsonrpc.aconnect_sse') + async def test_send_message_streaming_request_error( + self, + mock_aconnect_sse: AsyncMock, + mock_httpx_client: AsyncMock, + mock_agent_card: MagicMock, + ): + client = JsonRpcTransport( + httpx_client=mock_httpx_client, agent_card=mock_agent_card + ) + params = MessageSendParams( + message=create_text_message_object(content='Hello stream') + ) + mock_event_source = AsyncMock(spec=EventSource) + mock_event_source.aiter_sse.side_effect = httpx.RequestError( + 'Simulated request error', request=MagicMock() + ) + mock_aconnect_sse.return_value.__aenter__.return_value = ( + mock_event_source + ) + + with pytest.raises(A2AClientHTTPError): + _ = [ + item + async for item in client.send_message_streaming(request=params) + ] + + @pytest.mark.asyncio + async def test_get_card_no_card_provided(self, mock_httpx_client: AsyncMock): + client = JsonRpcTransport( + httpx_client=mock_httpx_client, url=self.AGENT_URL + ) + mock_response = AsyncMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.json.return_value = AGENT_CARD.model_dump(mode='json') + mock_httpx_client.get.return_value = mock_response + + card = await client.get_card() + + assert card == AGENT_CARD + mock_httpx_client.get.assert_called_once() + + @pytest.mark.asyncio + async def test_get_card_with_extended_card_support( + self, mock_httpx_client: AsyncMock + ): + agent_card = AGENT_CARD.model_copy( + update={'supports_authenticated_extended_card': True} + ) + client = JsonRpcTransport( + httpx_client=mock_httpx_client, agent_card=agent_card + ) + + rpc_response = { + 'id': '123', + 'jsonrpc': '2.0', + 'result': AGENT_CARD_EXTENDED.model_dump(mode='json'), + } + with patch.object( + client, '_send_request', new_callable=AsyncMock + ) as mock_send_request: + mock_send_request.return_value = rpc_response + card = await client.get_card() + + assert card == agent_card + mock_send_request.assert_called_once() + sent_payload = mock_send_request.call_args.args[0] + assert sent_payload['method'] == 'agent/getAuthenticatedExtendedCard' + + @pytest.mark.asyncio + async def test_close(self, mock_httpx_client: AsyncMock): + client = JsonRpcTransport( + httpx_client=mock_httpx_client, url=self.AGENT_URL + ) + await client.close() + mock_httpx_client.aclose.assert_called_once() From d057808c255cbe7c8a1fb58809e88ecc3a13cbbd Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Wed, 30 Jul 2025 04:37:08 +0000 Subject: [PATCH 07/12] Gemini authored: attempt to create a backwards compatible A2AClient and A2AGrpcClient --- src/a2a/client/__init__.py | 21 +++ src/a2a/client/legacy.py | 214 +++++++++++++++++++++++++++++ src/a2a/client/legacy_grpc.py | 26 ++++ tests/client/test_legacy_client.py | 188 +++++++++++++++++++++++++ 4 files changed, 449 insertions(+) create mode 100644 src/a2a/client/legacy.py create mode 100644 src/a2a/client/legacy_grpc.py create mode 100644 tests/client/test_legacy_client.py diff --git a/src/a2a/client/__init__.py b/src/a2a/client/__init__.py index 96d27033..bd81fa94 100644 --- a/src/a2a/client/__init__.py +++ b/src/a2a/client/__init__.py @@ -20,18 +20,39 @@ A2AClientTimeoutError, ) from a2a.client.helpers import create_text_message_object +from a2a.client.legacy import A2AClient from a2a.client.middleware import ClientCallContext, ClientCallInterceptor logger = logging.getLogger(__name__) +try: + from a2a.client.legacy_grpc import A2AGrpcClient +except ImportError as e: + _original_error = e + logger.debug( + "A2AGrpcClient not loaded. This is expected if gRPC dependencies are not installed. Error: %s", + _original_error, + ) + + class A2AGrpcClient: # type: ignore + """Placeholder for A2AGrpcClient when dependencies are not installed.""" + + def __init__(self, *args, **kwargs): + raise ImportError( + "To use A2AGrpcClient, its dependencies must be installed. " + 'You can install them with \'pip install "a2a-sdk[grpc]"\'' + ) from _original_error + __all__ = [ "A2ACardResolver", + "A2AClient", "A2AClientError", "A2AClientHTTPError", "A2AClientJSONError", "A2AClientTimeoutError", + "A2AGrpcClient", "AuthInterceptor", "Client", "ClientCallContext", diff --git a/src/a2a/client/legacy.py b/src/a2a/client/legacy.py new file mode 100644 index 00000000..544ccbad --- /dev/null +++ b/src/a2a/client/legacy.py @@ -0,0 +1,214 @@ +"""Backwards compatibility layer for legacy A2A clients.""" + +import warnings + +from collections.abc import AsyncGenerator +from typing import Any + +import httpx + +from a2a.client.errors import A2AClientJSONRPCError +from a2a.client.middleware import ClientCallContext, ClientCallInterceptor +from a2a.client.transports.jsonrpc import JsonRpcTransport +from a2a.types import ( + AgentCard, + CancelTaskRequest, + CancelTaskResponse, + CancelTaskSuccessResponse, + GetTaskPushNotificationConfigRequest, + GetTaskPushNotificationConfigResponse, + GetTaskPushNotificationConfigSuccessResponse, + GetTaskRequest, + GetTaskResponse, + GetTaskSuccessResponse, + JSONRPCErrorResponse, + SendMessageRequest, + SendMessageResponse, + SendMessageSuccessResponse, + SendStreamingMessageRequest, + SendStreamingMessageResponse, + SendStreamingMessageSuccessResponse, + SetTaskPushNotificationConfigRequest, + SetTaskPushNotificationConfigResponse, + SetTaskPushNotificationConfigSuccessResponse, + TaskResubscriptionRequest, +) + + +class A2AClient: + """[DEPRECATED] Backwards compatibility wrapper for the JSON-RPC client.""" + + def __init__( + self, + httpx_client: httpx.AsyncClient, + agent_card: AgentCard | None = None, + url: str | None = None, + interceptors: list[ClientCallInterceptor] | None = None, + ): + warnings.warn( + 'A2AClient is deprecated and will be removed in a future version. ' + 'Use ClientFactory to create a client with a JSON-RPC transport.', + DeprecationWarning, + stacklevel=2, + ) + self._transport = JsonRpcTransport( + httpx_client, agent_card, url, interceptors + ) + + async def send_message( + self, + request: SendMessageRequest, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> SendMessageResponse: + if not context and http_kwargs: + context = ClientCallContext(state={'http_kwargs': http_kwargs}) + + try: + result = await self._transport.send_message( + request.params, context=context + ) + return SendMessageResponse( + root=SendMessageSuccessResponse( + id=request.id, jsonrpc='2.0', result=result + ) + ) + except A2AClientJSONRPCError as e: + return SendMessageResponse(root=JSONRPCErrorResponse(error=e.error)) + + async def send_message_streaming( + self, + request: SendStreamingMessageRequest, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[SendStreamingMessageResponse, None]: + if not context and http_kwargs: + context = ClientCallContext(state={'http_kwargs': http_kwargs}) + + async for result in self._transport.send_message_streaming( + request.params, context=context + ): + yield SendStreamingMessageResponse( + root=SendStreamingMessageSuccessResponse( + id=request.id, jsonrpc='2.0', result=result + ) + ) + + async def get_task( + self, + request: GetTaskRequest, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> GetTaskResponse: + if not context and http_kwargs: + context = ClientCallContext(state={'http_kwargs': http_kwargs}) + try: + result = await self._transport.get_task( + request.params, context=context + ) + return GetTaskResponse( + root=GetTaskSuccessResponse( + id=request.id, jsonrpc='2.0', result=result + ) + ) + except A2AClientJSONRPCError as e: + return GetTaskResponse(root=JSONRPCErrorResponse(error=e.error)) + + async def cancel_task( + self, + request: CancelTaskRequest, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> CancelTaskResponse: + if not context and http_kwargs: + context = ClientCallContext(state={'http_kwargs': http_kwargs}) + try: + result = await self._transport.cancel_task( + request.params, context=context + ) + return CancelTaskResponse( + root=CancelTaskSuccessResponse( + id=request.id, jsonrpc='2.0', result=result + ) + ) + except A2AClientJSONRPCError as e: + return CancelTaskResponse(root=JSONRPCErrorResponse(error=e.error)) + + async def set_task_callback( + self, + request: SetTaskPushNotificationConfigRequest, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> SetTaskPushNotificationConfigResponse: + if not context and http_kwargs: + context = ClientCallContext(state={'http_kwargs': http_kwargs}) + try: + result = await self._transport.set_task_callback( + request.params, context=context + ) + return SetTaskPushNotificationConfigResponse( + root=SetTaskPushNotificationConfigSuccessResponse( + id=request.id, jsonrpc='2.0', result=result + ) + ) + except A2AClientJSONRPCError as e: + return SetTaskPushNotificationConfigResponse( + root=JSONRPCErrorResponse(error=e.error) + ) + + async def get_task_callback( + self, + request: GetTaskPushNotificationConfigRequest, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> GetTaskPushNotificationConfigResponse: + if not context and http_kwargs: + context = ClientCallContext(state={'http_kwargs': http_kwargs}) + try: + result = await self._transport.get_task_callback( + request.params, context=context + ) + return GetTaskPushNotificationConfigResponse( + root=GetTaskPushNotificationConfigSuccessResponse( + id=request.id, jsonrpc='2.0', result=result + ) + ) + except A2AClientJSONRPCError as e: + return GetTaskPushNotificationConfigResponse( + root=JSONRPCErrorResponse(error=e.error) + ) + + async def resubscribe( + self, + request: TaskResubscriptionRequest, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[SendStreamingMessageResponse, None]: + if not context and http_kwargs: + context = ClientCallContext(state={'http_kwargs': http_kwargs}) + + async for result in self._transport.resubscribe( + request.params, context=context + ): + yield SendStreamingMessageResponse( + root=SendStreamingMessageSuccessResponse( + id=request.id, jsonrpc='2.0', result=result + ) + ) + + async def get_card( + self, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> AgentCard: + if not context and http_kwargs: + context = ClientCallContext(state={'http_kwargs': http_kwargs}) + return await self._transport.get_card(context=context) diff --git a/src/a2a/client/legacy_grpc.py b/src/a2a/client/legacy_grpc.py new file mode 100644 index 00000000..b3ba150f --- /dev/null +++ b/src/a2a/client/legacy_grpc.py @@ -0,0 +1,26 @@ +"""Backwards compatibility layer for the legacy A2A gRPC client.""" + +import warnings + +from a2a.client.transports.grpc import GrpcTransport +from a2a.grpc import a2a_pb2_grpc +from a2a.types import AgentCard + + +class A2AGrpcClient(GrpcTransport): + """ + [DEPRECATED] Backwards compatibility wrapper for the gRPC client. + """ + + def __init__( + self, + grpc_stub: "a2a_pb2_grpc.A2AServiceStub", + agent_card: AgentCard, + ): + warnings.warn( + "A2AGrpcClient is deprecated and will be removed in a future version. " + "Use ClientFactory to create a client with a gRPC transport.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__(grpc_stub, agent_card) \ No newline at end of file diff --git a/tests/client/test_legacy_client.py b/tests/client/test_legacy_client.py new file mode 100644 index 00000000..ea45e6a1 --- /dev/null +++ b/tests/client/test_legacy_client.py @@ -0,0 +1,188 @@ +"""Tests for the legacy client compatibility layer.""" + +from unittest.mock import AsyncMock, MagicMock + +import httpx +import pytest + +from a2a.client import A2AClient, A2AGrpcClient +from a2a.types import ( + AgentCapabilities, + AgentCard, + GetTaskRequest, + Message, + MessageSendParams, + Part, + Role, + SendMessageRequest, + Task, + TaskQueryParams, + TaskState, + TaskStatus, + TextPart, +) + + +@pytest.fixture +def mock_httpx_client() -> AsyncMock: + return AsyncMock(spec=httpx.AsyncClient) + + +@pytest.fixture +def mock_grpc_stub() -> AsyncMock: + stub = AsyncMock() + stub._channel = MagicMock() + return stub + + +@pytest.fixture +def jsonrpc_agent_card() -> AgentCard: + return AgentCard( + name='Test Agent', + description='A test agent', + url='http://test.agent.com/rpc', + version='1.0.0', + capabilities=AgentCapabilities(streaming=True), + skills=[], + default_input_modes=[], + default_output_modes=[], + preferred_transport='jsonrpc', + ) + + +@pytest.fixture +def grpc_agent_card() -> AgentCard: + return AgentCard( + name='Test Agent', + description='A test agent', + url='http://test.agent.com/rpc', + version='1.0.0', + capabilities=AgentCapabilities(streaming=True), + skills=[], + default_input_modes=[], + default_output_modes=[], + preferred_transport='grpc', + ) + + +@pytest.mark.asyncio +async def test_a2a_client_send_message( + mock_httpx_client: AsyncMock, jsonrpc_agent_card: AgentCard +): + """Tests for the legacy client compatibility layer.""" + + +from unittest.mock import AsyncMock, MagicMock + +import httpx +import pytest + +from a2a.client import A2AClient, A2AGrpcClient +from a2a.types import ( + AgentCapabilities, + AgentCard, + GetTaskRequest, + Message, + MessageSendParams, + Part, + Role, + SendMessageRequest, + Task, + TaskQueryParams, + TaskState, + TaskStatus, + TextPart, +) + + +@pytest.fixture +def mock_httpx_client() -> AsyncMock: + return AsyncMock(spec=httpx.AsyncClient) + + +@pytest.fixture +def mock_grpc_stub() -> AsyncMock: + stub = AsyncMock() + stub._channel = MagicMock() + return stub + + +@pytest.fixture +def jsonrpc_agent_card() -> AgentCard: + return AgentCard( + name='Test Agent', + description='A test agent', + url='http://test.agent.com/rpc', + version='1.0.0', + capabilities=AgentCapabilities(streaming=True), + skills=[], + default_input_modes=[], + default_output_modes=[], + preferred_transport='jsonrpc', + ) + + +@pytest.fixture +def grpc_agent_card() -> AgentCard: + return AgentCard( + name='Test Agent', + description='A test agent', + url='http://test.agent.com/rpc', + version='1.0.0', + capabilities=AgentCapabilities(streaming=True), + skills=[], + default_input_modes=[], + default_output_modes=[], + preferred_transport='grpc', + ) + + +@pytest.mark.asyncio +async def test_a2a_client_send_message( + mock_httpx_client: AsyncMock, jsonrpc_agent_card: AgentCard +): + client = A2AClient( + httpx_client=mock_httpx_client, agent_card=jsonrpc_agent_card + ) + + # Mock the underlying transport's send_message method + mock_response_task = Task( + id='task-123', + context_id='ctx-456', + status=TaskStatus(state=TaskState.completed), + ) + + client._transport.send_message = AsyncMock(return_value=mock_response_task) + + message = Message( + message_id='msg-123', + role=Role.user, + parts=[Part(root=TextPart(text='Hello'))], + ) + request = SendMessageRequest( + id='req-123', params=MessageSendParams(message=message) + ) + response = await client.send_message(request) + + assert response.root.result.id == 'task-123' + + +@pytest.mark.asyncio +async def test_a2a_grpc_client_get_task( + mock_grpc_stub: AsyncMock, grpc_agent_card: AgentCard +): + client = A2AGrpcClient(grpc_stub=mock_grpc_stub, agent_card=grpc_agent_card) + + mock_response_task = Task( + id='task-456', + context_id='ctx-789', + status=TaskStatus(state=TaskState.working), + ) + + client.get_task.return_value = mock_response_task + + params = TaskQueryParams(id='task-456') + response = await client.get_task(params) + + assert response.id == 'task-456' + client.get_task.assert_awaited_once_with(params) From b32b0da44705b5220a5b51318ee264c1a2cc7fbd Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Wed, 30 Jul 2025 04:41:36 +0000 Subject: [PATCH 08/12] Restore some comments gemini mysteriously deleted --- src/a2a/client/client_factory.py | 41 +++++++++++++++++++++++++++++--- 1 file changed, 38 insertions(+), 3 deletions(-) diff --git a/src/a2a/client/client_factory.py b/src/a2a/client/client_factory.py index 1e312793..fbace63c 100644 --- a/src/a2a/client/client_factory.py +++ b/src/a2a/client/client_factory.py @@ -38,7 +38,20 @@ class ClientFactory: - """ClientFactory is used to generate the appropriate client for the agent.""" + """ClientFactory is used to generate the appropriate client for the agent. + + The factory is configured with a `ClientConfig` and optionally a list of + `Consumer`s to use for all generated `Client`s. The expected use is: + + factory = ClientFactory(config, consumers) + # Optionally register custom client implementations + factory.register('my_customer_transport', NewCustomTransportClient) + # Then with an agent card make a client with additional consumers and + # interceptors + client = factory.create(card, additional_consumers, interceptors) + # Now the client can be used the same regardless of transport and + # aligns client config with server capabilities. + """ def __init__( self, @@ -92,7 +105,22 @@ def create( consumers: list[Consumer] | None = None, interceptors: list[ClientCallInterceptor] | None = None, ) -> Client: - """Create a new `Client` for the provided `AgentCard`.""" + """Create a new `Client` for the provided `AgentCard`. + + Args: + card: An `AgentCard` defining the characteristics of the agent. + consumers: A list of `Consumer` methods to pass responses to. + interceptors: A list of interceptors to use for each request. These + are used for things like attaching credentials or http headers + to all outbound requests. + + Returns: + A `Client` object. + + Raises: + If there is no valid matching of the client configuration with the + server configuration, a `ValueError` is raised. + """ server_set = [card.preferred_transport or TransportProtocol.jsonrpc] if card.additional_interfaces: server_set.extend([x.transport for x in card.additional_interfaces]) @@ -131,7 +159,14 @@ def create( def minimal_agent_card( url: str, transports: list[str] | None = None ) -> AgentCard: - """Generates a minimal card to simplify bootstrapping client creation.""" + """Generates a minimal card to simplify bootstrapping client creation. + + This minimal card is not viable itself to interact with the remote agent. + Instead this is a short hand way to take a known url and transport option + and interact with the get card endpoint of the agent server to get the + correct agent card. This pattern is necessary for gRPC based card access + as typically these servers won't expose a well known path card. + """ if transports is None: transports = [] return AgentCard( From 12e595df2de57822f05abb3496b339cca51c5a80 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Wed, 30 Jul 2025 14:57:56 +0000 Subject: [PATCH 09/12] Fix an issue with ClientFactory not respecting transport URL, add tests --- src/a2a/client/client_factory.py | 30 +++++---- src/a2a/client/transports/jsonrpc.py | 6 +- src/a2a/client/transports/rest.py | 6 +- tests/client/test_client_factory.py | 93 ++++++++++++++++++++++++++++ 4 files changed, 117 insertions(+), 18 deletions(-) create mode 100644 tests/client/test_client_factory.py diff --git a/src/a2a/client/client_factory.py b/src/a2a/client/client_factory.py index fbace63c..b2d914ee 100644 --- a/src/a2a/client/client_factory.py +++ b/src/a2a/client/client_factory.py @@ -32,7 +32,7 @@ TransportProducer = Callable[ - [AgentCard, ClientConfig, list[ClientCallInterceptor]], + [AgentCard, str, ClientConfig, list[ClientCallInterceptor]], ClientTransport, ] @@ -68,28 +68,28 @@ def __init__( def _register_defaults(self) -> None: self.register( TransportProtocol.jsonrpc, - lambda card, config, interceptors: JsonRpcTransport( + lambda card, url, config, interceptors: JsonRpcTransport( config.httpx_client or httpx.AsyncClient(), card, - card.url, + url, interceptors, ), ) self.register( TransportProtocol.http_json, - lambda card, config, interceptors: RestTransport( + lambda card, url, config, interceptors: RestTransport( config.httpx_client or httpx.AsyncClient(), card, - card.url, + url, interceptors, ), ) if GrpcTransport: self.register( TransportProtocol.grpc, - lambda card, config, interceptors: GrpcTransport( + lambda card, url, config, interceptors: GrpcTransport( a2a_pb2_grpc.A2AServiceStub( - config.grpc_channel_factory(card.url) + config.grpc_channel_factory(url) ), card, ), @@ -121,24 +121,30 @@ def create( If there is no valid matching of the client configuration with the server configuration, a `ValueError` is raised. """ - server_set = [card.preferred_transport or TransportProtocol.jsonrpc] + server_preferred = card.preferred_transport or TransportProtocol.jsonrpc + server_set = {server_preferred: card.url} if card.additional_interfaces: - server_set.extend([x.transport for x in card.additional_interfaces]) + server_set.update( + {x.transport: x.url for x in card.additional_interfaces} + ) client_set = self._config.supported_transports or [ TransportProtocol.jsonrpc ] transport_protocol = None + transport_url = None if self._config.use_client_preference: for x in client_set: if x in server_set: transport_protocol = x + transport_url = server_set[x] break else: - for x in server_set: + for x, url in server_set.items(): if x in client_set: transport_protocol = x + transport_url = url break - if not transport_protocol: + if not transport_protocol or not transport_url: raise ValueError('no compatible transports found.') if transport_protocol not in self._registry: raise ValueError(f'no client available for {transport_protocol}') @@ -148,7 +154,7 @@ def create( all_consumers.extend(consumers) transport = self._registry[transport_protocol]( - card, self._config, interceptors or [] + card, transport_url, self._config, interceptors or [] ) return BaseClient( diff --git a/src/a2a/client/transports/jsonrpc.py b/src/a2a/client/transports/jsonrpc.py index 0b3236c1..868b3a01 100644 --- a/src/a2a/client/transports/jsonrpc.py +++ b/src/a2a/client/transports/jsonrpc.py @@ -64,10 +64,10 @@ def __init__( interceptors: list[ClientCallInterceptor] | None = None, ): """Initializes the JsonRpcTransport.""" - if agent_card: - self.url = agent_card.url - elif url: + if url: self.url = url + elif agent_card: + self.url = agent_card.url else: raise ValueError('Must provide either agent_card or url') diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index 9b22d165..7c20fe9c 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -45,10 +45,10 @@ def __init__( interceptors: list[ClientCallInterceptor] | None = None, ): """Initializes the RestTransport.""" - if agent_card: - self.url = agent_card.url - elif url: + if url: self.url = url + elif agent_card: + self.url = agent_card.url else: raise ValueError('Must provide either agent_card or url') if self.url.endswith('/'): diff --git a/tests/client/test_client_factory.py b/tests/client/test_client_factory.py new file mode 100644 index 00000000..109c489d --- /dev/null +++ b/tests/client/test_client_factory.py @@ -0,0 +1,93 @@ +"""Tests for the ClientFactory.""" + +import httpx +import pytest + +from a2a.client import ClientConfig, ClientFactory +from a2a.client.transports import JsonRpcTransport, RestTransport +from a2a.types import ( + AgentCard, + AgentCapabilities, + AgentInterface, + TransportProtocol, +) + + +@pytest.fixture +def base_agent_card() -> AgentCard: + """Provides a base AgentCard for tests.""" + return AgentCard( + name="Test Agent", + description="An agent for testing.", + url="http://primary-url.com", + version="1.0.0", + capabilities=AgentCapabilities(), + skills=[], + default_input_modes=[], + default_output_modes=[], + preferred_transport=TransportProtocol.jsonrpc, + ) + + +def test_client_factory_selects_preferred_transport(base_agent_card: AgentCard): + """Verify that the factory selects the preferred transport by default.""" + config = ClientConfig( + httpx_client=httpx.AsyncClient(), + supported_transports=[TransportProtocol.jsonrpc, TransportProtocol.http_json], + ) + factory = ClientFactory(config) + client = factory.create(base_agent_card) + + assert isinstance(client._transport, JsonRpcTransport) + assert client._transport.url == "http://primary-url.com" + + +def test_client_factory_selects_secondary_transport_url(base_agent_card: AgentCard): + """Verify that the factory selects the correct URL for a secondary transport.""" + base_agent_card.additional_interfaces = [ + AgentInterface( + transport=TransportProtocol.http_json, url="http://secondary-url.com" + ) + ] + # Client prefers REST, which is available as a secondary transport + config = ClientConfig( + httpx_client=httpx.AsyncClient(), + supported_transports=[TransportProtocol.http_json, TransportProtocol.jsonrpc], + use_client_preference=True, + ) + factory = ClientFactory(config) + client = factory.create(base_agent_card) + + assert isinstance(client._transport, RestTransport) + assert client._transport.url == "http://secondary-url.com" + + +def test_client_factory_server_preference(base_agent_card: AgentCard): + """Verify that the factory respects server transport preference.""" + base_agent_card.preferred_transport = TransportProtocol.http_json + base_agent_card.additional_interfaces = [ + AgentInterface( + transport=TransportProtocol.jsonrpc, url="http://secondary-url.com" + ) + ] + # Client supports both, but server prefers REST + config = ClientConfig( + httpx_client=httpx.AsyncClient(), + supported_transports=[TransportProtocol.jsonrpc, TransportProtocol.http_json], + ) + factory = ClientFactory(config) + client = factory.create(base_agent_card) + + assert isinstance(client._transport, RestTransport) + assert client._transport.url == "http://primary-url.com" + + +def test_client_factory_no_compatible_transport(base_agent_card: AgentCard): + """Verify that the factory raises an error if no compatible transport is found.""" + config = ClientConfig( + httpx_client=httpx.AsyncClient(), + supported_transports=[TransportProtocol.grpc], + ) + factory = ClientFactory(config) + with pytest.raises(ValueError, match="no compatible transports found"): + factory.create(base_agent_card) From 9182fdb778e4f9d49edea25e60c7b9b63bbfca77 Mon Sep 17 00:00:00 2001 From: Holt Skinner Date: Wed, 30 Jul 2025 16:12:22 +0100 Subject: [PATCH 10/12] formatting --- src/a2a/client/__init__.py | 42 ++++----- src/a2a/client/legacy_grpc.py | 12 ++- src/a2a/client/transports/__init__.py | 8 +- src/a2a/utils/proto_utils.py | 4 +- tests/client/test_auth_middleware.py | 3 +- tests/client/test_client_factory.py | 42 +++++---- tests/client/test_jsonrpc_client.py | 26 +++--- tests/client/test_legacy_client.py | 17 +--- .../test_client_server_integration.py | 85 ++++++++++++------- tests/utils/test_proto_utils.py | 4 +- 10 files changed, 136 insertions(+), 107 deletions(-) diff --git a/src/a2a/client/__init__.py b/src/a2a/client/__init__.py index bd81fa94..dae06357 100644 --- a/src/a2a/client/__init__.py +++ b/src/a2a/client/__init__.py @@ -31,7 +31,7 @@ except ImportError as e: _original_error = e logger.debug( - "A2AGrpcClient not loaded. This is expected if gRPC dependencies are not installed. Error: %s", + 'A2AGrpcClient not loaded. This is expected if gRPC dependencies are not installed. Error: %s', _original_error, ) @@ -40,29 +40,29 @@ class A2AGrpcClient: # type: ignore def __init__(self, *args, **kwargs): raise ImportError( - "To use A2AGrpcClient, its dependencies must be installed. " + 'To use A2AGrpcClient, its dependencies must be installed. ' 'You can install them with \'pip install "a2a-sdk[grpc]"\'' ) from _original_error __all__ = [ - "A2ACardResolver", - "A2AClient", - "A2AClientError", - "A2AClientHTTPError", - "A2AClientJSONError", - "A2AClientTimeoutError", - "A2AGrpcClient", - "AuthInterceptor", - "Client", - "ClientCallContext", - "ClientCallInterceptor", - "ClientConfig", - "ClientEvent", - "ClientFactory", - "Consumer", - "CredentialService", - "InMemoryContextCredentialStore", - "create_text_message_object", - "minimal_agent_card", + 'A2ACardResolver', + 'A2AClient', + 'A2AClientError', + 'A2AClientHTTPError', + 'A2AClientJSONError', + 'A2AClientTimeoutError', + 'A2AGrpcClient', + 'AuthInterceptor', + 'Client', + 'ClientCallContext', + 'ClientCallInterceptor', + 'ClientConfig', + 'ClientEvent', + 'ClientFactory', + 'Consumer', + 'CredentialService', + 'InMemoryContextCredentialStore', + 'create_text_message_object', + 'minimal_agent_card', ] diff --git a/src/a2a/client/legacy_grpc.py b/src/a2a/client/legacy_grpc.py index b3ba150f..9a4b1656 100644 --- a/src/a2a/client/legacy_grpc.py +++ b/src/a2a/client/legacy_grpc.py @@ -8,19 +8,17 @@ class A2AGrpcClient(GrpcTransport): - """ - [DEPRECATED] Backwards compatibility wrapper for the gRPC client. - """ + """[DEPRECATED] Backwards compatibility wrapper for the gRPC client.""" def __init__( self, - grpc_stub: "a2a_pb2_grpc.A2AServiceStub", + grpc_stub: 'a2a_pb2_grpc.A2AServiceStub', agent_card: AgentCard, ): warnings.warn( - "A2AGrpcClient is deprecated and will be removed in a future version. " - "Use ClientFactory to create a client with a gRPC transport.", + 'A2AGrpcClient is deprecated and will be removed in a future version. ' + 'Use ClientFactory to create a client with a gRPC transport.', DeprecationWarning, stacklevel=2, ) - super().__init__(grpc_stub, agent_card) \ No newline at end of file + super().__init__(grpc_stub, agent_card) diff --git a/src/a2a/client/transports/__init__.py b/src/a2a/client/transports/__init__.py index 8bcca4e4..0e600ff4 100644 --- a/src/a2a/client/transports/__init__.py +++ b/src/a2a/client/transports/__init__.py @@ -12,8 +12,8 @@ __all__ = [ - "ClientTransport", - "GrpcTransport", - "JsonRpcTransport", - "RestTransport", + 'ClientTransport', + 'GrpcTransport', + 'JsonRpcTransport', + 'RestTransport', ] diff --git a/src/a2a/utils/proto_utils.py b/src/a2a/utils/proto_utils.py index 05f03ed1..541504d5 100644 --- a/src/a2a/utils/proto_utils.py +++ b/src/a2a/utils/proto_utils.py @@ -19,7 +19,9 @@ # Regexp patterns for matching _TASK_NAME_MATCH = r'tasks/([\w-]+)' -_TASK_PUSH_CONFIG_NAME_MATCH = r'tasks/([\w-]+)/pushNotificationConfigs/([\w-]+)' +_TASK_PUSH_CONFIG_NAME_MATCH = ( + r'tasks/([\w-]+)/pushNotificationConfigs/([\w-]+)' +) class ToProto: diff --git a/tests/client/test_auth_middleware.py b/tests/client/test_auth_middleware.py index 1c37992f..4f53ca3f 100644 --- a/tests/client/test_auth_middleware.py +++ b/tests/client/test_auth_middleware.py @@ -1,4 +1,5 @@ import json + from collections.abc import Callable from dataclasses import dataclass from typing import Any @@ -17,9 +18,9 @@ InMemoryContextCredentialStore, ) from a2a.types import ( + APIKeySecurityScheme, AgentCapabilities, AgentCard, - APIKeySecurityScheme, AuthorizationCodeOAuthFlow, HTTPAuthSecurityScheme, In, diff --git a/tests/client/test_client_factory.py b/tests/client/test_client_factory.py index 109c489d..d615bbff 100644 --- a/tests/client/test_client_factory.py +++ b/tests/client/test_client_factory.py @@ -6,8 +6,8 @@ from a2a.client import ClientConfig, ClientFactory from a2a.client.transports import JsonRpcTransport, RestTransport from a2a.types import ( - AgentCard, AgentCapabilities, + AgentCard, AgentInterface, TransportProtocol, ) @@ -17,10 +17,10 @@ def base_agent_card() -> AgentCard: """Provides a base AgentCard for tests.""" return AgentCard( - name="Test Agent", - description="An agent for testing.", - url="http://primary-url.com", - version="1.0.0", + name='Test Agent', + description='An agent for testing.', + url='http://primary-url.com', + version='1.0.0', capabilities=AgentCapabilities(), skills=[], default_input_modes=[], @@ -33,33 +33,42 @@ def test_client_factory_selects_preferred_transport(base_agent_card: AgentCard): """Verify that the factory selects the preferred transport by default.""" config = ClientConfig( httpx_client=httpx.AsyncClient(), - supported_transports=[TransportProtocol.jsonrpc, TransportProtocol.http_json], + supported_transports=[ + TransportProtocol.jsonrpc, + TransportProtocol.http_json, + ], ) factory = ClientFactory(config) client = factory.create(base_agent_card) assert isinstance(client._transport, JsonRpcTransport) - assert client._transport.url == "http://primary-url.com" + assert client._transport.url == 'http://primary-url.com' -def test_client_factory_selects_secondary_transport_url(base_agent_card: AgentCard): +def test_client_factory_selects_secondary_transport_url( + base_agent_card: AgentCard, +): """Verify that the factory selects the correct URL for a secondary transport.""" base_agent_card.additional_interfaces = [ AgentInterface( - transport=TransportProtocol.http_json, url="http://secondary-url.com" + transport=TransportProtocol.http_json, + url='http://secondary-url.com', ) ] # Client prefers REST, which is available as a secondary transport config = ClientConfig( httpx_client=httpx.AsyncClient(), - supported_transports=[TransportProtocol.http_json, TransportProtocol.jsonrpc], + supported_transports=[ + TransportProtocol.http_json, + TransportProtocol.jsonrpc, + ], use_client_preference=True, ) factory = ClientFactory(config) client = factory.create(base_agent_card) assert isinstance(client._transport, RestTransport) - assert client._transport.url == "http://secondary-url.com" + assert client._transport.url == 'http://secondary-url.com' def test_client_factory_server_preference(base_agent_card: AgentCard): @@ -67,19 +76,22 @@ def test_client_factory_server_preference(base_agent_card: AgentCard): base_agent_card.preferred_transport = TransportProtocol.http_json base_agent_card.additional_interfaces = [ AgentInterface( - transport=TransportProtocol.jsonrpc, url="http://secondary-url.com" + transport=TransportProtocol.jsonrpc, url='http://secondary-url.com' ) ] # Client supports both, but server prefers REST config = ClientConfig( httpx_client=httpx.AsyncClient(), - supported_transports=[TransportProtocol.jsonrpc, TransportProtocol.http_json], + supported_transports=[ + TransportProtocol.jsonrpc, + TransportProtocol.http_json, + ], ) factory = ClientFactory(config) client = factory.create(base_agent_card) assert isinstance(client._transport, RestTransport) - assert client._transport.url == "http://primary-url.com" + assert client._transport.url == 'http://primary-url.com' def test_client_factory_no_compatible_transport(base_agent_card: AgentCard): @@ -89,5 +101,5 @@ def test_client_factory_no_compatible_transport(base_agent_card: AgentCard): supported_transports=[TransportProtocol.grpc], ) factory = ClientFactory(config) - with pytest.raises(ValueError, match="no compatible transports found"): + with pytest.raises(ValueError, match='no compatible transports found'): factory.create(base_agent_card) diff --git a/tests/client/test_jsonrpc_client.py b/tests/client/test_jsonrpc_client.py index 57c8e191..a4a9b96e 100644 --- a/tests/client/test_jsonrpc_client.py +++ b/tests/client/test_jsonrpc_client.py @@ -1,11 +1,13 @@ import json + from collections.abc import AsyncGenerator from typing import Any -from unittest.mock import ANY, AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import httpx import pytest -from httpx_sse import EventSource, ServerSentEvent, SSEError + +from httpx_sse import EventSource, SSEError, ServerSentEvent from a2a.client import ( A2ACardResolver, @@ -27,12 +29,12 @@ SendMessageSuccessResponse, Task, TaskIdParams, - TaskNotCancelableError, TaskPushNotificationConfig, TaskQueryParams, ) from a2a.utils import AGENT_CARD_WELL_KNOWN_PATH + AGENT_CARD = AgentCard( name='Hello World Agent', description='Just a hello world agent', @@ -550,9 +552,10 @@ async def test_get_task_success( response = await client.get_task(request=params) assert isinstance(response, Task) - assert response.model_dump() == Task.model_validate( - MINIMAL_TASK - ).model_dump() + assert ( + response.model_dump() + == Task.model_validate(MINIMAL_TASK).model_dump() + ) mock_send_request.assert_called_once() sent_payload = mock_send_request.call_args.args[0] assert sent_payload['method'] == 'tasks/get' @@ -577,9 +580,10 @@ async def test_cancel_task_success( response = await client.cancel_task(request=params) assert isinstance(response, Task) - assert response.model_dump() == Task.model_validate( - MINIMAL_CANCELLED_TASK - ).model_dump() + assert ( + response.model_dump() + == Task.model_validate(MINIMAL_CANCELLED_TASK).model_dump() + ) mock_send_request.assert_called_once() sent_payload = mock_send_request.call_args.args[0] assert sent_payload['method'] == 'tasks/cancel' @@ -731,7 +735,9 @@ async def test_send_message_streaming_request_error( ] @pytest.mark.asyncio - async def test_get_card_no_card_provided(self, mock_httpx_client: AsyncMock): + async def test_get_card_no_card_provided( + self, mock_httpx_client: AsyncMock + ): client = JsonRpcTransport( httpx_client=mock_httpx_client, url=self.AGENT_URL ) diff --git a/tests/client/test_legacy_client.py b/tests/client/test_legacy_client.py index ea45e6a1..b15872d8 100644 --- a/tests/client/test_legacy_client.py +++ b/tests/client/test_legacy_client.py @@ -9,7 +9,6 @@ from a2a.types import ( AgentCapabilities, AgentCard, - GetTaskRequest, Message, MessageSendParams, Part, @@ -72,26 +71,12 @@ async def test_a2a_client_send_message( """Tests for the legacy client compatibility layer.""" -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock -import httpx import pytest -from a2a.client import A2AClient, A2AGrpcClient from a2a.types import ( - AgentCapabilities, AgentCard, - GetTaskRequest, - Message, - MessageSendParams, - Part, - Role, - SendMessageRequest, - Task, - TaskQueryParams, - TaskState, - TaskStatus, - TextPart, ) diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index 56c18bb4..928ab2ea 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -1,14 +1,15 @@ - import asyncio -from typing import Any, AsyncGenerator, NamedTuple + +from collections.abc import AsyncGenerator +from typing import NamedTuple from unittest.mock import ANY, AsyncMock import grpc import httpx import pytest import pytest_asyncio + from grpc.aio import Channel -from starlette.testclient import TestClient from a2a.client.transports import JsonRpcTransport, RestTransport from a2a.client.transports.base import ClientTransport @@ -27,7 +28,6 @@ PushNotificationConfig, Role, Task, - TaskArtifactUpdateEvent, TaskIdParams, TaskPushNotificationConfig, TaskQueryParams, @@ -38,46 +38,47 @@ TransportProtocol, ) + # --- Test Constants --- TASK_FROM_STREAM = Task( - id="task-123-stream", - context_id="ctx-456-stream", + id='task-123-stream', + context_id='ctx-456-stream', status=TaskStatus(state=TaskState.completed), - kind="task", + kind='task', ) TASK_FROM_BLOCKING = Task( - id="task-789-blocking", - context_id="ctx-101-blocking", + id='task-789-blocking', + context_id='ctx-101-blocking', status=TaskStatus(state=TaskState.completed), - kind="task", + kind='task', ) GET_TASK_RESPONSE = Task( - id="task-get-456", - context_id="ctx-get-789", + id='task-get-456', + context_id='ctx-get-789', status=TaskStatus(state=TaskState.working), - kind="task", + kind='task', ) CANCEL_TASK_RESPONSE = Task( - id="task-cancel-789", - context_id="ctx-cancel-101", + id='task-cancel-789', + context_id='ctx-cancel-101', status=TaskStatus(state=TaskState.canceled), - kind="task", + kind='task', ) CALLBACK_CONFIG = TaskPushNotificationConfig( - task_id="task-callback-123", + task_id='task-callback-123', push_notification_config=PushNotificationConfig( - id="pnc-abc", url="http://callback.example.com", token='' + id='pnc-abc', url='http://callback.example.com', token='' ), ) RESUBSCRIBE_EVENT = TaskStatusUpdateEvent( - task_id="task-resub-456", - context_id="ctx-resub-789", + task_id='task-resub-456', + context_id='ctx-resub-789', status=TaskStatus(state=TaskState.working), final=False, ) @@ -103,7 +104,9 @@ async def stream_side_effect(*args, **kwargs): # Configure other methods handler.on_get_task.return_value = GET_TASK_RESPONSE handler.on_cancel_task.return_value = CANCEL_TASK_RESPONSE - handler.on_set_task_push_notification_config.side_effect = lambda params, context: params + handler.on_set_task_push_notification_config.side_effect = ( + lambda params, context: params + ) handler.on_get_task_push_notification_config.return_value = CALLBACK_CONFIG async def resubscribe_side_effect(*args, **kwargs): @@ -506,8 +509,14 @@ async def test_http_transport_set_task_callback( result = await transport.set_task_callback(request=params) assert result.task_id == CALLBACK_CONFIG.task_id - assert result.push_notification_config.id == CALLBACK_CONFIG.push_notification_config.id - assert result.push_notification_config.url == CALLBACK_CONFIG.push_notification_config.url + assert ( + result.push_notification_config.id + == CALLBACK_CONFIG.push_notification_config.id + ) + assert ( + result.push_notification_config.url + == CALLBACK_CONFIG.push_notification_config.url + ) handler.on_set_task_push_notification_config.assert_awaited_once_with( params, ANY ) @@ -534,8 +543,14 @@ def channel_factory(address: str) -> Channel: result = await transport.set_task_callback(request=params) assert result.task_id == CALLBACK_CONFIG.task_id - assert result.push_notification_config.id == CALLBACK_CONFIG.push_notification_config.id - assert result.push_notification_config.url == CALLBACK_CONFIG.push_notification_config.url + assert ( + result.push_notification_config.id + == CALLBACK_CONFIG.push_notification_config.id + ) + assert ( + result.push_notification_config.url + == CALLBACK_CONFIG.push_notification_config.url + ) handler.on_set_task_push_notification_config.assert_awaited_once() assert ( handler.on_set_task_push_notification_config.call_args[0][0].task_id @@ -569,8 +584,14 @@ async def test_http_transport_get_task_callback( result = await transport.get_task_callback(request=params) assert result.task_id == CALLBACK_CONFIG.task_id - assert result.push_notification_config.id == CALLBACK_CONFIG.push_notification_config.id - assert result.push_notification_config.url == CALLBACK_CONFIG.push_notification_config.url + assert ( + result.push_notification_config.id + == CALLBACK_CONFIG.push_notification_config.id + ) + assert ( + result.push_notification_config.url + == CALLBACK_CONFIG.push_notification_config.url + ) handler.on_get_task_push_notification_config.assert_awaited_once_with( params, ANY ) @@ -600,8 +621,14 @@ def channel_factory(address: str) -> Channel: result = await transport.get_task_callback(request=params) assert result.task_id == CALLBACK_CONFIG.task_id - assert result.push_notification_config.id == CALLBACK_CONFIG.push_notification_config.id - assert result.push_notification_config.url == CALLBACK_CONFIG.push_notification_config.url + assert ( + result.push_notification_config.id + == CALLBACK_CONFIG.push_notification_config.id + ) + assert ( + result.push_notification_config.url + == CALLBACK_CONFIG.push_notification_config.url + ) handler.on_get_task_push_notification_config.assert_awaited_once() assert ( handler.on_get_task_push_notification_config.call_args[0][0].id diff --git a/tests/utils/test_proto_utils.py b/tests/utils/test_proto_utils.py index 74699c68..83848c24 100644 --- a/tests/utils/test_proto_utils.py +++ b/tests/utils/test_proto_utils.py @@ -237,9 +237,7 @@ def test_task_id_params_from_proto_invalid_name(self): assert isinstance(exc_info.value.error, types.InvalidParamsError) def test_task_push_config_from_proto_invalid_parent(self): - request = a2a_pb2.TaskPushNotificationConfig( - name='invalid-name-format' - ) + request = a2a_pb2.TaskPushNotificationConfig(name='invalid-name-format') with pytest.raises(ServerError) as exc_info: proto_utils.FromProto.task_push_notification_config(request) assert isinstance(exc_info.value.error, types.InvalidParamsError) From ff0ad3be3815d96df5d987043b0bd6a818435abc Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Wed, 30 Jul 2025 15:44:07 +0000 Subject: [PATCH 11/12] Fix broken tests --- src/a2a/client/transports/rest.py | 16 +--- src/a2a/server/apps/rest/rest_app.py | 31 ++------ .../server/request_handlers/rest_handler.py | 18 ++--- tests/client/test_jsonrpc_client.py | 4 +- tests/client/test_legacy_client.py | 77 +------------------ 5 files changed, 22 insertions(+), 124 deletions(-) diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index 7c20fe9c..430d642c 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -310,24 +310,14 @@ async def resubscribe( Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message ]: """Reconnects to get task updates.""" - pb = a2a_pb2.TaskSubscriptionRequest( - name=f'tasks/{request.id}', - ) - payload = MessageToDict(pb) - payload, modified_kwargs = await self._apply_interceptors( - payload, - self._get_http_args(context), - context, - ) - - modified_kwargs.setdefault('timeout', None) + http_kwargs = self._get_http_args(context) or {} + http_kwargs.setdefault('timeout', None) async with aconnect_sse( self.httpx_client, 'GET', f'{self.url}/v1/tasks/{request.id}:subscribe', - json=payload, - **modified_kwargs, + **http_kwargs, ) as event_source: try: async for sse in event_source.aiter_sse(): diff --git a/src/a2a/server/apps/rest/rest_app.py b/src/a2a/server/apps/rest/rest_app.py index 331f19ca..7707a5f7 100644 --- a/src/a2a/server/apps/rest/rest_app.py +++ b/src/a2a/server/apps/rest/rest_app.py @@ -15,10 +15,7 @@ from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler from a2a.server.request_handlers.rest_handler import RESTHandler -from a2a.types import ( - AgentCard, - AuthenticatedExtendedCardNotConfiguredError, -) +from a2a.types import AgentCard, AuthenticatedExtendedCardNotConfiguredError from a2a.utils.error_handlers import ( rest_error_handler, rest_stream_error_handler, @@ -61,21 +58,7 @@ def __init__( @rest_error_handler async def _handle_request( self, - method: Callable[ - [Request, ServerCallContext], Awaitable[dict[str, Any]] - ], - request: Request, - ) -> Response: - call_context = self._context_builder.build(request) - response = await method(request, call_context) - return JSONResponse(content=response) - - @rest_error_handler - async def _handle_list_request( - self, - method: Callable[ - [Request, ServerCallContext], Awaitable[list[dict[str, Any]]] - ], + method: Callable[[Request, ServerCallContext], Awaitable[Any]], request: Request, ) -> Response: call_context = self._context_builder.build(request) @@ -85,15 +68,13 @@ async def _handle_list_request( @rest_stream_error_handler async def _handle_streaming_request( self, - method: Callable[ - [Request, ServerCallContext], AsyncIterable[dict[str, Any]] - ], + method: Callable[[Request, ServerCallContext], AsyncIterable[Any]], request: Request, ) -> EventSourceResponse: call_context = self._context_builder.build(request) async def event_generator( - stream: AsyncIterable[dict[str, Any]], + stream: AsyncIterable[Any], ) -> AsyncIterator[dict[str, dict[str, Any]]]: async for item in stream: yield {'data': item} @@ -188,10 +169,10 @@ def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]: '/v1/tasks/{id}/pushNotificationConfigs', 'GET', ): functools.partial( - self._handle_list_request, self.handler.list_push_notifications + self._handle_request, self.handler.list_push_notifications ), ('/v1/tasks', 'GET'): functools.partial( - self._handle_list_request, self.handler.list_tasks + self._handle_request, self.handler.list_tasks ), } if self.agent_card.supports_authenticated_extended_card: diff --git a/src/a2a/server/request_handlers/rest_handler.py b/src/a2a/server/request_handlers/rest_handler.py index cd64c93b..179ca108 100644 --- a/src/a2a/server/request_handlers/rest_handler.py +++ b/src/a2a/server/request_handlers/rest_handler.py @@ -3,7 +3,7 @@ from collections.abc import AsyncIterable, AsyncIterator from typing import Any -from google.protobuf.json_format import MessageToDict, Parse +from google.protobuf.json_format import MessageToDict, MessageToJson, Parse from starlette.requests import Request from a2a.grpc import a2a_pb2 @@ -86,7 +86,7 @@ async def on_message_send_stream( self, request: Request, context: ServerCallContext, - ) -> AsyncIterator[dict[str, Any]]: + ) -> AsyncIterator[str]: """Handles the 'message/stream' REST method. Yields response objects as they are produced by the underlying handler's stream. @@ -96,7 +96,7 @@ async def on_message_send_stream( context: Context provided by the server. Yields: - `dict` objects containing streaming events + JSON serialized objects containing streaming events (Task, Message, TaskStatusUpdateEvent, TaskArtifactUpdateEvent) as JSON """ body = await request.body() @@ -110,7 +110,7 @@ async def on_message_send_stream( a2a_request, context ): response = proto_utils.ToProto.stream_response(event) - yield MessageToDict(response) + yield MessageToJson(response) async def on_cancel_task( self, @@ -142,7 +142,7 @@ async def on_resubscribe_to_task( self, request: Request, context: ServerCallContext, - ) -> AsyncIterable[dict[str, Any]]: + ) -> AsyncIterable[str]: """Handles the 'tasks/resubscribe' REST method. Yields response objects as they are produced by the underlying handler's stream. @@ -152,13 +152,13 @@ async def on_resubscribe_to_task( context: Context provided by the server. Yields: - `dict` containing streaming events + JSON serialized objects containing streaming events """ task_id = request.path_params['id'] async for event in self.request_handler.on_resubscribe_to_task( TaskIdParams(id=task_id), context ): - yield (MessageToDict(proto_utils.ToProto.stream_response(event))) + yield MessageToJson(proto_utils.ToProto.stream_response(event)) async def get_push_notification( self, @@ -262,7 +262,7 @@ async def list_push_notifications( self, request: Request, context: ServerCallContext, - ) -> list[dict[str, Any]]: + ) -> dict[str, Any]: """Handles the 'tasks/pushNotificationConfig/list' REST method. This method is currently not implemented. @@ -283,7 +283,7 @@ async def list_tasks( self, request: Request, context: ServerCallContext, - ) -> list[dict[str, Any]]: + ) -> dict[str, Any]: """Handles the 'tasks/list' REST method. This method is currently not implemented. diff --git a/tests/client/test_jsonrpc_client.py b/tests/client/test_jsonrpc_client.py index a4a9b96e..6f985b77 100644 --- a/tests/client/test_jsonrpc_client.py +++ b/tests/client/test_jsonrpc_client.py @@ -334,7 +334,7 @@ def test_init_with_url(self, mock_httpx_client: AsyncMock): assert client.url == self.AGENT_URL assert client.httpx_client == mock_httpx_client - def test_init_with_agent_card_and_url_prioritizes_agent_card( + def test_init_with_agent_card_and_url_prioritizes_url( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): client = JsonRpcTransport( @@ -342,7 +342,7 @@ def test_init_with_agent_card_and_url_prioritizes_agent_card( agent_card=mock_agent_card, url='http://otherurl.com', ) - assert client.url == mock_agent_card.url + assert client.url == 'http://otherurl.com' def test_init_raises_value_error_if_no_card_or_url( self, mock_httpx_client: AsyncMock diff --git a/tests/client/test_legacy_client.py b/tests/client/test_legacy_client.py index b15872d8..7a7b5811 100644 --- a/tests/client/test_legacy_client.py +++ b/tests/client/test_legacy_client.py @@ -1,83 +1,10 @@ """Tests for the legacy client compatibility layer.""" -from unittest.mock import AsyncMock, MagicMock - -import httpx -import pytest - -from a2a.client import A2AClient, A2AGrpcClient -from a2a.types import ( - AgentCapabilities, - AgentCard, - Message, - MessageSendParams, - Part, - Role, - SendMessageRequest, - Task, - TaskQueryParams, - TaskState, - TaskStatus, - TextPart, -) - - -@pytest.fixture -def mock_httpx_client() -> AsyncMock: - return AsyncMock(spec=httpx.AsyncClient) - - -@pytest.fixture -def mock_grpc_stub() -> AsyncMock: - stub = AsyncMock() - stub._channel = MagicMock() - return stub - - -@pytest.fixture -def jsonrpc_agent_card() -> AgentCard: - return AgentCard( - name='Test Agent', - description='A test agent', - url='http://test.agent.com/rpc', - version='1.0.0', - capabilities=AgentCapabilities(streaming=True), - skills=[], - default_input_modes=[], - default_output_modes=[], - preferred_transport='jsonrpc', - ) - - -@pytest.fixture -def grpc_agent_card() -> AgentCard: - return AgentCard( - name='Test Agent', - description='A test agent', - url='http://test.agent.com/rpc', - version='1.0.0', - capabilities=AgentCapabilities(streaming=True), - skills=[], - default_input_modes=[], - default_output_modes=[], - preferred_transport='grpc', - ) - - -@pytest.mark.asyncio -async def test_a2a_client_send_message( - mock_httpx_client: AsyncMock, jsonrpc_agent_card: AgentCard -): - """Tests for the legacy client compatibility layer.""" - - from unittest.mock import AsyncMock import pytest -from a2a.types import ( - AgentCard, -) +from a2a.types import AgentCard @pytest.fixture @@ -164,7 +91,7 @@ async def test_a2a_grpc_client_get_task( status=TaskStatus(state=TaskState.working), ) - client.get_task.return_value = mock_response_task + client.get_task = AsyncMock(return_value=mock_response_task) params = TaskQueryParams(id='task-456') response = await client.get_task(params) From 22ef4dbebbc530e0811cc88600677e89fc5e4e6b Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Wed, 30 Jul 2025 16:00:53 +0000 Subject: [PATCH 12/12] Address issues in legacy client, review comments --- src/a2a/client/legacy.py | 142 ++++++++++++++++++++++++++++- src/a2a/client/transports/base.py | 2 +- tests/client/test_legacy_client.py | 20 +++- 3 files changed, 156 insertions(+), 8 deletions(-) diff --git a/src/a2a/client/legacy.py b/src/a2a/client/legacy.py index 544ccbad..dd289ded 100644 --- a/src/a2a/client/legacy.py +++ b/src/a2a/client/legacy.py @@ -15,6 +15,7 @@ CancelTaskRequest, CancelTaskResponse, CancelTaskSuccessResponse, + GetTaskPushNotificationConfigParams, GetTaskPushNotificationConfigRequest, GetTaskPushNotificationConfigResponse, GetTaskPushNotificationConfigSuccessResponse, @@ -31,6 +32,7 @@ SetTaskPushNotificationConfigRequest, SetTaskPushNotificationConfigResponse, SetTaskPushNotificationConfigSuccessResponse, + TaskIdParams, TaskResubscriptionRequest, ) @@ -62,6 +64,21 @@ async def send_message( http_kwargs: dict[str, Any] | None = None, context: ClientCallContext | None = None, ) -> SendMessageResponse: + """Sends a non-streaming message request to the agent. + + Args: + request: The `SendMessageRequest` object containing the message and configuration. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `SendMessageResponse` object containing the agent's response (Task or Message) or an error. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ if not context and http_kwargs: context = ClientCallContext(state={'http_kwargs': http_kwargs}) @@ -75,7 +92,7 @@ async def send_message( ) ) except A2AClientJSONRPCError as e: - return SendMessageResponse(root=JSONRPCErrorResponse(error=e.error)) + return SendMessageResponse(JSONRPCErrorResponse(error=e.error)) async def send_message_streaming( self, @@ -84,6 +101,24 @@ async def send_message_streaming( http_kwargs: dict[str, Any] | None = None, context: ClientCallContext | None = None, ) -> AsyncGenerator[SendStreamingMessageResponse, None]: + """Sends a streaming message request to the agent and yields responses as they arrive. + + This method uses Server-Sent Events (SSE) to receive a stream of updates from the agent. + + Args: + request: The `SendStreamingMessageRequest` object containing the message and configuration. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. A default `timeout=None` is set but can be overridden. + context: The client call context. + + Yields: + `SendStreamingMessageResponse` objects as they are received in the SSE stream. + These can be Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent. + + Raises: + A2AClientHTTPError: If an HTTP or SSE protocol error occurs during the request. + A2AClientJSONError: If an SSE event data cannot be decoded as JSON or validated. + """ if not context and http_kwargs: context = ClientCallContext(state={'http_kwargs': http_kwargs}) @@ -103,6 +138,21 @@ async def get_task( http_kwargs: dict[str, Any] | None = None, context: ClientCallContext | None = None, ) -> GetTaskResponse: + """Retrieves the current state and history of a specific task. + + Args: + request: The `GetTaskRequest` object specifying the task ID and history length. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `GetTaskResponse` object containing the Task or an error. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ if not context and http_kwargs: context = ClientCallContext(state={'http_kwargs': http_kwargs}) try: @@ -124,6 +174,21 @@ async def cancel_task( http_kwargs: dict[str, Any] | None = None, context: ClientCallContext | None = None, ) -> CancelTaskResponse: + """Requests the agent to cancel a specific task. + + Args: + request: The `CancelTaskRequest` object specifying the task ID. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `CancelTaskResponse` object containing the updated Task with canceled status or an error. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ if not context and http_kwargs: context = ClientCallContext(state={'http_kwargs': http_kwargs}) try: @@ -136,7 +201,7 @@ async def cancel_task( ) ) except A2AClientJSONRPCError as e: - return CancelTaskResponse(root=JSONRPCErrorResponse(error=e.error)) + return CancelTaskResponse(JSONRPCErrorResponse(error=e.error)) async def set_task_callback( self, @@ -145,6 +210,21 @@ async def set_task_callback( http_kwargs: dict[str, Any] | None = None, context: ClientCallContext | None = None, ) -> SetTaskPushNotificationConfigResponse: + """Sets or updates the push notification configuration for a specific task. + + Args: + request: The `SetTaskPushNotificationConfigRequest` object specifying the task ID and configuration. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `SetTaskPushNotificationConfigResponse` object containing the confirmation or an error. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ if not context and http_kwargs: context = ClientCallContext(state={'http_kwargs': http_kwargs}) try: @@ -158,7 +238,7 @@ async def set_task_callback( ) except A2AClientJSONRPCError as e: return SetTaskPushNotificationConfigResponse( - root=JSONRPCErrorResponse(error=e.error) + JSONRPCErrorResponse(error=e.error) ) async def get_task_callback( @@ -168,11 +248,31 @@ async def get_task_callback( http_kwargs: dict[str, Any] | None = None, context: ClientCallContext | None = None, ) -> GetTaskPushNotificationConfigResponse: + """Retrieves the push notification configuration for a specific task. + + Args: + request: The `GetTaskPushNotificationConfigRequest` object specifying the task ID. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `GetTaskPushNotificationConfigResponse` object containing the configuration or an error. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ if not context and http_kwargs: context = ClientCallContext(state={'http_kwargs': http_kwargs}) + params = request.params + if isinstance(params, TaskIdParams): + params = GetTaskPushNotificationConfigParams( + id=request.params.task_id + ) try: result = await self._transport.get_task_callback( - request.params, context=context + params, context=context ) return GetTaskPushNotificationConfigResponse( root=GetTaskPushNotificationConfigSuccessResponse( @@ -181,7 +281,7 @@ async def get_task_callback( ) except A2AClientJSONRPCError as e: return GetTaskPushNotificationConfigResponse( - root=JSONRPCErrorResponse(error=e.error) + JSONRPCErrorResponse(error=e.error) ) async def resubscribe( @@ -191,6 +291,24 @@ async def resubscribe( http_kwargs: dict[str, Any] | None = None, context: ClientCallContext | None = None, ) -> AsyncGenerator[SendStreamingMessageResponse, None]: + """Reconnects to get task updates. + + This method uses Server-Sent Events (SSE) to receive a stream of updates from the agent. + + Args: + request: The `TaskResubscriptionRequest` object containing the task information to reconnect to. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. A default `timeout=None` is set but can be overridden. + context: The client call context. + + Yields: + `SendStreamingMessageResponse` objects as they are received in the SSE stream. + These can be Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent. + + Raises: + A2AClientHTTPError: If an HTTP or SSE protocol error occurs during the request. + A2AClientJSONError: If an SSE event data cannot be decoded as JSON or validated. + """ if not context and http_kwargs: context = ClientCallContext(state={'http_kwargs': http_kwargs}) @@ -209,6 +327,20 @@ async def get_card( http_kwargs: dict[str, Any] | None = None, context: ClientCallContext | None = None, ) -> AgentCard: + """Retrieves the authenticated card (if necessary) or the public one. + + Args: + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `AgentCard` object containing the card or an error. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ if not context and http_kwargs: context = ClientCallContext(state={'http_kwargs': http_kwargs}) return await self._transport.get_card(context=context) diff --git a/src/a2a/client/transports/base.py b/src/a2a/client/transports/base.py index ad693f24..3573cb7c 100644 --- a/src/a2a/client/transports/base.py +++ b/src/a2a/client/transports/base.py @@ -96,7 +96,7 @@ async def get_card( *, context: ClientCallContext | None = None, ) -> AgentCard: - """Retrieves the agent's card.""" + """Retrieves the AgentCard.""" @abstractmethod async def close(self) -> None: diff --git a/tests/client/test_legacy_client.py b/tests/client/test_legacy_client.py index 7a7b5811..247f0b18 100644 --- a/tests/client/test_legacy_client.py +++ b/tests/client/test_legacy_client.py @@ -1,10 +1,26 @@ """Tests for the legacy client compatibility layer.""" -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, MagicMock +import httpx import pytest -from a2a.types import AgentCard +from a2a.client import A2AClient, A2AGrpcClient +from a2a.types import ( + AgentCard, + AgentCapabilities, + Message, + Role, + TextPart, + Part, + Task, + TaskStatus, + TaskState, + TaskQueryParams, + SendMessageRequest, + MessageSendParams, + GetTaskRequest, +) @pytest.fixture