diff --git a/src/a2a/client/__init__.py b/src/a2a/client/__init__.py index 6e88a03d..dae06357 100644 --- a/src/a2a/client/__init__.py +++ b/src/a2a/client/__init__.py @@ -11,7 +11,6 @@ from a2a.client.client import Client, ClientConfig, ClientEvent, Consumer from a2a.client.client_factory import ( ClientFactory, - ClientProducer, minimal_agent_card, ) from a2a.client.errors import ( @@ -21,28 +20,14 @@ A2AClientTimeoutError, ) from a2a.client.helpers import create_text_message_object -from a2a.client.jsonrpc_client import ( - A2AClient, - JsonRpcClient, - JsonRpcTransportClient, - NewJsonRpcClient, -) +from a2a.client.legacy import A2AClient from a2a.client.middleware import ClientCallContext, ClientCallInterceptor -from a2a.client.rest_client import ( - NewRestfulClient, - RestClient, - RestTransportClient, -) logger = logging.getLogger(__name__) try: - from a2a.client.grpc_client import ( - GrpcClient, - GrpcTransportClient, # type: ignore - NewGrpcClient, - ) + from a2a.client.legacy_grpc import A2AGrpcClient except ImportError as e: _original_error = e logger.debug( @@ -50,7 +35,7 @@ _original_error, ) - class GrpcTransportClient: # type: ignore + class A2AGrpcClient: # type: ignore """Placeholder for A2AGrpcClient when dependencies are not installed.""" def __init__(self, *args, **kwargs): @@ -58,20 +43,16 @@ def __init__(self, *args, **kwargs): 'To use A2AGrpcClient, its dependencies must be installed. ' 'You can install them with \'pip install "a2a-sdk[grpc]"\'' ) from _original_error -finally: - # For backward compatability define this alias. This will be deprecated in - # a future release. - A2AGrpcClient = GrpcTransportClient # type: ignore __all__ = [ 'A2ACardResolver', - 'A2AClient', # for backward compatability + 'A2AClient', 'A2AClientError', 'A2AClientHTTPError', 'A2AClientJSONError', 'A2AClientTimeoutError', - 'A2AGrpcClient', # for backward compatability + 'A2AGrpcClient', 'AuthInterceptor', 'Client', 'ClientCallContext', @@ -79,19 +60,9 @@ def __init__(self, *args, **kwargs): 'ClientConfig', 'ClientEvent', 'ClientFactory', - 'ClientProducer', 'Consumer', 'CredentialService', - 'GrpcClient', - 'GrpcTransportClient', 'InMemoryContextCredentialStore', - 'JsonRpcClient', - 'JsonRpcTransportClient', - 'NewGrpcClient', - 'NewJsonRpcClient', - 'NewRestfulClient', - 'RestClient', - 'RestTransportClient', 'create_text_message_object', 'minimal_agent_card', ] diff --git a/src/a2a/client/base_client.py b/src/a2a/client/base_client.py new file mode 100644 index 00000000..f4a8d03d --- /dev/null +++ b/src/a2a/client/base_client.py @@ -0,0 +1,241 @@ +from collections.abc import AsyncIterator + +from a2a.client.client import ( + Client, + ClientCallContext, + ClientConfig, + ClientEvent, + Consumer, +) +from a2a.client.client_task_manager import ClientTaskManager +from a2a.client.errors import A2AClientInvalidStateError +from a2a.client.middleware import ClientCallInterceptor +from a2a.client.transports.base import ClientTransport +from a2a.types import ( + AgentCard, + GetTaskPushNotificationConfigParams, + Message, + MessageSendConfiguration, + MessageSendParams, + Task, + TaskArtifactUpdateEvent, + TaskIdParams, + TaskPushNotificationConfig, + TaskQueryParams, + TaskStatusUpdateEvent, +) + + +class BaseClient(Client): + """Base implementation of the A2A client, containing transport-independent logic.""" + + def __init__( + self, + card: AgentCard, + config: ClientConfig, + transport: ClientTransport, + consumers: list[Consumer], + middleware: list[ClientCallInterceptor], + ): + super().__init__(consumers, middleware) + self._card = card + self._config = config + self._transport = transport + + async def send_message( + self, + request: Message, + *, + context: ClientCallContext | None = None, + ) -> AsyncIterator[ClientEvent | Message]: + """Sends a message to the agent. + + This method handles both streaming and non-streaming (polling) interactions + based on the client configuration and agent capabilities. It will yield + events as they are received from the agent. + + Args: + request: The message to send to the agent. + context: The client call context. + + Yields: + An async iterator of `ClientEvent` or a final `Message` response. + """ + config = MessageSendConfiguration( + accepted_output_modes=self._config.accepted_output_modes, + blocking=not self._config.polling, + push_notification_config=( + self._config.push_notification_configs[0] + if self._config.push_notification_configs + else None + ), + ) + params = MessageSendParams(message=request, configuration=config) + + if not self._config.streaming or not self._card.capabilities.streaming: + response = await self._transport.send_message( + params, context=context + ) + result = ( + (response, None) if isinstance(response, Task) else response + ) + await self.consume(result, self._card) + yield result + return + + tracker = ClientTaskManager() + stream = self._transport.send_message_streaming(params, context=context) + + first_event = await anext(stream) + # The response from a server may be either exactly one Message or a + # series of Task updates. Separate out the first message for special + # case handling, which allows us to simplify further stream processing. + if isinstance(first_event, Message): + await self.consume(first_event, self._card) + yield first_event + return + + yield await self._process_response(tracker, first_event) + + async for event in stream: + yield await self._process_response(tracker, event) + + async def _process_response( + self, + tracker: ClientTaskManager, + event: Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent, + ) -> ClientEvent: + if isinstance(event, Message): + raise A2AClientInvalidStateError( + 'received a streamed Message from server after first response; this is not supported' + ) + await tracker.process(event) + task = tracker.get_task_or_raise() + update = None if isinstance(event, Task) else event + client_event = (task, update) + await self.consume(client_event, self._card) + return client_event + + async def get_task( + self, + request: TaskQueryParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + """Retrieves the current state and history of a specific task. + + Args: + request: The `TaskQueryParams` object specifying the task ID. + context: The client call context. + + Returns: + A `Task` object representing the current state of the task. + """ + return await self._transport.get_task(request, context=context) + + async def cancel_task( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + """Requests the agent to cancel a specific task. + + Args: + request: The `TaskIdParams` object specifying the task ID. + context: The client call context. + + Returns: + A `Task` object containing the updated task status. + """ + return await self._transport.cancel_task(request, context=context) + + async def set_task_callback( + self, + request: TaskPushNotificationConfig, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Sets or updates the push notification configuration for a specific task. + + Args: + request: The `TaskPushNotificationConfig` object with the new configuration. + context: The client call context. + + Returns: + The created or updated `TaskPushNotificationConfig` object. + """ + return await self._transport.set_task_callback(request, context=context) + + async def get_task_callback( + self, + request: GetTaskPushNotificationConfigParams, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Retrieves the push notification configuration for a specific task. + + Args: + request: The `GetTaskPushNotificationConfigParams` object specifying the task. + context: The client call context. + + Returns: + A `TaskPushNotificationConfig` object containing the configuration. + """ + return await self._transport.get_task_callback(request, context=context) + + async def resubscribe( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> AsyncIterator[ClientEvent]: + """Resubscribes to a task's event stream. + + This is only available if both the client and server support streaming. + + Args: + request: Parameters to identify the task to resubscribe to. + context: The client call context. + + Yields: + An async iterator of `ClientEvent` objects. + + Raises: + NotImplementedError: If streaming is not supported by the client or server. + """ + if not self._config.streaming or not self._card.capabilities.streaming: + raise NotImplementedError( + 'client and/or server do not support resubscription.' + ) + + tracker = ClientTaskManager() + # Note: resubscribe can only be called on an existing task. As such, + # we should never see Message updates, despite the typing of the service + # definition indicating it may be possible. + async for event in self._transport.resubscribe( + request, context=context + ): + yield await self._process_response(tracker, event) + + async def get_card( + self, *, context: ClientCallContext | None = None + ) -> AgentCard: + """Retrieves the agent's card. + + This will fetch the authenticated card if necessary and update the + client's internal state with the new card. + + Args: + context: The client call context. + + Returns: + The `AgentCard` for the agent. + """ + card = await self._transport.get_card(context=context) + self._card = card + return card + + async def close(self) -> None: + """Closes the underlying transport.""" + await self._transport.close() diff --git a/src/a2a/client/client.py b/src/a2a/client/client.py index c51597c0..7cc10423 100644 --- a/src/a2a/client/client.py +++ b/src/a2a/client/client.py @@ -119,6 +119,8 @@ async def send_message( pairs, or a `Message`. Client will also send these values to any configured `Consumer`s in the client. """ + return + yield @abstractmethod async def get_task( @@ -164,6 +166,8 @@ async def resubscribe( context: ClientCallContext | None = None, ) -> AsyncIterator[ClientEvent]: """Resubscribes to a task's event stream.""" + return + yield @abstractmethod async def get_card( diff --git a/src/a2a/client/client_factory.py b/src/a2a/client/client_factory.py index f47be58a..b2d914ee 100644 --- a/src/a2a/client/client_factory.py +++ b/src/a2a/client/client_factory.py @@ -4,16 +4,14 @@ from collections.abc import Callable -from a2a.client.client import Client, ClientConfig, Consumer - +import httpx -try: - from a2a.client.grpc_client import NewGrpcClient -except ImportError: - NewGrpcClient = None -from a2a.client.jsonrpc_client import NewJsonRpcClient +from a2a.client.base_client import BaseClient +from a2a.client.client import Client, ClientConfig, Consumer from a2a.client.middleware import ClientCallInterceptor -from a2a.client.rest_client import NewRestfulClient +from a2a.client.transports.base import ClientTransport +from a2a.client.transports.jsonrpc import JsonRpcTransport +from a2a.client.transports.rest import RestTransport from a2a.types import ( AgentCapabilities, AgentCard, @@ -22,16 +20,20 @@ ) +try: + from a2a.client.transports.grpc import GrpcTransport + from a2a.grpc import a2a_pb2_grpc +except ImportError: + GrpcTransport = None + a2a_pb2_grpc = None + + logger = logging.getLogger(__name__) -ClientProducer = Callable[ - [ - AgentCard, - ClientConfig, - list[Consumer], - list[ClientCallInterceptor], - ], - Client, + +TransportProducer = Callable[ + [AgentCard, str, ClientConfig, list[ClientCallInterceptor]], + ClientTransport, ] @@ -60,23 +62,41 @@ def __init__( consumers = [] self._config = config self._consumers = consumers - self._registry: dict[str, ClientProducer] = {} - # By default register the 3 core transports if in the config. - # Can be overridden with custom clients via the register method. - if TransportProtocol.jsonrpc in self._config.supported_transports: - self._registry[TransportProtocol.jsonrpc] = NewJsonRpcClient - if TransportProtocol.http_json in self._config.supported_transports: - self._registry[TransportProtocol.http_json] = NewRestfulClient - if TransportProtocol.grpc in self._config.supported_transports: - if NewGrpcClient is None: - raise ImportError( - 'To use GrpcClient, its dependencies must be installed. ' - 'You can install them with \'pip install "a2a-sdk[grpc]"\'' - ) - self._registry[TransportProtocol.grpc] = NewGrpcClient - - def register(self, label: str, generator: ClientProducer) -> None: - """Register a new client producer for a given transport label.""" + self._registry: dict[str, TransportProducer] = {} + self._register_defaults() + + def _register_defaults(self) -> None: + self.register( + TransportProtocol.jsonrpc, + lambda card, url, config, interceptors: JsonRpcTransport( + config.httpx_client or httpx.AsyncClient(), + card, + url, + interceptors, + ), + ) + self.register( + TransportProtocol.http_json, + lambda card, url, config, interceptors: RestTransport( + config.httpx_client or httpx.AsyncClient(), + card, + url, + interceptors, + ), + ) + if GrpcTransport: + self.register( + TransportProtocol.grpc, + lambda card, url, config, interceptors: GrpcTransport( + a2a_pb2_grpc.A2AServiceStub( + config.grpc_channel_factory(url) + ), + card, + ), + ) + + def register(self, label: str, generator: TransportProducer) -> None: + """Register a new transport producer for a given transport label.""" self._registry[label] = generator def create( @@ -101,34 +121,44 @@ def create( If there is no valid matching of the client configuration with the server configuration, a `ValueError` is raised. """ - # Determine preferential transport - server_set = [card.preferred_transport or TransportProtocol.jsonrpc] + server_preferred = card.preferred_transport or TransportProtocol.jsonrpc + server_set = {server_preferred: card.url} if card.additional_interfaces: - server_set.extend([x.transport for x in card.additional_interfaces]) + server_set.update( + {x.transport: x.url for x in card.additional_interfaces} + ) client_set = self._config.supported_transports or [ TransportProtocol.jsonrpc ] - transport = None - # Two options, use the client ordering or the server ordering. + transport_protocol = None + transport_url = None if self._config.use_client_preference: for x in client_set: if x in server_set: - transport = x + transport_protocol = x + transport_url = server_set[x] break else: - for x in server_set: + for x, url in server_set.items(): if x in client_set: - transport = x + transport_protocol = x + transport_url = url break - if not transport: + if not transport_protocol or not transport_url: raise ValueError('no compatible transports found.') - if transport not in self._registry: - raise ValueError(f'no client available for {transport}') + if transport_protocol not in self._registry: + raise ValueError(f'no client available for {transport_protocol}') + all_consumers = self._consumers.copy() if consumers: all_consumers.extend(consumers) - return self._registry[transport]( - card, self._config, all_consumers, interceptors or [] + + transport = self._registry[transport_protocol]( + card, transport_url, self._config, interceptors or [] + ) + + return BaseClient( + card, self._config, transport, all_consumers, interceptors or [] ) diff --git a/src/a2a/client/grpc_client.py b/src/a2a/client/grpc_client.py deleted file mode 100644 index 1ead88c8..00000000 --- a/src/a2a/client/grpc_client.py +++ /dev/null @@ -1,544 +0,0 @@ -import logging - -from collections.abc import AsyncGenerator, AsyncIterator - - -try: - import grpc -except ImportError as e: - raise ImportError( - 'A2AGrpcClient requires grpcio and grpcio-tools to be installed. ' - 'Install with: ' - "'pip install a2a-sdk[grpc]'" - ) from e - - -from a2a.client.client import ( - Client, - ClientCallContext, - ClientConfig, - ClientEvent, - Consumer, -) -from a2a.client.client_task_manager import ClientTaskManager -from a2a.client.errors import A2AClientInvalidStateError -from a2a.client.middleware import ClientCallInterceptor -from a2a.grpc import a2a_pb2, a2a_pb2_grpc -from a2a.types import ( - AgentCard, - GetTaskPushNotificationConfigParams, - Message, - MessageSendConfiguration, - MessageSendParams, - Task, - TaskArtifactUpdateEvent, - TaskIdParams, - TaskPushNotificationConfig, - TaskQueryParams, - TaskStatusUpdateEvent, -) -from a2a.utils import proto_utils -from a2a.utils.telemetry import SpanKind, trace_class - - -logger = logging.getLogger(__name__) - - -@trace_class(kind=SpanKind.CLIENT) -class GrpcTransportClient: - """Transport specific details for interacting with an A2A agent via gRPC.""" - - def __init__( - self, - grpc_stub: a2a_pb2_grpc.A2AServiceStub, - agent_card: AgentCard | None, - ): - """Initializes the GrpcTransportClient. - - Requires an `AgentCard` and a grpc `A2AServiceStub`. - - Args: - grpc_stub: A grpc client stub. - agent_card: The agent card object. - """ - self.agent_card = agent_card - self.stub = grpc_stub - # If they don't provide an agent card, but do have a stub, lookup the - # card from the stub. - self._needs_extended_card = ( - agent_card.supports_authenticated_extended_card - if agent_card - else True - ) - - async def send_message( - self, - request: MessageSendParams, - *, - context: ClientCallContext | None = None, - ) -> Task | Message: - """Sends a non-streaming message request to the agent. - - Args: - request: The `MessageSendParams` object containing the message and configuration. - context: The client call context. - - Returns: - A `Task` or `Message` object containing the agent's response. - """ - response = await self.stub.SendMessage( - a2a_pb2.SendMessageRequest( - request=proto_utils.ToProto.message(request.message), - configuration=proto_utils.ToProto.message_send_configuration( - request.configuration - ), - metadata=proto_utils.ToProto.metadata(request.metadata), - ) - ) - if response.task: - return proto_utils.FromProto.task(response.task) - return proto_utils.FromProto.message(response.msg) - - async def send_message_streaming( - self, - request: MessageSendParams, - *, - context: ClientCallContext | None = None, - ) -> AsyncGenerator[ - Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent - ]: - """Sends a streaming message request to the agent and yields responses as they arrive. - - This method uses gRPC streams to receive a stream of updates from the - agent. - - Args: - request: The `MessageSendParams` object containing the message and configuration. - context: The client call context. - - Yields: - `Message` or `Task` or `TaskStatusUpdateEvent` or - `TaskArtifactUpdateEvent` objects as they are received in the - stream. - """ - stream = self.stub.SendStreamingMessage( - a2a_pb2.SendMessageRequest( - request=proto_utils.ToProto.message(request.message), - configuration=proto_utils.ToProto.message_send_configuration( - request.configuration - ), - metadata=proto_utils.ToProto.metadata(request.metadata), - ) - ) - while True: - response = await stream.read() - if response == grpc.aio.EOF: # pyright: ignore [reportAttributeAccessIssue] - break - yield proto_utils.FromProto.stream_response(response) - - async def resubscribe( - self, request: TaskIdParams, *, context: ClientCallContext | None = None - ) -> AsyncGenerator[ - Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent - ]: - """Reconnects to get task updates. - - This method uses a unary server-side stream to receive updates. - - Args: - request: The `TaskIdParams` object containing the task information to reconnect to. - context: The client call context. - - Yields: - Task update events, which can be either a Task, Message, - TaskStatusUpdateEvent, or TaskArtifactUpdateEvent. - - Raises: - A2AClientInvalidStateError: If the server returns an invalid response. - """ - stream = self.stub.TaskSubscription( - a2a_pb2.TaskSubscriptionRequest(name=f'tasks/{request.id}') - ) - while True: - response = await stream.read() - if response == grpc.aio.EOF: # pyright: ignore [reportAttributeAccessIssue] - break - yield proto_utils.FromProto.stream_response(response) - - async def get_task( - self, - request: TaskQueryParams, - *, - context: ClientCallContext | None = None, - ) -> Task: - """Retrieves the current state and history of a specific task. - - Args: - request: The `TaskQueryParams` object specifying the task ID - context: The client call context. - - Returns: - A `Task` object containing the Task. - """ - task = await self.stub.GetTask( - a2a_pb2.GetTaskRequest(name=f'tasks/{request.id}') - ) - return proto_utils.FromProto.task(task) - - async def cancel_task( - self, - request: TaskIdParams, - *, - context: ClientCallContext | None = None, - ) -> Task: - """Requests the agent to cancel a specific task. - - Args: - request: The `TaskIdParams` object specifying the task ID. - context: The client call context. - - Returns: - A `Task` object containing the updated Task - """ - task = await self.stub.CancelTask( - a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}') - ) - return proto_utils.FromProto.task(task) - - async def set_task_callback( - self, - request: TaskPushNotificationConfig, - *, - context: ClientCallContext | None = None, - ) -> TaskPushNotificationConfig: - """Sets or updates the push notification configuration for a specific task. - - Args: - request: The `TaskPushNotificationConfig` object specifying the task ID and configuration. - context: The client call context. - - Returns: - A `TaskPushNotificationConfig` object containing the config. - """ - config = await self.stub.CreateTaskPushNotificationConfig( - a2a_pb2.CreateTaskPushNotificationConfigRequest( - parent='', - config_id='', - config=proto_utils.ToProto.task_push_notification_config( - request - ), - ) - ) - return proto_utils.FromProto.task_push_notification_config_request( - config - ) - - async def get_task_callback( - self, - request: TaskIdParams, # TODO: Update to a push id params - *, - context: ClientCallContext | None = None, - ) -> TaskPushNotificationConfig: - """Retrieves the push notification configuration for a specific task. - - Args: - request: The `TaskIdParams` object specifying the task ID. - context: The client call context. - - Returns: - A `TaskPushNotificationConfig` object containing the configuration. - """ - config = await self.stub.GetTaskPushNotificationConfig( - a2a_pb2.GetTaskPushNotificationConfigRequest( - name=f'tasks/{request.id}/pushNotification/undefined', - ) - ) - return proto_utils.FromProto.task_push_notification_config_request( - config - ) - - async def get_card( - self, - *, - context: ClientCallContext | None = None, - ) -> AgentCard: - """Retrieves the authenticated card (if necessary) or the public one. - - Args: - context: The client call context. - - Returns: - A `AgentCard` object containing the card. - - Raises: - grpc.RpcError: If a gRPC error occurs during the request. - """ - # If we don't have the public card, try to get that first. - card = self.agent_card - if card is None and not self._needs_extended_card: - raise ValueError('Agent card is not available.') - - if not self._needs_extended_card: - return card - - card_pb = await self.stub.GetAgentCard( - a2a_pb2.GetAgentCardRequest(), - ) - card = proto_utils.FromProto.agent_card(card_pb) - self.agent_card = card - self._needs_extended_card = False - return card - - -class GrpcClient(Client): - """GrpcClient provides the Client interface for the gRPC transport.""" - - def __init__( - self, - card: AgentCard, - config: ClientConfig, - consumers: list[Consumer], - middleware: list[ClientCallInterceptor], - ): - super().__init__(consumers, middleware) - if not config.grpc_channel_factory: - raise ValueError('GRPC client requires channel factory.') - self._card = card - self._config = config - channel = config.grpc_channel_factory(self._card.url) - stub = a2a_pb2_grpc.A2AServiceStub(channel) - self._transport_client = GrpcTransportClient(stub, self._card) - - async def send_message( - self, - request: Message, - *, - context: ClientCallContext | None = None, - ) -> AsyncIterator[ClientEvent | Message]: - """Sends a message to the agent. - - This method handles both streaming and non-streaming (polling) interactions - based on the client configuration and agent capabilities. It will yield - events as they are received from the agent. - - Args: - request: The message to send to the agent. - context: The client call context. - - Yields: - An async iterator of `ClientEvent` or a final `Message` response. - """ - config = MessageSendConfiguration( - accepted_output_modes=self._config.accepted_output_modes, - blocking=not self._config.polling, - push_notification_config=( - self._config.push_notification_configs[0] - if self._config.push_notification_configs - else None - ), - ) - if not self._config.streaming or not self._card.capabilities.streaming: - response = await self._transport_client.send_message( - MessageSendParams( - message=request, - configuration=config, - ), - context=context, - ) - result = ( - (response, None) if isinstance(response, Task) else response - ) - await self.consume(result, self._card) - yield result - return - tracker = ClientTaskManager() - stream = self._transport_client.send_message_streaming( - MessageSendParams( - message=request, - configuration=config, - ), - context=context, - ) - # Only the first event may be a Message. All others must be Task - # or TaskStatusUpdates. Separate this one out, which allows our core - # event processing logic to ignore that case. - # TODO(mikeas1): Reconcile with other transport logic. - first_event = await anext(stream) - if isinstance(first_event, Message): - yield first_event - return - yield await self._process_response(tracker, first_event) - async for result in stream: - yield await self._process_response(tracker, result) - - async def _process_response( - self, - tracker: ClientTaskManager, - event: Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent, - ) -> ClientEvent: - result = event.root.result - # Update task, check for errors, etc. - if isinstance(result, Message): - raise A2AClientInvalidStateError( - 'received a streamed Message from server after first response; this' - ' is not supported' - ) - await tracker.process(result) - result = ( - tracker.get_task_or_raise(), - None if isinstance(result, Task) else result, - ) - await self.consume(result, self._card) - return result - - async def get_task( - self, - request: TaskQueryParams, - *, - context: ClientCallContext | None = None, - ) -> Task: - """Retrieves the current state and history of a specific task. - - Args: - request: The `TaskQueryParams` object specifying the task ID. - context: The client call context. - - Returns: - A `Task` object representing the current state of the task. - """ - return await self._transport_client.get_task( - request, - context=context, - ) - - async def cancel_task( - self, - request: TaskIdParams, - *, - context: ClientCallContext | None = None, - ) -> Task: - """Requests the agent to cancel a specific task. - - Args: - request: The `TaskIdParams` object specifying the task ID. - context: The client call context. - - Returns: - A `Task` object containing the updated task status. - """ - return await self._transport_client.cancel_task( - request, - context=context, - ) - - async def set_task_callback( - self, - request: TaskPushNotificationConfig, - *, - context: ClientCallContext | None = None, - ) -> TaskPushNotificationConfig: - """Sets or updates the push notification configuration for a specific task. - - Args: - request: The `TaskPushNotificationConfig` object with the new configuration. - context: The client call context. - - Returns: - The created or updated `TaskPushNotificationConfig` object. - """ - return await self._transport_client.set_task_callback( - request, - context=context, - ) - - async def get_task_callback( - self, - request: GetTaskPushNotificationConfigParams, - *, - context: ClientCallContext | None = None, - ) -> TaskPushNotificationConfig: - """Retrieves the push notification configuration for a specific task. - - Args: - request: The `GetTaskPushNotificationConfigParams` object specifying the task. - context: The client call context. - - Returns: - A `TaskPushNotificationConfig` object containing the configuration. - """ - return await self._transport_client.get_task_callback( - request, - context=context, - ) - - async def resubscribe( - self, - request: TaskIdParams, - *, - context: ClientCallContext | None = None, - ) -> AsyncIterator[ClientEvent]: - """Resubscribes to a task's event stream. - - This is only available if both the client and server support streaming. - - Args: - request: The `TaskIdParams` object specifying the task ID to resubscribe to. - context: The client call context. - - Yields: - An async iterator of `Task` or `Message` events. - - Raises: - Exception: If streaming is not supported by the client or server. - """ - if not self._config.streaming or not self._card.capabilities.streaming: - raise NotImplementedError( - 'client and/or server do not support resubscription.' - ) - if not self._transport_client: - raise ValueError('Transport client is not initialized.') - if not hasattr(self._transport_client, 'resubscribe'): - # This can happen if the proto definitions are out of date or the method is missing - raise NotImplementedError( - 'Resubscribe is not implemented on the gRPC transport client.' - ) - # Note: works correctly for resubscription where the first event is the - # current Task state. - tracker = ClientTaskManager() - async for result in self._transport_client.resubscribe( - request, - context=context, - ): - yield await self._process_response(tracker, result) - - async def get_card( - self, - *, - context: ClientCallContext | None = None, - ) -> AgentCard: - """Retrieves the agent's card. - - This will fetch the authenticated card if necessary and update the - client's internal state with the new card. - - Args: - context: The client call context. - - Returns: - The `AgentCard` for the agent. - """ - card = await self._transport_client.get_card( - context=context, - ) - self._card = card - return card - - -def NewGrpcClient( # noqa: N802 - card: AgentCard, - config: ClientConfig, - consumers: list[Consumer], - middleware: list[ClientCallInterceptor], -) -> Client: - """Generator for the `GrpcClient` implementation.""" - return GrpcClient(card, config, consumers, middleware) diff --git a/src/a2a/client/jsonrpc_client.py b/src/a2a/client/jsonrpc_client.py deleted file mode 100644 index 5bf23dc2..00000000 --- a/src/a2a/client/jsonrpc_client.py +++ /dev/null @@ -1,851 +0,0 @@ -import json -import logging - -from collections.abc import AsyncGenerator, AsyncIterator -from typing import Any -from uuid import uuid4 - -import httpx - -from httpx_sse import SSEError, aconnect_sse - -from a2a.client.card_resolver import A2ACardResolver -from a2a.client.client import ( - Client, - ClientConfig, - ClientEvent, - Consumer, -) -from a2a.client.client_task_manager import ClientTaskManager -from a2a.client.errors import ( - A2AClientHTTPError, - A2AClientInvalidStateError, - A2AClientJSONError, - A2AClientJSONRPCError, - A2AClientTimeoutError, -) -from a2a.client.middleware import ClientCallContext, ClientCallInterceptor -from a2a.types import ( - AgentCard, - CancelTaskRequest, - CancelTaskResponse, - GetTaskPushNotificationConfigParams, - GetTaskPushNotificationConfigRequest, - GetTaskPushNotificationConfigResponse, - GetTaskRequest, - GetTaskResponse, - JSONRPCErrorResponse, - Message, - MessageSendConfiguration, - MessageSendParams, - SendMessageRequest, - SendMessageResponse, - SendStreamingMessageRequest, - SendStreamingMessageResponse, - SetTaskPushNotificationConfigRequest, - SetTaskPushNotificationConfigResponse, - Task, - TaskIdParams, - TaskPushNotificationConfig, - TaskQueryParams, - TaskResubscriptionRequest, -) -from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH -from a2a.utils.telemetry import SpanKind, trace_class - - -logger = logging.getLogger(__name__) - - -@trace_class(kind=SpanKind.CLIENT) -class JsonRpcTransportClient: - """A2A Client for interacting with an A2A agent.""" - - def __init__( - self, - httpx_client: httpx.AsyncClient, - agent_card: AgentCard | None = None, - url: str | None = None, - interceptors: list[ClientCallInterceptor] | None = None, - ): - """Initializes the A2AClient. - - Requires either an `AgentCard` or a direct `url` to the agent's RPC endpoint. - - Args: - httpx_client: An async HTTP client instance (e.g., httpx.AsyncClient). - agent_card: The agent card object. If provided, `url` is taken from `agent_card.url`. - url: The direct URL to the agent's A2A RPC endpoint. Required if `agent_card` is None. - interceptors: An optional list of client call interceptors to apply to requests. - - Raises: - ValueError: If neither `agent_card` nor `url` is provided. - """ - if agent_card: - self.url = agent_card.url - elif url: - self.url = url - else: - raise ValueError('Must provide either agent_card or url') - - self.httpx_client = httpx_client - self.agent_card = agent_card - self.interceptors = interceptors or [] - # Indicate if we have captured an extended card details so we can update - # on first call if needed. It is done this way so the caller can setup - # their auth credentials based on the public card and get the updated - # card. - self._needs_extended_card = ( - not agent_card.supports_authenticated_extended_card - if agent_card - else True - ) - - async def _apply_interceptors( - self, - method_name: str, - request_payload: dict[str, Any], - http_kwargs: dict[str, Any] | None, - context: ClientCallContext | None, - ) -> tuple[dict[str, Any], dict[str, Any]]: - """Applies all registered interceptors to the request.""" - final_http_kwargs = http_kwargs or {} - final_request_payload = request_payload - - for interceptor in self.interceptors: - ( - final_request_payload, - final_http_kwargs, - ) = await interceptor.intercept( - method_name, - final_request_payload, - final_http_kwargs, - self.agent_card, - context, - ) - return final_request_payload, final_http_kwargs - - @staticmethod - async def get_client_from_agent_card_url( - httpx_client: httpx.AsyncClient, - base_url: str, - agent_card_path: str = AGENT_CARD_WELL_KNOWN_PATH, - http_kwargs: dict[str, Any] | None = None, - ) -> 'JsonRpcTransportClient': - """[deprecated] Fetches the public AgentCard and initializes an A2A client. - - This method will always fetch the public agent card. If an authenticated - or extended agent card is required, the A2ACardResolver should be used - directly to fetch the specific card, and then the A2AClient should be - instantiated with it. - - Args: - httpx_client: An async HTTP client instance (e.g., httpx.AsyncClient). - base_url: The base URL of the agent's host. - agent_card_path: The path to the agent card endpoint, relative to the base URL. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.get request when fetching the agent card. - - Returns: - An initialized `A2AClient` instance. - - Raises: - A2AClientHTTPError: If an HTTP error occurs fetching the agent card. - A2AClientJSONError: If the agent card response is invalid. - """ - agent_card: AgentCard = await A2ACardResolver( - httpx_client, base_url=base_url, agent_card_path=agent_card_path - ).get_agent_card( - http_kwargs=http_kwargs - ) # Fetches public card by default - return JsonRpcTransportClient( - httpx_client=httpx_client, agent_card=agent_card - ) - - async def send_message( - self, - request: SendMessageRequest, - *, - http_kwargs: dict[str, Any] | None = None, - context: ClientCallContext | None = None, - ) -> SendMessageResponse: - """Sends a non-streaming message request to the agent. - - Args: - request: The `SendMessageRequest` object containing the message and configuration. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. - - Returns: - A `SendMessageResponse` object containing the agent's response (Task or Message) or an error. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - if not request.id: - request.id = str(uuid4()) - - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - 'message/send', - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - context, - ) - response_data = await self._send_request(payload, modified_kwargs) - return SendMessageResponse.model_validate(response_data) - - async def send_message_streaming( - self, - request: SendStreamingMessageRequest, - *, - http_kwargs: dict[str, Any] | None = None, - context: ClientCallContext | None = None, - ) -> AsyncGenerator[SendStreamingMessageResponse]: - """Sends a streaming message request to the agent and yields responses as they arrive. - - This method uses Server-Sent Events (SSE) to receive a stream of updates from the agent. - - Args: - request: The `SendStreamingMessageRequest` object containing the message and configuration. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. A default `timeout=None` is set but can be overridden. - context: The client call context. - - Yields: - `SendStreamingMessageResponse` objects as they are received in the SSE stream. - These can be Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent. - - Raises: - A2AClientHTTPError: If an HTTP or SSE protocol error occurs during the request. - A2AClientJSONError: If an SSE event data cannot be decoded as JSON or validated. - """ - if not request.id: - request.id = str(uuid4()) - - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - 'message/stream', - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - context, - ) - - modified_kwargs.setdefault('timeout', None) - - async with aconnect_sse( - self.httpx_client, - 'POST', - self.url, - json=payload, - **modified_kwargs, - ) as event_source: - try: - async for sse in event_source.aiter_sse(): - yield SendStreamingMessageResponse.model_validate( - json.loads(sse.data) - ) - except SSEError as e: - raise A2AClientHTTPError( - 400, - f'Invalid SSE response or protocol error: {e}', - ) from e - except json.JSONDecodeError as e: - raise A2AClientJSONError(str(e)) from e - except httpx.RequestError as e: - raise A2AClientHTTPError( - 503, f'Network communication error: {e}' - ) from e - - async def _send_request( - self, - rpc_request_payload: dict[str, Any], - http_kwargs: dict[str, Any] | None = None, - ) -> dict[str, Any]: - """Sends a non-streaming JSON-RPC request to the agent. - - Args: - rpc_request_payload: JSON RPC payload for sending the request. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - - Returns: - The JSON response payload as a dictionary. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON. - """ - try: - response = await self.httpx_client.post( - self.url, json=rpc_request_payload, **(http_kwargs or {}) - ) - response.raise_for_status() - return response.json() - except httpx.ReadTimeout as e: - raise A2AClientTimeoutError('Client Request timed out') from e - except httpx.HTTPStatusError as e: - raise A2AClientHTTPError(e.response.status_code, str(e)) from e - except json.JSONDecodeError as e: - raise A2AClientJSONError(str(e)) from e - except httpx.RequestError as e: - raise A2AClientHTTPError( - 503, f'Network communication error: {e}' - ) from e - - async def get_task( - self, - request: GetTaskRequest, - *, - http_kwargs: dict[str, Any] | None = None, - context: ClientCallContext | None = None, - ) -> GetTaskResponse: - """Retrieves the current state and history of a specific task. - - Args: - request: The `GetTaskRequest` object specifying the task ID and history length. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. - - Returns: - A `GetTaskResponse` object containing the Task or an error. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - if not request.id: - request.id = str(uuid4()) - - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - 'tasks/get', - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - context, - ) - response_data = await self._send_request(payload, modified_kwargs) - return GetTaskResponse.model_validate(response_data) - - async def cancel_task( - self, - request: CancelTaskRequest, - *, - http_kwargs: dict[str, Any] | None = None, - context: ClientCallContext | None = None, - ) -> CancelTaskResponse: - """Requests the agent to cancel a specific task. - - Args: - request: The `CancelTaskRequest` object specifying the task ID. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. - - Returns: - A `CancelTaskResponse` object containing the updated Task with canceled status or an error. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - if not request.id: - request.id = str(uuid4()) - - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - 'tasks/cancel', - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - context, - ) - response_data = await self._send_request(payload, modified_kwargs) - return CancelTaskResponse.model_validate(response_data) - - async def set_task_callback( - self, - request: SetTaskPushNotificationConfigRequest, - *, - http_kwargs: dict[str, Any] | None = None, - context: ClientCallContext | None = None, - ) -> SetTaskPushNotificationConfigResponse: - """Sets or updates the push notification configuration for a specific task. - - Args: - request: The `SetTaskPushNotificationConfigRequest` object specifying the task ID and configuration. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. - - Returns: - A `SetTaskPushNotificationConfigResponse` object containing the confirmation or an error. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - if not request.id: - request.id = str(uuid4()) - - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - 'tasks/pushNotificationConfig/set', - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - context, - ) - response_data = await self._send_request(payload, modified_kwargs) - return SetTaskPushNotificationConfigResponse.model_validate( - response_data - ) - - async def get_task_callback( - self, - request: GetTaskPushNotificationConfigRequest, - *, - http_kwargs: dict[str, Any] | None = None, - context: ClientCallContext | None = None, - ) -> GetTaskPushNotificationConfigResponse: - """Retrieves the push notification configuration for a specific task. - - Args: - request: The `GetTaskPushNotificationConfigRequest` object specifying the task ID. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. - - Returns: - A `GetTaskPushNotificationConfigResponse` object containing the configuration or an error. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - if not request.id: - request.id = str(uuid4()) - - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - 'tasks/pushNotificationConfig/get', - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - context, - ) - response_data = await self._send_request(payload, modified_kwargs) - return GetTaskPushNotificationConfigResponse.model_validate( - response_data - ) - - async def resubscribe( - self, - request: TaskResubscriptionRequest, - *, - http_kwargs: dict[str, Any] | None = None, - context: ClientCallContext | None = None, - ) -> AsyncGenerator[SendStreamingMessageResponse]: - """Reconnects to get task updates. - - This method uses Server-Sent Events (SSE) to receive a stream of updates from the agent. - - Args: - request: The `TaskResubscriptionRequest` object containing the task information to reconnect to. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. A default `timeout=None` is set but can be overridden. - context: The client call context. - - Yields: - `SendStreamingMessageResponse` objects as they are received in the SSE stream. - These can be Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent. - - Raises: - A2AClientHTTPError: If an HTTP or SSE protocol error occurs during the request. - A2AClientJSONError: If an SSE event data cannot be decoded as JSON or validated. - """ - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - 'tasks/resubscribe', - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - context, - ) - - modified_kwargs.setdefault('timeout', None) - - async with aconnect_sse( - self.httpx_client, - 'POST', - self.url, - json=payload, - **modified_kwargs, - ) as event_source: - try: - async for sse in event_source.aiter_sse(): - yield SendStreamingMessageResponse.model_validate_json( - sse.data - ) - except SSEError as e: - raise A2AClientHTTPError( - 400, - f'Invalid SSE response or protocol error: {e}', - ) from e - except json.JSONDecodeError as e: - raise A2AClientJSONError(str(e)) from e - except httpx.RequestError as e: - raise A2AClientHTTPError( - 503, f'Network communication error: {e}' - ) from e - - async def get_card( - self, - *, - http_kwargs: dict[str, Any] | None = None, - context: ClientCallContext | None = None, - ) -> AgentCard: - """Retrieves the authenticated card (if necessary) or the public one. - - Args: - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. - - Returns: - A `AgentCard` object containing the card or an error. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - # If we don't have the public card, try to get that first. - card = self.agent_card - if not card: - resolver = A2ACardResolver(self.httpx_client, self.url) - card = await resolver.get_agent_card(http_kwargs=http_kwargs) - self._needs_extended_card = ( - card.supports_authenticated_extended_card - ) - self.agent_card = card - - if not self._needs_extended_card: - return card - - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - 'card/getAuthenticated', - {}, - http_kwargs, - context, - ) - response_data = await self._send_request(payload, modified_kwargs) - card = AgentCard.model_validate(response_data) - self.agent_card = card - self._needs_extended_card = False - return card - - -@trace_class(kind=SpanKind.CLIENT) -class JsonRpcClient(Client): - """JsonRpcClient is the implementation of the JSONRPC A2A client. - - This client proxies requests to the JsonRpcTransportClient implementation - and manages the JSONRPC specific details. If passing additional arguments - in the http.post command, these should be attached to the ClientCallContext - under the dictionary key 'http_kwargs'. - """ - - def __init__( - self, - card: AgentCard, - config: ClientConfig, - consumers: list[Consumer], - middleware: list[ClientCallInterceptor], - ): - super().__init__(consumers, middleware) - if not config.httpx_client: - raise Exception('JsonRpc client requires httpx client.') - self._card = card - url = card.url - self._config = config - self._transport_client = JsonRpcTransportClient( - config.httpx_client, self._card, url, middleware - ) - - def get_http_args( - self, context: ClientCallContext | None - ) -> dict[str, Any] | None: - """Extract HTTP-specific keyword arguments from the client call context. - - Args: - context: The client call context. - - Returns: - A dictionary of HTTP arguments, or None. - """ - return context.state.get('http_kwargs', None) if context else None - - async def send_message( - self, - request: Message, - *, - context: ClientCallContext | None = None, - ) -> AsyncIterator[ClientEvent | Message]: - """Send a message to the agent and consumes the response(s). - - This method handles both blocking (non-streaming) and streaming responses - based on the client configuration and agent capabilities. - - Args: - request: The message to send. - context: The client call context. - - Yields: - An async iterator of `ClientEvent` or a final `Message` response. - - Raises: - JSONRPCError: If the agent returns a JSON-RPC error in the response. - """ - config = MessageSendConfiguration( - accepted_output_modes=self._config.accepted_output_modes, - blocking=not self._config.polling, - push_notification_config=( - self._config.push_notification_configs[0] - if self._config.push_notification_configs - else None - ), - ) - if not self._config.streaming or not self._card.capabilities.streaming: - response = await self._transport_client.send_message( - SendMessageRequest( - params=MessageSendParams( - message=request, - configuration=config, - ), - id=str(uuid4()), - ), - http_kwargs=self.get_http_args(context), - context=context, - ) - if isinstance(response.root, JSONRPCErrorResponse): - raise A2AClientJSONRPCError(response.root) - result = response.root.result - result = result if isinstance(result, Message) else (result, None) - await self.consume(result, self._card) - yield result - return - tracker = ClientTaskManager() - stream = self._transport_client.send_message_streaming( - SendStreamingMessageRequest( - params=MessageSendParams( - message=request, - configuration=config, - ), - id=str(uuid4()), - ), - http_kwargs=self.get_http_args(context), - context=context, - ) - # Only the first event may be a Message. All others must be Task - # or TaskStatusUpdates. Separate this one out, which allows our core - # event processing logic to ignore that case. - first_event = await anext(stream) - if isinstance(first_event, Message): - yield first_event - return - yield await self._process_response(tracker, first_event) - async for event in stream: - yield await self._process_response(tracker, event) - - async def _process_response( - self, - tracker: ClientTaskManager, - event: SendStreamingMessageResponse, - ) -> ClientEvent: - if isinstance(event.root, JSONRPCErrorResponse): - raise A2AClientJSONRPCError(event.root) - result = event.root.result - # Update task, check for errors, etc. - if isinstance(result, Message): - raise A2AClientInvalidStateError( - 'received a streamed Message from server after first response; this' - ' is not supported' - ) - await tracker.process(result) - result = ( - tracker.get_task_or_raise(), - None if isinstance(result, Task) else result, - ) - await self.consume(result, self._card) - return result - - async def get_task( - self, - request: TaskQueryParams, - *, - context: ClientCallContext | None = None, - ) -> Task: - """Retrieve a task from the agent. - - Args: - request: Parameters to identify the task. - context: The client call context. - - Returns: - The requested task. - """ - response = await self._transport_client.get_task( - GetTaskRequest( - params=request, - id=str(uuid4()), - ), - http_kwargs=self.get_http_args(context), - context=context, - ) - return response.root.result - - async def cancel_task( - self, - request: TaskIdParams, - *, - context: ClientCallContext | None = None, - ) -> Task: - """Cancel an ongoing task on the agent. - - Args: - request: Parameters to identify the task to cancel. - context: The client call context. - - Returns: - The task after the cancellation request. - """ - response = await self._transport_client.cancel_task( - CancelTaskRequest( - params=request, - id=str(uuid4()), - ), - http_kwargs=self.get_http_args(context), - context=context, - ) - return response.root.result - - async def set_task_callback( - self, - request: TaskPushNotificationConfig, - *, - context: ClientCallContext | None = None, - ) -> TaskPushNotificationConfig: - """Set a push notification callback for a task. - - Args: - request: The push notification configuration to set. - context: The client call context. - - Returns: - The configured task push notification configuration. - """ - response = await self._transport_client.set_task_callback( - SetTaskPushNotificationConfigRequest( - params=request, - id=str(uuid4()), - ), - http_kwargs=self.get_http_args(context), - context=context, - ) - return response.root.result - - async def get_task_callback( - self, - request: GetTaskPushNotificationConfigParams, - *, - context: ClientCallContext | None = None, - ) -> TaskPushNotificationConfig: - """Retrieve the push notification callback configuration for a task. - - Args: - request: Parameters to identify the task and configuration. - context: The client call context. - - Returns: - The requested task push notification configuration. - """ - response = await self._transport_client.get_task_callback( - GetTaskPushNotificationConfigRequest( - params=request, - id=str(uuid4()), - ), - http_kwargs=self.get_http_args(context), - context=context, - ) - return response.root.result - - async def resubscribe( - self, - request: TaskIdParams, - *, - context: ClientCallContext | None = None, - ) -> AsyncIterator[ClientEvent]: - """Resubscribe to a task's event stream. - - This is only available if both the client and server support streaming. - - Args: - request: Parameters to identify the task to resubscribe to. - context: The client call context. - - Yields: - Task events from the agent. - - Raises: - Exception: If streaming is not supported. - """ - if not self._config.streaming or not self._card.capabilities.streaming: - raise NotImplementedError( - 'client and/or server do not support resubscription.' - ) - tracker = ClientTaskManager() - async for event in self._transport_client.resubscribe( - TaskResubscriptionRequest( - params=request, - id=str(uuid4()), - ), - http_kwargs=self.get_http_args(context), - context=context, - ): - yield await self._process_response(tracker, event) - - async def get_card( - self, - *, - context: ClientCallContext | None = None, - ) -> AgentCard: - """Retrieve the agent's card. - - This may involve fetching the public card first if not already available, - and then fetching the authenticated extended card if supported and required. - - Args: - context: The client call context. - - Returns: - The agent's card. - """ - return await self._transport_client.get_card( - http_kwargs=self.get_http_args(context), - context=context, - ) - - -def NewJsonRpcClient( # noqa: N802 - card: AgentCard, - config: ClientConfig, - consumers: list[Consumer], - middleware: list[ClientCallInterceptor], -) -> Client: - """Factory function for the `JsonRpcClient` implementation.""" - return JsonRpcClient(card, config, consumers, middleware) - - -# For backward compatability define this alias. This will be deprecated in -# a future release. -A2AClient = JsonRpcTransportClient diff --git a/src/a2a/client/legacy.py b/src/a2a/client/legacy.py new file mode 100644 index 00000000..dd289ded --- /dev/null +++ b/src/a2a/client/legacy.py @@ -0,0 +1,346 @@ +"""Backwards compatibility layer for legacy A2A clients.""" + +import warnings + +from collections.abc import AsyncGenerator +from typing import Any + +import httpx + +from a2a.client.errors import A2AClientJSONRPCError +from a2a.client.middleware import ClientCallContext, ClientCallInterceptor +from a2a.client.transports.jsonrpc import JsonRpcTransport +from a2a.types import ( + AgentCard, + CancelTaskRequest, + CancelTaskResponse, + CancelTaskSuccessResponse, + GetTaskPushNotificationConfigParams, + GetTaskPushNotificationConfigRequest, + GetTaskPushNotificationConfigResponse, + GetTaskPushNotificationConfigSuccessResponse, + GetTaskRequest, + GetTaskResponse, + GetTaskSuccessResponse, + JSONRPCErrorResponse, + SendMessageRequest, + SendMessageResponse, + SendMessageSuccessResponse, + SendStreamingMessageRequest, + SendStreamingMessageResponse, + SendStreamingMessageSuccessResponse, + SetTaskPushNotificationConfigRequest, + SetTaskPushNotificationConfigResponse, + SetTaskPushNotificationConfigSuccessResponse, + TaskIdParams, + TaskResubscriptionRequest, +) + + +class A2AClient: + """[DEPRECATED] Backwards compatibility wrapper for the JSON-RPC client.""" + + def __init__( + self, + httpx_client: httpx.AsyncClient, + agent_card: AgentCard | None = None, + url: str | None = None, + interceptors: list[ClientCallInterceptor] | None = None, + ): + warnings.warn( + 'A2AClient is deprecated and will be removed in a future version. ' + 'Use ClientFactory to create a client with a JSON-RPC transport.', + DeprecationWarning, + stacklevel=2, + ) + self._transport = JsonRpcTransport( + httpx_client, agent_card, url, interceptors + ) + + async def send_message( + self, + request: SendMessageRequest, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> SendMessageResponse: + """Sends a non-streaming message request to the agent. + + Args: + request: The `SendMessageRequest` object containing the message and configuration. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `SendMessageResponse` object containing the agent's response (Task or Message) or an error. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ + if not context and http_kwargs: + context = ClientCallContext(state={'http_kwargs': http_kwargs}) + + try: + result = await self._transport.send_message( + request.params, context=context + ) + return SendMessageResponse( + root=SendMessageSuccessResponse( + id=request.id, jsonrpc='2.0', result=result + ) + ) + except A2AClientJSONRPCError as e: + return SendMessageResponse(JSONRPCErrorResponse(error=e.error)) + + async def send_message_streaming( + self, + request: SendStreamingMessageRequest, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[SendStreamingMessageResponse, None]: + """Sends a streaming message request to the agent and yields responses as they arrive. + + This method uses Server-Sent Events (SSE) to receive a stream of updates from the agent. + + Args: + request: The `SendStreamingMessageRequest` object containing the message and configuration. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. A default `timeout=None` is set but can be overridden. + context: The client call context. + + Yields: + `SendStreamingMessageResponse` objects as they are received in the SSE stream. + These can be Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent. + + Raises: + A2AClientHTTPError: If an HTTP or SSE protocol error occurs during the request. + A2AClientJSONError: If an SSE event data cannot be decoded as JSON or validated. + """ + if not context and http_kwargs: + context = ClientCallContext(state={'http_kwargs': http_kwargs}) + + async for result in self._transport.send_message_streaming( + request.params, context=context + ): + yield SendStreamingMessageResponse( + root=SendStreamingMessageSuccessResponse( + id=request.id, jsonrpc='2.0', result=result + ) + ) + + async def get_task( + self, + request: GetTaskRequest, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> GetTaskResponse: + """Retrieves the current state and history of a specific task. + + Args: + request: The `GetTaskRequest` object specifying the task ID and history length. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `GetTaskResponse` object containing the Task or an error. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ + if not context and http_kwargs: + context = ClientCallContext(state={'http_kwargs': http_kwargs}) + try: + result = await self._transport.get_task( + request.params, context=context + ) + return GetTaskResponse( + root=GetTaskSuccessResponse( + id=request.id, jsonrpc='2.0', result=result + ) + ) + except A2AClientJSONRPCError as e: + return GetTaskResponse(root=JSONRPCErrorResponse(error=e.error)) + + async def cancel_task( + self, + request: CancelTaskRequest, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> CancelTaskResponse: + """Requests the agent to cancel a specific task. + + Args: + request: The `CancelTaskRequest` object specifying the task ID. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `CancelTaskResponse` object containing the updated Task with canceled status or an error. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ + if not context and http_kwargs: + context = ClientCallContext(state={'http_kwargs': http_kwargs}) + try: + result = await self._transport.cancel_task( + request.params, context=context + ) + return CancelTaskResponse( + root=CancelTaskSuccessResponse( + id=request.id, jsonrpc='2.0', result=result + ) + ) + except A2AClientJSONRPCError as e: + return CancelTaskResponse(JSONRPCErrorResponse(error=e.error)) + + async def set_task_callback( + self, + request: SetTaskPushNotificationConfigRequest, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> SetTaskPushNotificationConfigResponse: + """Sets or updates the push notification configuration for a specific task. + + Args: + request: The `SetTaskPushNotificationConfigRequest` object specifying the task ID and configuration. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `SetTaskPushNotificationConfigResponse` object containing the confirmation or an error. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ + if not context and http_kwargs: + context = ClientCallContext(state={'http_kwargs': http_kwargs}) + try: + result = await self._transport.set_task_callback( + request.params, context=context + ) + return SetTaskPushNotificationConfigResponse( + root=SetTaskPushNotificationConfigSuccessResponse( + id=request.id, jsonrpc='2.0', result=result + ) + ) + except A2AClientJSONRPCError as e: + return SetTaskPushNotificationConfigResponse( + JSONRPCErrorResponse(error=e.error) + ) + + async def get_task_callback( + self, + request: GetTaskPushNotificationConfigRequest, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> GetTaskPushNotificationConfigResponse: + """Retrieves the push notification configuration for a specific task. + + Args: + request: The `GetTaskPushNotificationConfigRequest` object specifying the task ID. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `GetTaskPushNotificationConfigResponse` object containing the configuration or an error. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ + if not context and http_kwargs: + context = ClientCallContext(state={'http_kwargs': http_kwargs}) + params = request.params + if isinstance(params, TaskIdParams): + params = GetTaskPushNotificationConfigParams( + id=request.params.task_id + ) + try: + result = await self._transport.get_task_callback( + params, context=context + ) + return GetTaskPushNotificationConfigResponse( + root=GetTaskPushNotificationConfigSuccessResponse( + id=request.id, jsonrpc='2.0', result=result + ) + ) + except A2AClientJSONRPCError as e: + return GetTaskPushNotificationConfigResponse( + JSONRPCErrorResponse(error=e.error) + ) + + async def resubscribe( + self, + request: TaskResubscriptionRequest, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[SendStreamingMessageResponse, None]: + """Reconnects to get task updates. + + This method uses Server-Sent Events (SSE) to receive a stream of updates from the agent. + + Args: + request: The `TaskResubscriptionRequest` object containing the task information to reconnect to. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. A default `timeout=None` is set but can be overridden. + context: The client call context. + + Yields: + `SendStreamingMessageResponse` objects as they are received in the SSE stream. + These can be Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent. + + Raises: + A2AClientHTTPError: If an HTTP or SSE protocol error occurs during the request. + A2AClientJSONError: If an SSE event data cannot be decoded as JSON or validated. + """ + if not context and http_kwargs: + context = ClientCallContext(state={'http_kwargs': http_kwargs}) + + async for result in self._transport.resubscribe( + request.params, context=context + ): + yield SendStreamingMessageResponse( + root=SendStreamingMessageSuccessResponse( + id=request.id, jsonrpc='2.0', result=result + ) + ) + + async def get_card( + self, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> AgentCard: + """Retrieves the authenticated card (if necessary) or the public one. + + Args: + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `AgentCard` object containing the card or an error. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ + if not context and http_kwargs: + context = ClientCallContext(state={'http_kwargs': http_kwargs}) + return await self._transport.get_card(context=context) diff --git a/src/a2a/client/legacy_grpc.py b/src/a2a/client/legacy_grpc.py new file mode 100644 index 00000000..9a4b1656 --- /dev/null +++ b/src/a2a/client/legacy_grpc.py @@ -0,0 +1,24 @@ +"""Backwards compatibility layer for the legacy A2A gRPC client.""" + +import warnings + +from a2a.client.transports.grpc import GrpcTransport +from a2a.grpc import a2a_pb2_grpc +from a2a.types import AgentCard + + +class A2AGrpcClient(GrpcTransport): + """[DEPRECATED] Backwards compatibility wrapper for the gRPC client.""" + + def __init__( + self, + grpc_stub: 'a2a_pb2_grpc.A2AServiceStub', + agent_card: AgentCard, + ): + warnings.warn( + 'A2AGrpcClient is deprecated and will be removed in a future version. ' + 'Use ClientFactory to create a client with a gRPC transport.', + DeprecationWarning, + stacklevel=2, + ) + super().__init__(grpc_stub, agent_card) diff --git a/src/a2a/client/rest_client.py b/src/a2a/client/rest_client.py deleted file mode 100644 index 552defc3..00000000 --- a/src/a2a/client/rest_client.py +++ /dev/null @@ -1,833 +0,0 @@ -import json -import logging - -from collections.abc import AsyncGenerator, AsyncIterator -from typing import Any - -import httpx - -from google.protobuf.json_format import MessageToDict, Parse, ParseDict -from httpx_sse import SSEError, aconnect_sse - -from a2a.client.card_resolver import A2ACardResolver -from a2a.client.client import Client, ClientConfig, ClientEvent, Consumer -from a2a.client.client_task_manager import ClientTaskManager -from a2a.client.errors import ( - A2AClientHTTPError, - A2AClientInvalidStateError, - A2AClientJSONError, -) -from a2a.client.middleware import ClientCallContext, ClientCallInterceptor -from a2a.grpc import a2a_pb2 -from a2a.types import ( - AgentCard, - GetTaskPushNotificationConfigParams, - Message, - MessageSendConfiguration, - MessageSendParams, - Task, - TaskArtifactUpdateEvent, - TaskIdParams, - TaskPushNotificationConfig, - TaskQueryParams, - TaskStatusUpdateEvent, -) -from a2a.utils import proto_utils -from a2a.utils.telemetry import SpanKind, trace_class - - -logger = logging.getLogger(__name__) - - -@trace_class(kind=SpanKind.CLIENT) -class RestTransportClient: - """A2A Client for interacting with an A2A agent.""" - - def __init__( - self, - httpx_client: httpx.AsyncClient, - agent_card: AgentCard | None = None, - url: str | None = None, - interceptors: list[ClientCallInterceptor] | None = None, - ): - """Initializes the A2AClient. - - Requires either an `AgentCard` or a direct `url` to the agent's RPC endpoint. - - Args: - httpx_client: An async HTTP client instance (e.g., httpx.AsyncClient). - agent_card: The agent card object. If provided, `url` is taken from `agent_card.url`. - url: The direct URL to the agent's A2A RPC endpoint. Required if `agent_card` is None. - interceptors: An optional list of client call interceptors to apply to requests. - - Raises: - ValueError: If neither `agent_card` nor `url` is provided. - """ - if agent_card: - self.url = agent_card.url - elif url: - self.url = url - else: - raise ValueError('Must provide either agent_card or url') - # If the url ends in / remove it as this is added by the routes - if self.url.endswith('/'): - self.url = self.url[:-1] - self.httpx_client = httpx_client - self.agent_card = agent_card - self.interceptors = interceptors or [] - # Indicate if we have captured an extended card details so we can update - # on first call if needed. It is done this way so the caller can setup - # their auth credentials based on the public card and get the updated - # card. - self._needs_extended_card = ( - not agent_card.supports_authenticated_extended_card - if agent_card - else True - ) - - async def _apply_interceptors( - self, - request_payload: dict[str, Any], - http_kwargs: dict[str, Any] | None, - context: ClientCallContext | None, - ) -> tuple[dict[str, Any], dict[str, Any]]: - """Applies all registered interceptors to the request.""" - final_http_kwargs = http_kwargs or {} - final_request_payload = request_payload - # TODO: Implement interceptors for other transports - return final_request_payload, final_http_kwargs - - async def send_message( - self, - request: MessageSendParams, - *, - http_kwargs: dict[str, Any] | None = None, - context: ClientCallContext | None = None, - ) -> Task | Message: - """Sends a non-streaming message request to the agent. - - Args: - request: The `MessageSendParams` object containing the message and configuration. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. - - Returns: - A `Task` or `Message` object containing the agent's response. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - pb = a2a_pb2.SendMessageRequest( - request=proto_utils.ToProto.message(request.message), - configuration=proto_utils.ToProto.message_send_configuration( - request.configuration - ), - metadata=( - proto_utils.ToProto.metadata(request.metadata) - if request.metadata - else None - ), - ) - payload = MessageToDict(pb) - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - payload, - http_kwargs, - context, - ) - response_data = await self._send_post_request( - '/v1/message:send', payload, modified_kwargs - ) - response_pb = a2a_pb2.SendMessageResponse() - ParseDict(response_data, response_pb) - return proto_utils.FromProto.task_or_message(response_pb) - - async def send_message_streaming( - self, - request: MessageSendParams, - *, - http_kwargs: dict[str, Any] | None = None, - context: ClientCallContext | None = None, - ) -> AsyncGenerator[ - Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message - ]: - """Sends a streaming message request to the agent and yields responses as they arrive. - - This method uses Server-Sent Events (SSE) to receive a stream of updates from the agent. - - Args: - request: The `MessageSendParams` object containing the message and configuration. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. A default `timeout=None` is set but can be overridden. - context: The client call context. - - Yields: - Objects as they are received in the SSE stream. - These can be Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent. - - Raises: - A2AClientHTTPError: If an HTTP or SSE protocol error occurs during the request. - A2AClientJSONError: If an SSE event data cannot be decoded as JSON or validated. - """ - pb = a2a_pb2.SendMessageRequest( - request=proto_utils.ToProto.message(request.message), - configuration=proto_utils.ToProto.message_send_configuration( - request.configuration - ), - metadata=( - proto_utils.ToProto.metadata(request.metadata) - if request.metadata - else None - ), - ) - payload = MessageToDict(pb) - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - payload, - http_kwargs, - context, - ) - - modified_kwargs.setdefault('timeout', None) - - async with aconnect_sse( - self.httpx_client, - 'POST', - f'{self.url}/v1/message:stream', - json=payload, - **modified_kwargs, - ) as event_source: - try: - async for sse in event_source.aiter_sse(): - event = a2a_pb2.StreamResponse() - Parse(sse.data, event) - yield proto_utils.FromProto.stream_response(event) - except SSEError as e: - raise A2AClientHTTPError( - 400, - f'Invalid SSE response or protocol error: {e}', - ) from e - except json.JSONDecodeError as e: - raise A2AClientJSONError(str(e)) from e - except httpx.RequestError as e: - raise A2AClientHTTPError( - 503, f'Network communication error: {e}' - ) from e - - async def _send_post_request( - self, - target: str, - rpc_request_payload: dict[str, Any], - http_kwargs: dict[str, Any] | None = None, - ) -> dict[str, Any]: - """Sends a non-streaming JSON-RPC request to the agent. - - Args: - target: url path - rpc_request_payload: JSON payload for sending the request. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - - Returns: - The JSON response payload as a dictionary. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON. - """ - try: - response = await self.httpx_client.post( - f'{self.url}{target}', - json=rpc_request_payload, - **(http_kwargs or {}), - ) - response.raise_for_status() - return response.json() - except httpx.HTTPStatusError as e: - raise A2AClientHTTPError(e.response.status_code, str(e)) from e - except json.JSONDecodeError as e: - raise A2AClientJSONError(str(e)) from e - except httpx.RequestError as e: - raise A2AClientHTTPError( - 503, f'Network communication error: {e}' - ) from e - - async def _send_get_request( - self, - target: str, - query_params: dict[str, str], - http_kwargs: dict[str, Any] | None = None, - ) -> dict[str, Any]: - """Sends a non-streaming JSON-RPC request to the agent. - - Args: - target: url path - query_params: HTTP query params for the request. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - - Returns: - The JSON response payload as a dictionary. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON. - """ - try: - response = await self.httpx_client.get( - f'{self.url}{target}', - params=query_params, - **(http_kwargs or {}), - ) - response.raise_for_status() - return response.json() - except httpx.HTTPStatusError as e: - raise A2AClientHTTPError(e.response.status_code, str(e)) from e - except json.JSONDecodeError as e: - raise A2AClientJSONError(str(e)) from e - except httpx.RequestError as e: - raise A2AClientHTTPError( - 503, f'Network communication error: {e}' - ) from e - - async def get_task( - self, - request: TaskQueryParams, - *, - http_kwargs: dict[str, Any] | None = None, - context: ClientCallContext | None = None, - ) -> Task: - """Retrieves the current state and history of a specific task. - - Args: - request: The `TaskQueryParams` object specifying the task ID and history length. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. - - Returns: - A `Task` object containing the Task. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - # Apply interceptors before sending - only for the http kwargs - payload, modified_kwargs = await self._apply_interceptors( - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - context, - ) - response_data = await self._send_get_request( - f'/v1/tasks/{request.taskId}', - {'historyLength': str(request.history_length)} - if request.history_length - else {}, - modified_kwargs, - ) - task = a2a_pb2.Task() - ParseDict(response_data, task) - return proto_utils.FromProto.task(task) - - async def cancel_task( - self, - request: TaskIdParams, - *, - http_kwargs: dict[str, Any] | None = None, - context: ClientCallContext | None = None, - ) -> Task: - """Requests the agent to cancel a specific task. - - Args: - request: The `TaskIdParams` object specifying the task ID. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. - - Returns: - A `Task` object containing the updated Task with canceled status - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - pb = a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}') - payload = MessageToDict(pb) - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - payload, - http_kwargs, - context, - ) - response_data = await self._send_post_request( - f'/v1/tasks/{request.id}:cancel', payload, modified_kwargs - ) - task = a2a_pb2.Task() - ParseDict(response_data, task) - return proto_utils.FromProto.task(task) - - async def set_task_callback( - self, - request: TaskPushNotificationConfig, - *, - http_kwargs: dict[str, Any] | None = None, - context: ClientCallContext | None = None, - ) -> TaskPushNotificationConfig: - """Sets or updates the push notification configuration for a specific task. - - Args: - request: The `TaskPushNotificationConfig` object specifying the task ID and configuration. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. - - Returns: - A `TaskPushNotificationConfig` object containing the confirmation. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - pb = a2a_pb2.CreateTaskPushNotificationConfigRequest( - parent=f'tasks/{request.taskId}', - config_id=request.push_notification_config.id, - config=proto_utils.ToProto.task_push_notification_config(request), - ) - payload = MessageToDict(pb) - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - payload, http_kwargs, context - ) - response_data = await self._send_post_request( - f'/v1/tasks/{request.taskId}/pushNotificationConfigs/', - payload, - modified_kwargs, - ) - config = a2a_pb2.TaskPushNotificationConfig() - ParseDict(response_data, config) - return proto_utils.FromProto.task_push_notification_config(config) - - async def get_task_callback( - self, - request: GetTaskPushNotificationConfigParams, - *, - http_kwargs: dict[str, Any] | None = None, - context: ClientCallContext | None = None, - ) -> TaskPushNotificationConfig: - """Retrieves the push notification configuration for a specific task. - - Args: - request: The `GetTaskPushNotificationConfigParams` object specifying the task ID. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. - - Returns: - A `TaskPushNotificationConfig` object containing the configuration. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - pb = a2a_pb2.GetTaskPushNotificationConfigRequest( - name=f'tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}', - ) - payload = MessageToDict(pb) - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - payload, - http_kwargs, - context, - ) - response_data = await self._send_get_request( - f'/v1/tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}', - {}, - modified_kwargs, - ) - config = a2a_pb2.TaskPushNotificationConfig() - ParseDict(response_data, config) - return proto_utils.FromProto.task_push_notification_config(config) - - async def resubscribe( - self, - request: TaskIdParams, - *, - http_kwargs: dict[str, Any] | None = None, - context: ClientCallContext | None = None, - ) -> AsyncGenerator[ - Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message - ]: - """Reconnects to get task updates. - - This method uses Server-Sent Events (SSE) to receive a stream of updates from the agent. - - Args: - request: The `TaskIdParams` object containing the task information to reconnect to. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. A default `timeout=None` is set but can be overridden. - context: The client call context. - - Yields: - Objects as they are received in the SSE stream. - These can be Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent. - - Raises: - A2AClientHTTPError: If an HTTP or SSE protocol error occurs during the request. - A2AClientJSONError: If an SSE event data cannot be decoded as JSON or validated. - """ - pb = a2a_pb2.TaskSubscriptionRequest( - name=f'tasks/{request.id}', - ) - payload = MessageToDict(pb) - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - payload, - http_kwargs, - context, - ) - - modified_kwargs.setdefault('timeout', None) - - async with aconnect_sse( - self.httpx_client, - 'GET', - f'{self.url}/v1/tasks/{request.id}:subscribe', - json=payload, - **modified_kwargs, - ) as event_source: - try: - async for sse in event_source.aiter_sse(): - event = a2a_pb2.StreamResponse() - Parse(sse.data, event) - yield proto_utils.FromProto.stream_response(event) - except SSEError as e: - raise A2AClientHTTPError( - 400, - f'Invalid SSE response or protocol error: {e}', - ) from e - except json.JSONDecodeError as e: - raise A2AClientJSONError(str(e)) from e - except httpx.RequestError as e: - raise A2AClientHTTPError( - 503, f'Network communication error: {e}' - ) from e - - async def get_card( - self, - *, - http_kwargs: dict[str, Any] | None = None, - context: ClientCallContext | None = None, - ) -> AgentCard: - """Retrieves the authenticated card (if necessary) or the public one. - - Args: - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. - - Returns: - A `AgentCard` object containing the card or an error. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - # If we don't have the public card, try to get that first. - card = self.agent_card - if not card: - resolver = A2ACardResolver(self.httpx_client, self.url) - card = await resolver.get_agent_card(http_kwargs=http_kwargs) - self._needs_extended_card = ( - card.supports_authenticated_extended_card - ) - self.agent_card = card - - if not self._needs_extended_card: - return card - - # Apply interceptors before sending - _, modified_kwargs = await self._apply_interceptors( - {}, - http_kwargs, - context, - ) - response_data = await self._send_get_request( - '/v1/card/get', {}, modified_kwargs - ) - card = AgentCard.model_validate(response_data) - self.agent_card = card - self._needs_extended_card = False - return card - - -@trace_class(kind=SpanKind.CLIENT) -class RestClient(Client): - """RestClient is the implementation of the RESTful A2A client. - - This client proxies requests to the RestTransportClient implementation - and manages the REST specific details. If passing additional arguments - in the http.post command, these should be attached to the ClientCallContext - under the dictionary key 'http_kwargs'. - """ - - def __init__( - self, - card: AgentCard, - config: ClientConfig, - consumers: list[Consumer], - middleware: list[ClientCallInterceptor], - ): - super().__init__(consumers, middleware) - if not config.httpx_client: - raise ValueError('RestClient client requires httpx client.') - self._card = card - url = card.url - self._config = config - self._transport_client = RestTransportClient( - config.httpx_client, self._card, url, middleware - ) - - def get_http_args( - self, context: ClientCallContext | None - ) -> dict[str, Any] | None: - """Extract HTTP-specific keyword arguments from the client call context. - - Args: - context: The client call context. - - Returns: - A dictionary of HTTP arguments, or None. - """ - return context.state.get('http_kwargs', None) if context else None - - async def send_message( - self, - request: Message, - *, - context: ClientCallContext | None = None, - ) -> AsyncIterator[Message | ClientEvent]: - """Send a message to the agent and consumes the response(s). - - This method handles both blocking (non-streaming) and streaming responses - based on the client configuration and agent capabilities. - - Args: - request: The message to send. - context: The client call context. - - Yields: - The final message or task result from the agent. - """ - config = MessageSendConfiguration( - accepted_output_modes=self._config.accepted_output_modes, - blocking=not self._config.polling, - push_notification_config=( - self._config.push_notification_configs[0] - if self._config.push_notification_configs - else None - ), - ) - if not self._config.streaming or not self._card.capabilities.streaming: - response = await self._transport_client.send_message( - MessageSendParams( - message=request, - configuration=config, - ), - http_kwargs=self.get_http_args(context), - context=context, - ) - result = ( - response if isinstance(response, Message) else (response, None) - ) - await self.consume(result, self._card) - yield result - return - tracker = ClientTaskManager() - stream = self._transport_client.send_message_streaming( - MessageSendParams( - message=request, - configuration=config, - ), - http_kwargs=self.get_http_args(context), - context=context, - ) - # Only the first event may be a Message. All others must be Task - # or TaskStatusUpdates. Separate this one out, which allows our core - # event processing logic to ignore that case. - first_event = await anext(stream) - if isinstance(first_event, Message): - yield first_event - return - yield await self._process_response(tracker, first_event) - async for event in stream: - yield await self._process_response(tracker, event) - - async def _process_response( - self, - tracker: ClientTaskManager, - event: Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message, - ) -> ClientEvent: - result = event.root.result - # Update task, check for errors, etc. - if isinstance(result, Message): - raise A2AClientInvalidStateError( - 'received a streamed Message from server after first response; this' - ' is not supported' - ) - await tracker.process(result) - result = ( - tracker.get_task_or_raise(), - None if isinstance(result, Task) else result, - ) - await self.consume(result, self._card) - return result - - async def get_task( - self, - request: TaskQueryParams, - *, - context: ClientCallContext | None = None, - ) -> Task: - """Retrieve a task from the agent. - - Args: - request: Parameters to identify the task. - context: The client call context. - - Returns: - The requested task. - """ - return await self._transport_client.get_task( - request, - http_kwargs=self.get_http_args(context), - context=context, - ) - - async def cancel_task( - self, - request: TaskIdParams, - *, - context: ClientCallContext | None = None, - ) -> Task: - """Cancel an ongoing task on the agent. - - Args: - request: Parameters to identify the task to cancel. - context: The client call context. - - Returns: - The task after the cancellation request. - """ - return await self._transport_client.cancel_task( - request, - http_kwargs=self.get_http_args(context), - context=context, - ) - - async def set_task_callback( - self, - request: TaskPushNotificationConfig, - *, - context: ClientCallContext | None = None, - ) -> TaskPushNotificationConfig: - """Set a push notification callback for a task. - - Args: - request: The push notification configuration to set. - context: The client call context. - - Returns: - The configured task push notification configuration. - """ - return await self._transport_client.set_task_callback( - request, - http_kwargs=self.get_http_args(context), - context=context, - ) - - async def get_task_callback( - self, - request: GetTaskPushNotificationConfigParams, - *, - context: ClientCallContext | None = None, - ) -> TaskPushNotificationConfig: - """Retrieve the push notification callback configuration for a task. - - Args: - request: Parameters to identify the task and configuration. - context: The client call context. - - Returns: - The requested task push notification configuration. - """ - return await self._transport_client.get_task_callback( - request, - http_kwargs=self.get_http_args(context), - context=context, - ) - - async def resubscribe( - self, - request: TaskIdParams, - *, - context: ClientCallContext | None = None, - ) -> AsyncIterator[ClientEvent]: - """Resubscribe to a task's event stream. - - This is only available if both the client and server support streaming. - - Args: - request: Parameters to identify the task to resubscribe to. - context: The client call context. - - Yields: - Task events from the agent. - - Raises: - Exception: If streaming is not supported. - """ - if not self._config.streaming or not self._card.capabilities.streaming: - raise NotImplementedError( - 'client and/or server do not support resubscription.' - ) - tracker = ClientTaskManager() - async for event in self._transport_client.resubscribe( - request, - http_kwargs=self.get_http_args(context), - context=context, - ): - # Update task, check for errors, etc. - yield await self._process_response(tracker, event) - - async def get_card( - self, - *, - context: ClientCallContext | None = None, - ) -> AgentCard: - """Retrieve the agent's card. - - This may involve fetching the public card first if not already available, - and then fetching the authenticated extended card if supported and required. - - Args: - context: The client call context. - - Returns: - The agent's card. - """ - return await self._transport_client.get_card( - http_kwargs=self.get_http_args(context), - context=context, - ) - - -def NewRestfulClient( # noqa: N802 - card: AgentCard, - config: ClientConfig, - consumers: list[Consumer], - middleware: list[ClientCallInterceptor], -) -> Client: - """Factory function for the `RestClient` implementation.""" - return RestClient(card, config, consumers, middleware) diff --git a/src/a2a/client/transports/__init__.py b/src/a2a/client/transports/__init__.py new file mode 100644 index 00000000..0e600ff4 --- /dev/null +++ b/src/a2a/client/transports/__init__.py @@ -0,0 +1,19 @@ +"""A2A Client Transports.""" + +from a2a.client.transports.base import ClientTransport +from a2a.client.transports.jsonrpc import JsonRpcTransport +from a2a.client.transports.rest import RestTransport + + +try: + from a2a.client.transports.grpc import GrpcTransport +except ImportError: + GrpcTransport = None + + +__all__ = [ + 'ClientTransport', + 'GrpcTransport', + 'JsonRpcTransport', + 'RestTransport', +] diff --git a/src/a2a/client/transports/base.py b/src/a2a/client/transports/base.py new file mode 100644 index 00000000..3573cb7c --- /dev/null +++ b/src/a2a/client/transports/base.py @@ -0,0 +1,103 @@ +from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator + +from a2a.client.middleware import ClientCallContext +from a2a.types import ( + AgentCard, + GetTaskPushNotificationConfigParams, + Message, + MessageSendParams, + Task, + TaskArtifactUpdateEvent, + TaskIdParams, + TaskPushNotificationConfig, + TaskQueryParams, + TaskStatusUpdateEvent, +) + + +class ClientTransport(ABC): + """Abstract base class for a client transport.""" + + @abstractmethod + async def send_message( + self, + request: MessageSendParams, + *, + context: ClientCallContext | None = None, + ) -> Task | Message: + """Sends a non-streaming message request to the agent.""" + + @abstractmethod + async def send_message_streaming( + self, + request: MessageSendParams, + *, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[ + Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent + ]: + """Sends a streaming message request to the agent and yields responses as they arrive.""" + return + yield + + @abstractmethod + async def get_task( + self, + request: TaskQueryParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + """Retrieves the current state and history of a specific task.""" + + @abstractmethod + async def cancel_task( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + """Requests the agent to cancel a specific task.""" + + @abstractmethod + async def set_task_callback( + self, + request: TaskPushNotificationConfig, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Sets or updates the push notification configuration for a specific task.""" + + @abstractmethod + async def get_task_callback( + self, + request: GetTaskPushNotificationConfigParams, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Retrieves the push notification configuration for a specific task.""" + + @abstractmethod + async def resubscribe( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[ + Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent + ]: + """Reconnects to get task updates.""" + return + yield + + @abstractmethod + async def get_card( + self, + *, + context: ClientCallContext | None = None, + ) -> AgentCard: + """Retrieves the AgentCard.""" + + @abstractmethod + async def close(self) -> None: + """Closes the transport.""" diff --git a/src/a2a/client/transports/grpc.py b/src/a2a/client/transports/grpc.py new file mode 100644 index 00000000..c340ed63 --- /dev/null +++ b/src/a2a/client/transports/grpc.py @@ -0,0 +1,193 @@ +import logging + +from collections.abc import AsyncGenerator + + +try: + import grpc +except ImportError as e: + raise ImportError( + 'A2AGrpcClient requires grpcio and grpcio-tools to be installed. ' + 'Install with: ' + "'pip install a2a-sdk[grpc]'" + ) from e + +from a2a.client.middleware import ClientCallContext +from a2a.client.transports.base import ClientTransport +from a2a.grpc import a2a_pb2, a2a_pb2_grpc +from a2a.types import ( + AgentCard, + GetTaskPushNotificationConfigParams, + Message, + MessageSendParams, + Task, + TaskArtifactUpdateEvent, + TaskIdParams, + TaskPushNotificationConfig, + TaskQueryParams, + TaskStatusUpdateEvent, +) +from a2a.utils import proto_utils +from a2a.utils.telemetry import SpanKind, trace_class + + +logger = logging.getLogger(__name__) + + +@trace_class(kind=SpanKind.CLIENT) +class GrpcTransport(ClientTransport): + """A gRPC transport for the A2A client.""" + + def __init__( + self, + grpc_stub: a2a_pb2_grpc.A2AServiceStub, + agent_card: AgentCard | None, + ): + """Initializes the GrpcTransport.""" + self.agent_card = agent_card + self.stub = grpc_stub + self._needs_extended_card = ( + agent_card.supports_authenticated_extended_card + if agent_card + else True + ) + + async def send_message( + self, + request: MessageSendParams, + *, + context: ClientCallContext | None = None, + ) -> Task | Message: + """Sends a non-streaming message request to the agent.""" + response = await self.stub.SendMessage( + a2a_pb2.SendMessageRequest( + request=proto_utils.ToProto.message(request.message), + configuration=proto_utils.ToProto.message_send_configuration( + request.configuration + ), + metadata=proto_utils.ToProto.metadata(request.metadata), + ) + ) + if response.task: + return proto_utils.FromProto.task(response.task) + return proto_utils.FromProto.message(response.msg) + + async def send_message_streaming( + self, + request: MessageSendParams, + *, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[ + Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent + ]: + """Sends a streaming message request to the agent and yields responses as they arrive.""" + stream = self.stub.SendStreamingMessage( + a2a_pb2.SendMessageRequest( + request=proto_utils.ToProto.message(request.message), + configuration=proto_utils.ToProto.message_send_configuration( + request.configuration + ), + metadata=proto_utils.ToProto.metadata(request.metadata), + ) + ) + while True: + response = await stream.read() + if response == grpc.aio.EOF: + break + yield proto_utils.FromProto.stream_response(response) + + async def resubscribe( + self, request: TaskIdParams, *, context: ClientCallContext | None = None + ) -> AsyncGenerator[ + Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent + ]: + """Reconnects to get task updates.""" + stream = self.stub.TaskSubscription( + a2a_pb2.TaskSubscriptionRequest(name=f'tasks/{request.id}') + ) + while True: + response = await stream.read() + if response == grpc.aio.EOF: + break + yield proto_utils.FromProto.stream_response(response) + + async def get_task( + self, + request: TaskQueryParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + """Retrieves the current state and history of a specific task.""" + task = await self.stub.GetTask( + a2a_pb2.GetTaskRequest(name=f'tasks/{request.id}') + ) + return proto_utils.FromProto.task(task) + + async def cancel_task( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + """Requests the agent to cancel a specific task.""" + task = await self.stub.CancelTask( + a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}') + ) + return proto_utils.FromProto.task(task) + + async def set_task_callback( + self, + request: TaskPushNotificationConfig, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Sets or updates the push notification configuration for a specific task.""" + config = await self.stub.CreateTaskPushNotificationConfig( + a2a_pb2.CreateTaskPushNotificationConfigRequest( + parent=f'tasks/{request.task_id}', + config_id=request.push_notification_config.id, + config=proto_utils.ToProto.task_push_notification_config( + request + ), + ) + ) + return proto_utils.FromProto.task_push_notification_config(config) + + async def get_task_callback( + self, + request: GetTaskPushNotificationConfigParams, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Retrieves the push notification configuration for a specific task.""" + config = await self.stub.GetTaskPushNotificationConfig( + a2a_pb2.GetTaskPushNotificationConfigRequest( + name=f'tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}', + ) + ) + return proto_utils.FromProto.task_push_notification_config(config) + + async def get_card( + self, + *, + context: ClientCallContext | None = None, + ) -> AgentCard: + """Retrieves the agent's card.""" + card = self.agent_card + if card and not self._needs_extended_card: + return card + if card is None and not self._needs_extended_card: + raise ValueError('Agent card is not available.') + + card_pb = await self.stub.GetAgentCard( + a2a_pb2.GetAgentCardRequest(), + ) + card = proto_utils.FromProto.agent_card(card_pb) + self.agent_card = card + self._needs_extended_card = False + return card + + async def close(self) -> None: + """Closes the gRPC channel.""" + if hasattr(self.stub, 'close'): + await self.stub.close() diff --git a/src/a2a/client/transports/jsonrpc.py b/src/a2a/client/transports/jsonrpc.py new file mode 100644 index 00000000..868b3a01 --- /dev/null +++ b/src/a2a/client/transports/jsonrpc.py @@ -0,0 +1,376 @@ +import json +import logging + +from collections.abc import AsyncGenerator +from typing import Any +from uuid import uuid4 + +import httpx + +from httpx_sse import SSEError, aconnect_sse + +from a2a.client.card_resolver import A2ACardResolver +from a2a.client.errors import ( + A2AClientHTTPError, + A2AClientJSONError, + A2AClientJSONRPCError, + A2AClientTimeoutError, +) +from a2a.client.middleware import ClientCallContext, ClientCallInterceptor +from a2a.client.transports.base import ClientTransport +from a2a.types import ( + AgentCard, + CancelTaskRequest, + CancelTaskResponse, + GetAuthenticatedExtendedCardRequest, + GetAuthenticatedExtendedCardResponse, + GetTaskPushNotificationConfigParams, + GetTaskPushNotificationConfigRequest, + GetTaskPushNotificationConfigResponse, + GetTaskRequest, + GetTaskResponse, + JSONRPCErrorResponse, + Message, + MessageSendParams, + SendMessageRequest, + SendMessageResponse, + SendStreamingMessageRequest, + SendStreamingMessageResponse, + SetTaskPushNotificationConfigRequest, + SetTaskPushNotificationConfigResponse, + Task, + TaskArtifactUpdateEvent, + TaskIdParams, + TaskPushNotificationConfig, + TaskQueryParams, + TaskResubscriptionRequest, + TaskStatusUpdateEvent, +) +from a2a.utils.telemetry import SpanKind, trace_class + + +logger = logging.getLogger(__name__) + + +@trace_class(kind=SpanKind.CLIENT) +class JsonRpcTransport(ClientTransport): + """A JSON-RPC transport for the A2A client.""" + + def __init__( + self, + httpx_client: httpx.AsyncClient, + agent_card: AgentCard | None = None, + url: str | None = None, + interceptors: list[ClientCallInterceptor] | None = None, + ): + """Initializes the JsonRpcTransport.""" + if url: + self.url = url + elif agent_card: + self.url = agent_card.url + else: + raise ValueError('Must provide either agent_card or url') + + self.httpx_client = httpx_client + self.agent_card = agent_card + self.interceptors = interceptors or [] + self._needs_extended_card = ( + agent_card.supports_authenticated_extended_card + if agent_card + else True + ) + + async def _apply_interceptors( + self, + method_name: str, + request_payload: dict[str, Any], + http_kwargs: dict[str, Any] | None, + context: ClientCallContext | None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + final_http_kwargs = http_kwargs or {} + final_request_payload = request_payload + + for interceptor in self.interceptors: + ( + final_request_payload, + final_http_kwargs, + ) = await interceptor.intercept( + method_name, + final_request_payload, + final_http_kwargs, + self.agent_card, + context, + ) + return final_request_payload, final_http_kwargs + + def _get_http_args( + self, context: ClientCallContext | None + ) -> dict[str, Any] | None: + return context.state.get('http_kwargs') if context else None + + async def send_message( + self, + request: MessageSendParams, + *, + context: ClientCallContext | None = None, + ) -> Task | Message: + """Sends a non-streaming message request to the agent.""" + rpc_request = SendMessageRequest(params=request, id=str(uuid4())) + payload, modified_kwargs = await self._apply_interceptors( + 'message/send', + rpc_request.model_dump(mode='json', exclude_none=True), + self._get_http_args(context), + context, + ) + response_data = await self._send_request(payload, modified_kwargs) + response = SendMessageResponse.model_validate(response_data) + if isinstance(response.root, JSONRPCErrorResponse): + raise A2AClientJSONRPCError(response.root) + return response.root.result + + async def send_message_streaming( + self, + request: MessageSendParams, + *, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[ + Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent + ]: + """Sends a streaming message request to the agent and yields responses as they arrive.""" + rpc_request = SendStreamingMessageRequest( + params=request, id=str(uuid4()) + ) + payload, modified_kwargs = await self._apply_interceptors( + 'message/stream', + rpc_request.model_dump(mode='json', exclude_none=True), + self._get_http_args(context), + context, + ) + + modified_kwargs.setdefault('timeout', None) + + async with aconnect_sse( + self.httpx_client, + 'POST', + self.url, + json=payload, + **modified_kwargs, + ) as event_source: + try: + async for sse in event_source.aiter_sse(): + response = SendStreamingMessageResponse.model_validate( + json.loads(sse.data) + ) + if isinstance(response.root, JSONRPCErrorResponse): + raise A2AClientJSONRPCError(response.root) + yield response.root.result + except SSEError as e: + raise A2AClientHTTPError( + 400, f'Invalid SSE response or protocol error: {e}' + ) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + except httpx.RequestError as e: + raise A2AClientHTTPError( + 503, f'Network communication error: {e}' + ) from e + + async def _send_request( + self, + rpc_request_payload: dict[str, Any], + http_kwargs: dict[str, Any] | None = None, + ) -> dict[str, Any]: + try: + response = await self.httpx_client.post( + self.url, json=rpc_request_payload, **(http_kwargs or {}) + ) + response.raise_for_status() + return response.json() + except httpx.ReadTimeout as e: + raise A2AClientTimeoutError('Client Request timed out') from e + except httpx.HTTPStatusError as e: + raise A2AClientHTTPError(e.response.status_code, str(e)) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + except httpx.RequestError as e: + raise A2AClientHTTPError( + 503, f'Network communication error: {e}' + ) from e + + async def get_task( + self, + request: TaskQueryParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + """Retrieves the current state and history of a specific task.""" + rpc_request = GetTaskRequest(params=request, id=str(uuid4())) + payload, modified_kwargs = await self._apply_interceptors( + 'tasks/get', + rpc_request.model_dump(mode='json', exclude_none=True), + self._get_http_args(context), + context, + ) + response_data = await self._send_request(payload, modified_kwargs) + response = GetTaskResponse.model_validate(response_data) + if isinstance(response.root, JSONRPCErrorResponse): + raise A2AClientJSONRPCError(response.root) + return response.root.result + + async def cancel_task( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + """Requests the agent to cancel a specific task.""" + rpc_request = CancelTaskRequest(params=request, id=str(uuid4())) + payload, modified_kwargs = await self._apply_interceptors( + 'tasks/cancel', + rpc_request.model_dump(mode='json', exclude_none=True), + self._get_http_args(context), + context, + ) + response_data = await self._send_request(payload, modified_kwargs) + response = CancelTaskResponse.model_validate(response_data) + if isinstance(response.root, JSONRPCErrorResponse): + raise A2AClientJSONRPCError(response.root) + return response.root.result + + async def set_task_callback( + self, + request: TaskPushNotificationConfig, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Sets or updates the push notification configuration for a specific task.""" + rpc_request = SetTaskPushNotificationConfigRequest( + params=request, id=str(uuid4()) + ) + payload, modified_kwargs = await self._apply_interceptors( + 'tasks/pushNotificationConfig/set', + rpc_request.model_dump(mode='json', exclude_none=True), + self._get_http_args(context), + context, + ) + response_data = await self._send_request(payload, modified_kwargs) + response = SetTaskPushNotificationConfigResponse.model_validate( + response_data + ) + if isinstance(response.root, JSONRPCErrorResponse): + raise A2AClientJSONRPCError(response.root) + return response.root.result + + async def get_task_callback( + self, + request: GetTaskPushNotificationConfigParams, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Retrieves the push notification configuration for a specific task.""" + rpc_request = GetTaskPushNotificationConfigRequest( + params=request, id=str(uuid4()) + ) + payload, modified_kwargs = await self._apply_interceptors( + 'tasks/pushNotificationConfig/get', + rpc_request.model_dump(mode='json', exclude_none=True), + self._get_http_args(context), + context, + ) + response_data = await self._send_request(payload, modified_kwargs) + response = GetTaskPushNotificationConfigResponse.model_validate( + response_data + ) + if isinstance(response.root, JSONRPCErrorResponse): + raise A2AClientJSONRPCError(response.root) + return response.root.result + + async def resubscribe( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[ + Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent + ]: + """Reconnects to get task updates.""" + rpc_request = TaskResubscriptionRequest(params=request, id=str(uuid4())) + payload, modified_kwargs = await self._apply_interceptors( + 'tasks/resubscribe', + rpc_request.model_dump(mode='json', exclude_none=True), + self._get_http_args(context), + context, + ) + + modified_kwargs.setdefault('timeout', None) + + async with aconnect_sse( + self.httpx_client, + 'POST', + self.url, + json=payload, + **modified_kwargs, + ) as event_source: + try: + async for sse in event_source.aiter_sse(): + response = SendStreamingMessageResponse.model_validate_json( + sse.data + ) + if isinstance(response.root, JSONRPCErrorResponse): + raise A2AClientJSONRPCError(response.root) + yield response.root.result + except SSEError as e: + raise A2AClientHTTPError( + 400, f'Invalid SSE response or protocol error: {e}' + ) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + except httpx.RequestError as e: + raise A2AClientHTTPError( + 503, f'Network communication error: {e}' + ) from e + + async def get_card( + self, + *, + context: ClientCallContext | None = None, + ) -> AgentCard: + """Retrieves the agent's card.""" + card = self.agent_card + if not card: + resolver = A2ACardResolver(self.httpx_client, self.url) + card = await resolver.get_agent_card( + http_kwargs=self._get_http_args(context) + ) + self._needs_extended_card = ( + card.supports_authenticated_extended_card + ) + self.agent_card = card + + if not self._needs_extended_card: + return card + + request = GetAuthenticatedExtendedCardRequest(id=str(uuid4())) + payload, modified_kwargs = await self._apply_interceptors( + request.method, + request.model_dump(mode='json', exclude_none=True), + self._get_http_args(context), + context, + ) + + response_data = await self._send_request( + payload, + modified_kwargs, + ) + response = GetAuthenticatedExtendedCardResponse.model_validate( + response_data + ) + if isinstance(response.root, JSONRPCErrorResponse): + raise A2AClientJSONRPCError(response.root) + self.agent_card = response.root.result + self._needs_extended_card = False + return card + + async def close(self) -> None: + """Closes the httpx client.""" + await self.httpx_client.aclose() diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py new file mode 100644 index 00000000..430d642c --- /dev/null +++ b/src/a2a/client/transports/rest.py @@ -0,0 +1,373 @@ +import json +import logging + +from collections.abc import AsyncGenerator +from typing import Any + +import httpx + +from google.protobuf.json_format import MessageToDict, Parse, ParseDict +from httpx_sse import SSEError, aconnect_sse + +from a2a.client.card_resolver import A2ACardResolver +from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError +from a2a.client.middleware import ClientCallContext, ClientCallInterceptor +from a2a.client.transports.base import ClientTransport +from a2a.grpc import a2a_pb2 +from a2a.types import ( + AgentCard, + GetTaskPushNotificationConfigParams, + Message, + MessageSendParams, + Task, + TaskArtifactUpdateEvent, + TaskIdParams, + TaskPushNotificationConfig, + TaskQueryParams, + TaskStatusUpdateEvent, +) +from a2a.utils import proto_utils +from a2a.utils.telemetry import SpanKind, trace_class + + +logger = logging.getLogger(__name__) + + +@trace_class(kind=SpanKind.CLIENT) +class RestTransport(ClientTransport): + """A REST transport for the A2A client.""" + + def __init__( + self, + httpx_client: httpx.AsyncClient, + agent_card: AgentCard | None = None, + url: str | None = None, + interceptors: list[ClientCallInterceptor] | None = None, + ): + """Initializes the RestTransport.""" + if url: + self.url = url + elif agent_card: + self.url = agent_card.url + else: + raise ValueError('Must provide either agent_card or url') + if self.url.endswith('/'): + self.url = self.url[:-1] + self.httpx_client = httpx_client + self.agent_card = agent_card + self.interceptors = interceptors or [] + self._needs_extended_card = ( + agent_card.supports_authenticated_extended_card + if agent_card + else True + ) + + async def _apply_interceptors( + self, + request_payload: dict[str, Any], + http_kwargs: dict[str, Any] | None, + context: ClientCallContext | None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + final_http_kwargs = http_kwargs or {} + final_request_payload = request_payload + # TODO: Implement interceptors for other transports + return final_request_payload, final_http_kwargs + + def _get_http_args( + self, context: ClientCallContext | None + ) -> dict[str, Any] | None: + return context.state.get('http_kwargs') if context else None + + async def send_message( + self, + request: MessageSendParams, + *, + context: ClientCallContext | None = None, + ) -> Task | Message: + """Sends a non-streaming message request to the agent.""" + pb = a2a_pb2.SendMessageRequest( + request=proto_utils.ToProto.message(request.message), + configuration=proto_utils.ToProto.message_send_configuration( + request.configuration + ), + metadata=( + proto_utils.ToProto.metadata(request.metadata) + if request.metadata + else None + ), + ) + payload = MessageToDict(pb) + payload, modified_kwargs = await self._apply_interceptors( + payload, + self._get_http_args(context), + context, + ) + response_data = await self._send_post_request( + '/v1/message:send', payload, modified_kwargs + ) + response_pb = a2a_pb2.SendMessageResponse() + ParseDict(response_data, response_pb) + return proto_utils.FromProto.task_or_message(response_pb) + + async def send_message_streaming( + self, + request: MessageSendParams, + *, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[ + Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message + ]: + """Sends a streaming message request to the agent and yields responses as they arrive.""" + pb = a2a_pb2.SendMessageRequest( + request=proto_utils.ToProto.message(request.message), + configuration=proto_utils.ToProto.message_send_configuration( + request.configuration + ), + metadata=( + proto_utils.ToProto.metadata(request.metadata) + if request.metadata + else None + ), + ) + payload = MessageToDict(pb) + payload, modified_kwargs = await self._apply_interceptors( + payload, + self._get_http_args(context), + context, + ) + + modified_kwargs.setdefault('timeout', None) + + async with aconnect_sse( + self.httpx_client, + 'POST', + f'{self.url}/v1/message:stream', + json=payload, + **modified_kwargs, + ) as event_source: + try: + async for sse in event_source.aiter_sse(): + event = a2a_pb2.StreamResponse() + Parse(sse.data, event) + yield proto_utils.FromProto.stream_response(event) + except SSEError as e: + raise A2AClientHTTPError( + 400, f'Invalid SSE response or protocol error: {e}' + ) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + except httpx.RequestError as e: + raise A2AClientHTTPError( + 503, f'Network communication error: {e}' + ) from e + + async def _send_post_request( + self, + target: str, + rpc_request_payload: dict[str, Any], + http_kwargs: dict[str, Any] | None = None, + ) -> dict[str, Any]: + try: + response = await self.httpx_client.post( + f'{self.url}{target}', + json=rpc_request_payload, + **(http_kwargs or {}), + ) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as e: + raise A2AClientHTTPError(e.response.status_code, str(e)) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + except httpx.RequestError as e: + raise A2AClientHTTPError( + 503, f'Network communication error: {e}' + ) from e + + async def _send_get_request( + self, + target: str, + query_params: dict[str, str], + http_kwargs: dict[str, Any] | None = None, + ) -> dict[str, Any]: + try: + response = await self.httpx_client.get( + f'{self.url}{target}', + params=query_params, + **(http_kwargs or {}), + ) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as e: + raise A2AClientHTTPError(e.response.status_code, str(e)) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + except httpx.RequestError as e: + raise A2AClientHTTPError( + 503, f'Network communication error: {e}' + ) from e + + async def get_task( + self, + request: TaskQueryParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + """Retrieves the current state and history of a specific task.""" + payload, modified_kwargs = await self._apply_interceptors( + request.model_dump(mode='json', exclude_none=True), + self._get_http_args(context), + context, + ) + response_data = await self._send_get_request( + f'/v1/tasks/{request.id}', + {'historyLength': str(request.history_length)} + if request.history_length + else {}, + modified_kwargs, + ) + task = a2a_pb2.Task() + ParseDict(response_data, task) + return proto_utils.FromProto.task(task) + + async def cancel_task( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + """Requests the agent to cancel a specific task.""" + pb = a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}') + payload = MessageToDict(pb) + payload, modified_kwargs = await self._apply_interceptors( + payload, + self._get_http_args(context), + context, + ) + response_data = await self._send_post_request( + f'/v1/tasks/{request.id}:cancel', payload, modified_kwargs + ) + task = a2a_pb2.Task() + ParseDict(response_data, task) + return proto_utils.FromProto.task(task) + + async def set_task_callback( + self, + request: TaskPushNotificationConfig, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Sets or updates the push notification configuration for a specific task.""" + pb = a2a_pb2.CreateTaskPushNotificationConfigRequest( + parent=f'tasks/{request.task_id}', + config_id=request.push_notification_config.id, + config=proto_utils.ToProto.task_push_notification_config(request), + ) + payload = MessageToDict(pb) + payload, modified_kwargs = await self._apply_interceptors( + payload, self._get_http_args(context), context + ) + response_data = await self._send_post_request( + f'/v1/tasks/{request.task_id}/pushNotificationConfigs', + payload, + modified_kwargs, + ) + config = a2a_pb2.TaskPushNotificationConfig() + ParseDict(response_data, config) + return proto_utils.FromProto.task_push_notification_config(config) + + async def get_task_callback( + self, + request: GetTaskPushNotificationConfigParams, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Retrieves the push notification configuration for a specific task.""" + pb = a2a_pb2.GetTaskPushNotificationConfigRequest( + name=f'tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}', + ) + payload = MessageToDict(pb) + payload, modified_kwargs = await self._apply_interceptors( + payload, + self._get_http_args(context), + context, + ) + response_data = await self._send_get_request( + f'/v1/tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}', + {}, + modified_kwargs, + ) + config = a2a_pb2.TaskPushNotificationConfig() + ParseDict(response_data, config) + return proto_utils.FromProto.task_push_notification_config(config) + + async def resubscribe( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[ + Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message + ]: + """Reconnects to get task updates.""" + http_kwargs = self._get_http_args(context) or {} + http_kwargs.setdefault('timeout', None) + + async with aconnect_sse( + self.httpx_client, + 'GET', + f'{self.url}/v1/tasks/{request.id}:subscribe', + **http_kwargs, + ) as event_source: + try: + async for sse in event_source.aiter_sse(): + event = a2a_pb2.StreamResponse() + Parse(sse.data, event) + yield proto_utils.FromProto.stream_response(event) + except SSEError as e: + raise A2AClientHTTPError( + 400, f'Invalid SSE response or protocol error: {e}' + ) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + except httpx.RequestError as e: + raise A2AClientHTTPError( + 503, f'Network communication error: {e}' + ) from e + + async def get_card( + self, + *, + context: ClientCallContext | None = None, + ) -> AgentCard: + """Retrieves the agent's card.""" + card = self.agent_card + if not card: + resolver = A2ACardResolver(self.httpx_client, self.url) + card = await resolver.get_agent_card( + http_kwargs=self._get_http_args(context) + ) + self._needs_extended_card = ( + card.supports_authenticated_extended_card + ) + self.agent_card = card + + if not self._needs_extended_card: + return card + + _, modified_kwargs = await self._apply_interceptors( + {}, + self._get_http_args(context), + context, + ) + response_data = await self._send_get_request( + '/v1/card', {}, modified_kwargs + ) + card = AgentCard.model_validate(response_data) + self.agent_card = card + self._needs_extended_card = False + return card + + async def close(self) -> None: + """Closes the httpx client.""" + await self.httpx_client.aclose() diff --git a/src/a2a/server/apps/rest/rest_app.py b/src/a2a/server/apps/rest/rest_app.py index 8d2c1625..7707a5f7 100644 --- a/src/a2a/server/apps/rest/rest_app.py +++ b/src/a2a/server/apps/rest/rest_app.py @@ -14,13 +14,8 @@ ) from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.server.request_handlers.rest_handler import ( - RESTHandler, -) -from a2a.types import ( - AgentCard, - AuthenticatedExtendedCardNotConfiguredError, -) +from a2a.server.request_handlers.rest_handler import RESTHandler +from a2a.types import AgentCard, AuthenticatedExtendedCardNotConfiguredError from a2a.utils.error_handlers import ( rest_error_handler, rest_stream_error_handler, @@ -63,21 +58,7 @@ def __init__( @rest_error_handler async def _handle_request( self, - method: Callable[ - [Request, ServerCallContext], Awaitable[dict[str, Any]] - ], - request: Request, - ) -> Response: - call_context = self._context_builder.build(request) - response = await method(request, call_context) - return JSONResponse(content=response) - - @rest_error_handler - async def _handle_list_request( - self, - method: Callable[ - [Request, ServerCallContext], Awaitable[list[dict[str, Any]]] - ], + method: Callable[[Request, ServerCallContext], Awaitable[Any]], request: Request, ) -> Response: call_context = self._context_builder.build(request) @@ -87,15 +68,13 @@ async def _handle_list_request( @rest_stream_error_handler async def _handle_streaming_request( self, - method: Callable[ - [Request, ServerCallContext], AsyncIterable[dict[str, Any]] - ], + method: Callable[[Request, ServerCallContext], AsyncIterable[Any]], request: Request, ) -> EventSourceResponse: call_context = self._context_builder.build(request) async def event_generator( - stream: AsyncIterable[dict[str, Any]], + stream: AsyncIterable[Any], ) -> AsyncIterator[dict[str, dict[str, Any]]]: async for item in stream: yield {'data': item} @@ -164,7 +143,10 @@ def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]: self._handle_streaming_request, self.handler.on_message_send_stream, ), - ('/v1/tasks/{id}:subscribe', 'POST'): functools.partial( + ('/v1/tasks/{id}:cancel', 'POST'): functools.partial( + self._handle_request, self.handler.on_cancel_task + ), + ('/v1/tasks/{id}:subscribe', 'GET'): functools.partial( self._handle_streaming_request, self.handler.on_resubscribe_to_task, ), @@ -187,10 +169,10 @@ def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]: '/v1/tasks/{id}/pushNotificationConfigs', 'GET', ): functools.partial( - self._handle_list_request, self.handler.list_push_notifications + self._handle_request, self.handler.list_push_notifications ), ('/v1/tasks', 'GET'): functools.partial( - self._handle_list_request, self.handler.list_tasks + self._handle_request, self.handler.list_tasks ), } if self.agent_card.supports_authenticated_extended_card: diff --git a/src/a2a/server/request_handlers/rest_handler.py b/src/a2a/server/request_handlers/rest_handler.py index 717217a7..179ca108 100644 --- a/src/a2a/server/request_handlers/rest_handler.py +++ b/src/a2a/server/request_handlers/rest_handler.py @@ -3,7 +3,7 @@ from collections.abc import AsyncIterable, AsyncIterator from typing import Any -from google.protobuf.json_format import MessageToDict, Parse +from google.protobuf.json_format import MessageToDict, MessageToJson, Parse from starlette.requests import Request from a2a.grpc import a2a_pb2 @@ -86,7 +86,7 @@ async def on_message_send_stream( self, request: Request, context: ServerCallContext, - ) -> AsyncIterator[dict[str, Any]]: + ) -> AsyncIterator[str]: """Handles the 'message/stream' REST method. Yields response objects as they are produced by the underlying handler's stream. @@ -96,7 +96,7 @@ async def on_message_send_stream( context: Context provided by the server. Yields: - `dict` objects containing streaming events + JSON serialized objects containing streaming events (Task, Message, TaskStatusUpdateEvent, TaskArtifactUpdateEvent) as JSON """ body = await request.body() @@ -110,7 +110,7 @@ async def on_message_send_stream( a2a_request, context ): response = proto_utils.ToProto.stream_response(event) - yield MessageToDict(response) + yield MessageToJson(response) async def on_cancel_task( self, @@ -142,7 +142,7 @@ async def on_resubscribe_to_task( self, request: Request, context: ServerCallContext, - ) -> AsyncIterable[dict[str, Any]]: + ) -> AsyncIterable[str]: """Handles the 'tasks/resubscribe' REST method. Yields response objects as they are produced by the underlying handler's stream. @@ -152,13 +152,13 @@ async def on_resubscribe_to_task( context: Context provided by the server. Yields: - `dict` containing streaming events + JSON serialized objects containing streaming events """ task_id = request.path_params['id'] async for event in self.request_handler.on_resubscribe_to_task( TaskIdParams(id=task_id), context ): - yield (MessageToDict(proto_utils.ToProto.stream_response(event))) + yield MessageToJson(proto_utils.ToProto.stream_response(event)) async def get_push_notification( self, @@ -216,7 +216,7 @@ async def set_push_notification( (due to the `@validate` decorator), A2AError if processing error is found. """ - _ = request.path_params['id'] + task_id = request.path_params['id'] body = await request.body() params = a2a_pb2.CreateTaskPushNotificationConfigRequest() Parse(body, params) @@ -225,6 +225,7 @@ async def set_push_notification( params, ) ) + a2a_request.task_id = task_id config = ( await self.request_handler.on_set_task_push_notification_config( a2a_request, context @@ -261,7 +262,7 @@ async def list_push_notifications( self, request: Request, context: ServerCallContext, - ) -> list[dict[str, Any]]: + ) -> dict[str, Any]: """Handles the 'tasks/pushNotificationConfig/list' REST method. This method is currently not implemented. @@ -282,7 +283,7 @@ async def list_tasks( self, request: Request, context: ServerCallContext, - ) -> list[dict[str, Any]]: + ) -> dict[str, Any]: """Handles the 'tasks/list' REST method. This method is currently not implemented. diff --git a/src/a2a/utils/helpers.py b/src/a2a/utils/helpers.py index 091268ba..0760690b 100644 --- a/src/a2a/utils/helpers.py +++ b/src/a2a/utils/helpers.py @@ -1,6 +1,7 @@ """General utility functions for the A2A Python SDK.""" import functools +import inspect import logging from collections.abc import Callable @@ -135,7 +136,22 @@ def validate( """ def decorator(function: Callable) -> Callable: - def wrapper(self: Any, *args, **kwargs) -> Any: + if inspect.iscoroutinefunction(function): + + @functools.wraps(function) + async def async_wrapper(self: Any, *args, **kwargs) -> Any: + if not expression(self): + final_message = error_message or str(expression) + logger.error(f'Unsupported Operation: {final_message}') + raise ServerError( + UnsupportedOperationError(message=final_message) + ) + return await function(self, *args, **kwargs) + + return async_wrapper + + @functools.wraps(function) + def sync_wrapper(self: Any, *args, **kwargs) -> Any: if not expression(self): final_message = error_message or str(expression) logger.error(f'Unsupported Operation: {final_message}') @@ -144,7 +160,7 @@ def wrapper(self: Any, *args, **kwargs) -> Any: ) return function(self, *args, **kwargs) - return wrapper + return sync_wrapper return decorator diff --git a/src/a2a/utils/proto_utils.py b/src/a2a/utils/proto_utils.py index 4d87280f..541504d5 100644 --- a/src/a2a/utils/proto_utils.py +++ b/src/a2a/utils/proto_utils.py @@ -2,6 +2,7 @@ """Utils for converting between proto and Python types.""" import json +import logging import re from typing import Any @@ -13,9 +14,14 @@ from a2a.utils.errors import ServerError +logger = logging.getLogger(__name__) + + # Regexp patterns for matching -_TASK_NAME_MATCH = r'tasks/(\w+)' -_TASK_PUSH_CONFIG_NAME_MATCH = r'tasks/(\w+)/pushNotificationConfigs/(\w+)' +_TASK_NAME_MATCH = r'tasks/([\w-]+)' +_TASK_PUSH_CONFIG_NAME_MATCH = ( + r'tasks/([\w-]+)/pushNotificationConfigs/([\w-]+)' +) class ToProto: diff --git a/tests/client/test_auth_middleware.py b/tests/client/test_auth_middleware.py index 55fb5b8b..4f53ca3f 100644 --- a/tests/client/test_auth_middleware.py +++ b/tests/client/test_auth_middleware.py @@ -1,3 +1,5 @@ +import json + from collections.abc import Callable from dataclasses import dataclass from typing import Any @@ -6,9 +8,15 @@ import pytest import respx -from a2a.client import A2AClient -from a2a.client.auth import AuthInterceptor, InMemoryContextCredentialStore -from a2a.client.middleware import ClientCallContext, ClientCallInterceptor +from a2a.client import ( + AuthInterceptor, + Client, + ClientCallContext, + ClientCallInterceptor, + ClientConfig, + ClientFactory, + InMemoryContextCredentialStore, +) from a2a.types import ( APIKeySecurityScheme, AgentCapabilities, @@ -17,14 +25,13 @@ HTTPAuthSecurityScheme, In, Message, - MessageSendParams, OAuth2SecurityScheme, OAuthFlows, OpenIdConnectSecurityScheme, Role, SecurityScheme, - SendMessageRequest, SendMessageSuccessResponse, + TransportProtocol, ) @@ -49,10 +56,11 @@ async def intercept( return request_payload, http_kwargs -def build_success_response() -> dict: - """Creates a valid JSON-RPC success response as dict.""" - return SendMessageSuccessResponse( - id='1', +def build_success_response(request: httpx.Request) -> httpx.Response: + """Creates a valid JSON-RPC success response based on the request.""" + request_payload = json.loads(request.content) + response_payload = SendMessageSuccessResponse( + id=request_payload['id'], jsonrpc='2.0', result=Message( kind='message', @@ -61,41 +69,33 @@ def build_success_response() -> dict: parts=[], ), ).model_dump(mode='json') + return httpx.Response(200, json=response_payload) -def build_send_message_request() -> SendMessageRequest: - """Builds a minimal SendMessageRequest.""" - return SendMessageRequest( - id='1', - params=MessageSendParams( - message=Message( - message_id='msg1', - role=Role.user, - parts=[], - ) - ), +def build_message() -> Message: + """Builds a minimal Message.""" + return Message( + message_id='msg1', + role=Role.user, + parts=[], ) async def send_message( - client: A2AClient, + client: Client, url: str, session_id: str | None = None, ) -> httpx.Request: """Mocks the response and sends a message using the client.""" - respx.post(url).mock( - return_value=httpx.Response( - 200, - json=build_success_response(), - ) - ) + respx.post(url).mock(side_effect=build_success_response) context = ClientCallContext( state={'sessionId': session_id} if session_id else {} ) - await client.send_message( - request=build_send_message_request(), + async for _ in client.send_message( + request=build_message(), context=context, - ) + ): + pass return respx.calls.last.request @@ -170,11 +170,26 @@ async def test_client_with_simple_interceptor(): """ url = 'http://agent.com/rpc' interceptor = HeaderInterceptor('X-Test-Header', 'Test-Value-123') + card = AgentCard( + url=url, + name='testbot', + description='test bot', + version='1.0', + default_input_modes=[], + default_output_modes=[], + skills=[], + capabilities=AgentCapabilities(), + preferred_transport=TransportProtocol.jsonrpc, + ) async with httpx.AsyncClient() as http_client: - client = A2AClient( - httpx_client=http_client, url=url, interceptors=[interceptor] + config = ClientConfig( + httpx_client=http_client, + supported_transports=[TransportProtocol.jsonrpc], ) + factory = ClientFactory(config) + client = factory.create(card, interceptors=[interceptor]) + request = await send_message(client, url) assert request.headers['x-test-header'] == 'Test-Value-123' @@ -293,14 +308,17 @@ async def test_auth_interceptor_variants(test_case, store): root=test_case.security_scheme ) }, + preferred_transport=TransportProtocol.jsonrpc, ) async with httpx.AsyncClient() as http_client: - client = A2AClient( + config = ClientConfig( httpx_client=http_client, - agent_card=agent_card, - interceptors=[auth_interceptor], + supported_transports=[TransportProtocol.jsonrpc], ) + factory = ClientFactory(config) + client = factory.create(agent_card, interceptors=[auth_interceptor]) + request = await send_message( client, test_case.url, test_case.session_id ) diff --git a/tests/client/test_client_factory.py b/tests/client/test_client_factory.py new file mode 100644 index 00000000..d615bbff --- /dev/null +++ b/tests/client/test_client_factory.py @@ -0,0 +1,105 @@ +"""Tests for the ClientFactory.""" + +import httpx +import pytest + +from a2a.client import ClientConfig, ClientFactory +from a2a.client.transports import JsonRpcTransport, RestTransport +from a2a.types import ( + AgentCapabilities, + AgentCard, + AgentInterface, + TransportProtocol, +) + + +@pytest.fixture +def base_agent_card() -> AgentCard: + """Provides a base AgentCard for tests.""" + return AgentCard( + name='Test Agent', + description='An agent for testing.', + url='http://primary-url.com', + version='1.0.0', + capabilities=AgentCapabilities(), + skills=[], + default_input_modes=[], + default_output_modes=[], + preferred_transport=TransportProtocol.jsonrpc, + ) + + +def test_client_factory_selects_preferred_transport(base_agent_card: AgentCard): + """Verify that the factory selects the preferred transport by default.""" + config = ClientConfig( + httpx_client=httpx.AsyncClient(), + supported_transports=[ + TransportProtocol.jsonrpc, + TransportProtocol.http_json, + ], + ) + factory = ClientFactory(config) + client = factory.create(base_agent_card) + + assert isinstance(client._transport, JsonRpcTransport) + assert client._transport.url == 'http://primary-url.com' + + +def test_client_factory_selects_secondary_transport_url( + base_agent_card: AgentCard, +): + """Verify that the factory selects the correct URL for a secondary transport.""" + base_agent_card.additional_interfaces = [ + AgentInterface( + transport=TransportProtocol.http_json, + url='http://secondary-url.com', + ) + ] + # Client prefers REST, which is available as a secondary transport + config = ClientConfig( + httpx_client=httpx.AsyncClient(), + supported_transports=[ + TransportProtocol.http_json, + TransportProtocol.jsonrpc, + ], + use_client_preference=True, + ) + factory = ClientFactory(config) + client = factory.create(base_agent_card) + + assert isinstance(client._transport, RestTransport) + assert client._transport.url == 'http://secondary-url.com' + + +def test_client_factory_server_preference(base_agent_card: AgentCard): + """Verify that the factory respects server transport preference.""" + base_agent_card.preferred_transport = TransportProtocol.http_json + base_agent_card.additional_interfaces = [ + AgentInterface( + transport=TransportProtocol.jsonrpc, url='http://secondary-url.com' + ) + ] + # Client supports both, but server prefers REST + config = ClientConfig( + httpx_client=httpx.AsyncClient(), + supported_transports=[ + TransportProtocol.jsonrpc, + TransportProtocol.http_json, + ], + ) + factory = ClientFactory(config) + client = factory.create(base_agent_card) + + assert isinstance(client._transport, RestTransport) + assert client._transport.url == 'http://primary-url.com' + + +def test_client_factory_no_compatible_transport(base_agent_card: AgentCard): + """Verify that the factory raises an error if no compatible transport is found.""" + config = ClientConfig( + httpx_client=httpx.AsyncClient(), + supported_transports=[TransportProtocol.grpc], + ) + factory = ClientFactory(config) + with pytest.raises(ValueError, match='no compatible transports found'): + factory.create(base_agent_card) diff --git a/tests/client/test_grpc_client.py b/tests/client/test_grpc_client.py index 26967d73..7a9d6830 100644 --- a/tests/client/test_grpc_client.py +++ b/tests/client/test_grpc_client.py @@ -2,7 +2,7 @@ import pytest -from a2a.client import A2AGrpcClient +from a2a.client.transports.grpc import GrpcTransport from a2a.grpc import a2a_pb2, a2a_pb2_grpc from a2a.types import ( AgentCapabilities, @@ -51,11 +51,11 @@ def sample_agent_card() -> AgentCard: @pytest.fixture -def grpc_client( +def grpc_transport( mock_grpc_stub: AsyncMock, sample_agent_card: AgentCard -) -> A2AGrpcClient: - """Provides an A2AGrpcClient instance.""" - return A2AGrpcClient(grpc_stub=mock_grpc_stub, agent_card=sample_agent_card) +) -> GrpcTransport: + """Provides a GrpcTransport instance.""" + return GrpcTransport(grpc_stub=mock_grpc_stub, agent_card=sample_agent_card) @pytest.fixture @@ -92,7 +92,7 @@ def sample_message() -> Message: @pytest.mark.asyncio async def test_send_message_task_response( - grpc_client: A2AGrpcClient, + grpc_transport: GrpcTransport, mock_grpc_stub: AsyncMock, sample_message_send_params: MessageSendParams, sample_task: Task, @@ -102,7 +102,7 @@ async def test_send_message_task_response( task=proto_utils.ToProto.task(sample_task) ) - response = await grpc_client.send_message(sample_message_send_params) + response = await grpc_transport.send_message(sample_message_send_params) mock_grpc_stub.SendMessage.assert_awaited_once() assert isinstance(response, Task) @@ -111,13 +111,13 @@ async def test_send_message_task_response( @pytest.mark.asyncio async def test_get_task( - grpc_client: A2AGrpcClient, mock_grpc_stub: AsyncMock, sample_task: Task + grpc_transport: GrpcTransport, mock_grpc_stub: AsyncMock, sample_task: Task ): """Test retrieving a task.""" mock_grpc_stub.GetTask.return_value = proto_utils.ToProto.task(sample_task) params = TaskQueryParams(id=sample_task.id) - response = await grpc_client.get_task(params) + response = await grpc_transport.get_task(params) mock_grpc_stub.GetTask.assert_awaited_once_with( a2a_pb2.GetTaskRequest(name=f'tasks/{sample_task.id}') @@ -127,7 +127,7 @@ async def test_get_task( @pytest.mark.asyncio async def test_cancel_task( - grpc_client: A2AGrpcClient, mock_grpc_stub: AsyncMock, sample_task: Task + grpc_transport: GrpcTransport, mock_grpc_stub: AsyncMock, sample_task: Task ): """Test cancelling a task.""" cancelled_task = sample_task.model_copy() @@ -137,7 +137,7 @@ async def test_cancel_task( ) params = TaskIdParams(id=sample_task.id) - response = await grpc_client.cancel_task(params) + response = await grpc_transport.cancel_task(params) mock_grpc_stub.CancelTask.assert_awaited_once_with( a2a_pb2.CancelTaskRequest(name=f'tasks/{sample_task.id}') diff --git a/tests/client/test_jsonrpc_client.py b/tests/client/test_jsonrpc_client.py index 61a637f6..6f985b77 100644 --- a/tests/client/test_jsonrpc_client.py +++ b/tests/client/test_jsonrpc_client.py @@ -2,7 +2,7 @@ from collections.abc import AsyncGenerator from typing import Any -from unittest.mock import ANY, AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import httpx import pytest @@ -11,40 +11,24 @@ from a2a.client import ( A2ACardResolver, - A2AClient, A2AClientHTTPError, A2AClientJSONError, A2AClientTimeoutError, create_text_message_object, ) +from a2a.client.transports.jsonrpc import JsonRpcTransport from a2a.types import ( - A2ARequest, AgentCapabilities, AgentCard, AgentSkill, - CancelTaskRequest, - CancelTaskResponse, - CancelTaskSuccessResponse, - GetTaskPushNotificationConfigRequest, - GetTaskPushNotificationConfigResponse, - GetTaskPushNotificationConfigSuccessResponse, - GetTaskRequest, - GetTaskResponse, InvalidParamsError, - JSONRPCErrorResponse, + Message, MessageSendParams, PushNotificationConfig, Role, - SendMessageRequest, - SendMessageResponse, SendMessageSuccessResponse, - SendStreamingMessageRequest, - SendStreamingMessageResponse, - SetTaskPushNotificationConfigRequest, - SetTaskPushNotificationConfigResponse, - SetTaskPushNotificationConfigSuccessResponse, + Task, TaskIdParams, - TaskNotCancelableError, TaskPushNotificationConfig, TaskQueryParams, ) @@ -117,15 +101,13 @@ def mock_httpx_client() -> AsyncMock: @pytest.fixture def mock_agent_card() -> MagicMock: mock = MagicMock(spec=AgentCard, url='http://agent.example.com/api') - # The attribute is accessed in the client's __init__ to determine if an - # extended card needs to be fetched. mock.supports_authenticated_extended_card = False return mock async def async_iterable_from_list( items: list[ServerSentEvent], -) -> AsyncGenerator[ServerSentEvent]: +) -> AsyncGenerator[ServerSentEvent, None]: """Helper to create an async iterable from a list.""" for item in items: yield item @@ -135,9 +117,7 @@ class TestA2ACardResolver: BASE_URL = 'http://example.com' AGENT_CARD_PATH = AGENT_CARD_WELL_KNOWN_PATH FULL_AGENT_CARD_URL = f'{BASE_URL}{AGENT_CARD_PATH}' - EXTENDED_AGENT_CARD_PATH = ( - '/agent/authenticatedExtendedCard' # Default path - ) + EXTENDED_AGENT_CARD_PATH = '/agent/authenticatedExtendedCard' @pytest.mark.asyncio async def test_init_parameters_stored_correctly( @@ -154,7 +134,6 @@ async def test_init_parameters_stored_correctly( assert resolver.agent_card_path == custom_path.lstrip('/') assert resolver.httpx_client == mock_httpx_client - # Test default agent_card_path resolver_default_path = A2ACardResolver( httpx_client=mock_httpx_client, base_url=base_url, @@ -168,13 +147,10 @@ async def test_init_parameters_stored_correctly( async def test_init_strips_slashes(self, mock_httpx_client: AsyncMock): resolver = A2ACardResolver( httpx_client=mock_httpx_client, - base_url='http://example.com/', # With trailing slash - agent_card_path='/.well-known/agent.json/', # With leading/trailing slash + base_url='http://example.com/', + agent_card_path='/.well-known/agent.json/', ) - assert ( - resolver.base_url == 'http://example.com' - ) # Trailing slash stripped - # constructor lstrips agent_card_path, but keeps trailing if provided + assert resolver.base_url == 'http://example.com' assert resolver.agent_card_path == '.well-known/agent.json/' @pytest.mark.asyncio @@ -199,7 +175,6 @@ async def test_get_agent_card_success_public_only( mock_response.raise_for_status.assert_called_once() assert isinstance(agent_card, AgentCard) assert agent_card == AGENT_CARD - # Ensure only one call was made (for the public card) assert mock_httpx_client.get.call_count == 1 @pytest.mark.asyncio @@ -211,8 +186,6 @@ async def test_get_agent_card_success_with_specified_path_for_extended_card( extended_card_response.json.return_value = ( AGENT_CARD_EXTENDED.model_dump(mode='json') ) - - # Mock the single call for the extended card mock_httpx_client.get.return_value = extended_card_response resolver = A2ACardResolver( @@ -221,7 +194,6 @@ async def test_get_agent_card_success_with_specified_path_for_extended_card( agent_card_path=self.AGENT_CARD_PATH, ) - # Fetch the extended card by providing its relative path and example auth auth_kwargs = {'headers': {'Authorization': 'Bearer test token'}} agent_card_result = await resolver.get_agent_card( relative_card_path=self.EXTENDED_AGENT_CARD_PATH, @@ -235,11 +207,8 @@ async def test_get_agent_card_success_with_specified_path_for_extended_card( expected_extended_url, **auth_kwargs ) extended_card_response.raise_for_status.assert_called_once() - assert isinstance(agent_card_result, AgentCard) - assert ( - agent_card_result == AGENT_CARD_EXTENDED - ) # Should return the extended card + assert agent_card_result == AGENT_CARD_EXTENDED @pytest.mark.asyncio async def test_get_agent_card_validation_error( @@ -247,7 +216,6 @@ async def test_get_agent_card_validation_error( ): mock_response = AsyncMock(spec=httpx.Response) mock_response.status_code = 200 - # Data that will cause a Pydantic ValidationError mock_response.json.return_value = { 'invalid_field': 'value', 'name': 'Test Agent', @@ -257,31 +225,23 @@ async def test_get_agent_card_validation_error( resolver = A2ACardResolver( httpx_client=mock_httpx_client, base_url=self.BASE_URL ) - # The call that is expected to raise an error should be within pytest.raises with pytest.raises(A2AClientJSONError) as exc_info: - await resolver.get_agent_card() # Fetches from default path + await resolver.get_agent_card() assert ( f'Failed to validate agent card structure from {self.FULL_AGENT_CARD_URL}' in str(exc_info.value) ) - assert 'invalid_field' in str( - exc_info.value - ) # Check if Pydantic error details are present - assert ( - mock_httpx_client.get.call_count == 1 - ) # Should only be called once + assert 'invalid_field' in str(exc_info.value) + assert mock_httpx_client.get.call_count == 1 @pytest.mark.asyncio async def test_get_agent_card_http_status_error( self, mock_httpx_client: AsyncMock ): - mock_response = MagicMock( - spec=httpx.Response - ) # Use MagicMock for response attribute + mock_response = MagicMock(spec=httpx.Response) mock_response.status_code = 404 mock_response.text = 'Not Found' - http_status_error = httpx.HTTPStatusError( 'Not Found', request=MagicMock(), response=mock_response ) @@ -310,7 +270,6 @@ async def test_get_agent_card_json_decode_error( ): mock_response = AsyncMock(spec=httpx.Response) mock_response.status_code = 200 - # Define json_error before using it json_error = json.JSONDecodeError('Expecting value', 'doc', 0) mock_response.json.side_effect = json_error mock_httpx_client.get.return_value = mock_response @@ -324,7 +283,6 @@ async def test_get_agent_card_json_decode_error( with pytest.raises(A2AClientJSONError) as exc_info: await resolver.get_agent_card() - # Assertions using exc_info must be after the with block assert ( f'Failed to parse JSON for agent card from {self.FULL_AGENT_CARD_URL}' in str(exc_info.value) @@ -357,430 +315,153 @@ async def test_get_agent_card_request_error( mock_httpx_client.get.assert_called_once_with(self.FULL_AGENT_CARD_URL) -class TestA2AClient: +class TestJsonRpcTransport: AGENT_URL = 'http://agent.example.com/api' def test_init_with_agent_card( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): - client = A2AClient( + client = JsonRpcTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card ) assert client.url == mock_agent_card.url assert client.httpx_client == mock_httpx_client def test_init_with_url(self, mock_httpx_client: AsyncMock): - client = A2AClient(httpx_client=mock_httpx_client, url=self.AGENT_URL) + client = JsonRpcTransport( + httpx_client=mock_httpx_client, url=self.AGENT_URL + ) assert client.url == self.AGENT_URL assert client.httpx_client == mock_httpx_client - def test_init_with_agent_card_and_url_prioritizes_agent_card( + def test_init_with_agent_card_and_url_prioritizes_url( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): - client = A2AClient( + client = JsonRpcTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card, url='http://otherurl.com', ) - assert ( - client.url == mock_agent_card.url - ) # Agent card URL should be used + assert client.url == 'http://otherurl.com' def test_init_raises_value_error_if_no_card_or_url( self, mock_httpx_client: AsyncMock ): with pytest.raises(ValueError) as exc_info: - A2AClient(httpx_client=mock_httpx_client) + JsonRpcTransport(httpx_client=mock_httpx_client) assert 'Must provide either agent_card or url' in str(exc_info.value) @pytest.mark.asyncio - async def test_get_client_from_agent_card_url_success( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - base_url = 'http://example.com' - agent_card_path = '/.well-known/custom-agent.json' - resolver_kwargs = {'timeout': 30} - - mock_resolver_instance = AsyncMock(spec=A2ACardResolver) - mock_resolver_instance.get_agent_card.return_value = mock_agent_card - - with patch( - 'a2a.client.jsonrpc_client.A2ACardResolver', - return_value=mock_resolver_instance, - ) as mock_resolver_class: - client = await A2AClient.get_client_from_agent_card_url( - httpx_client=mock_httpx_client, - base_url=base_url, - agent_card_path=agent_card_path, - http_kwargs=resolver_kwargs, - ) - - mock_resolver_class.assert_called_once_with( - mock_httpx_client, - base_url=base_url, - agent_card_path=agent_card_path, - ) - mock_resolver_instance.get_agent_card.assert_called_once_with( - http_kwargs=resolver_kwargs, - # relative_card_path=None is implied by not passing it - ) - assert isinstance(client, A2AClient) - assert client.url == mock_agent_card.url - assert client.httpx_client == mock_httpx_client - - @pytest.mark.asyncio - async def test_get_client_from_agent_card_url_resolver_error( - self, mock_httpx_client: AsyncMock - ): - error_to_raise = A2AClientHTTPError(404, 'Agent card not found') - with patch( - 'a2a.client.jsonrpc_client.A2ACardResolver.get_agent_card', - new_callable=AsyncMock, - side_effect=error_to_raise, - ): - with pytest.raises(A2AClientHTTPError) as exc_info: - await A2AClient.get_client_from_agent_card_url( - httpx_client=mock_httpx_client, - base_url='http://example.com', - ) - assert exc_info.value == error_to_raise - - @pytest.mark.asyncio - async def test_send_message_success_use_request( + async def test_send_message_success( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): - client = A2AClient( + client = JsonRpcTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card ) - params = MessageSendParams( message=create_text_message_object(content='Hello') ) - - request = SendMessageRequest(id=123, params=params) - success_response = create_text_message_object( role=Role.agent, content='Hi there!' - ).model_dump(exclude_none=True) + ) + rpc_response = SendMessageSuccessResponse( + id='123', jsonrpc='2.0', result=success_response + ) + response = httpx.Response( + 200, json=rpc_response.model_dump(mode='json') + ) + response.request = httpx.Request('POST', 'http://agent.example.com/api') + mock_httpx_client.post.return_value = response - rpc_response: dict[str, Any] = { - 'id': 123, - 'jsonrpc': '2.0', - 'result': success_response, - } + response = await client.send_message(request=params) - with patch.object( - client, '_send_request', new_callable=AsyncMock - ) as mock_send_req: - mock_send_req.return_value = rpc_response - response = await client.send_message( - request=request, http_kwargs={'timeout': 10} - ) - - assert mock_send_req.call_count == 1 - called_args, called_kwargs = mock_send_req.call_args - assert not called_kwargs # no kwargs to _send_request - assert len(called_args) == 2 - json_rpc_request: dict[str, Any] = called_args[0] - assert isinstance(json_rpc_request['id'], int) - http_kwargs: dict[str, Any] = called_args[1] - assert http_kwargs['timeout'] == 10 - - a2a_request_arg = A2ARequest.model_validate(json_rpc_request) - assert isinstance(a2a_request_arg.root, SendMessageRequest) - assert isinstance(a2a_request_arg.root.params, MessageSendParams) - - assert a2a_request_arg.root.params.model_dump( - exclude_none=True - ) == params.model_dump(exclude_none=True) - - assert isinstance(response, SendMessageResponse) - assert isinstance(response.root, SendMessageSuccessResponse) - assert ( - response.root.result.model_dump(exclude_none=True) - == success_response - ) + assert isinstance(response, Message) + assert response.model_dump() == success_response.model_dump() @pytest.mark.asyncio async def test_send_message_error_response( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): - client = A2AClient( + client = JsonRpcTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card ) - params = MessageSendParams( message=create_text_message_object(content='Hello') ) - - request = SendMessageRequest(id=123, params=params) - error_response = InvalidParamsError() - - rpc_response: dict[str, Any] = { - 'id': 123, + rpc_response = { + 'id': '123', 'jsonrpc': '2.0', 'error': error_response.model_dump(exclude_none=True), } + mock_httpx_client.post.return_value.json.return_value = rpc_response - with patch.object( - client, '_send_request', new_callable=AsyncMock - ) as mock_send_req: - mock_send_req.return_value = rpc_response - response = await client.send_message(request=request) - - assert isinstance(response, SendMessageResponse) - assert isinstance(response.root, JSONRPCErrorResponse) - assert response.root.error.model_dump( - exclude_none=True - ) == InvalidParamsError().model_dump(exclude_none=True) + with pytest.raises(Exception): + await client.send_message(request=params) @pytest.mark.asyncio - @patch('a2a.client.jsonrpc_client.aconnect_sse') - async def test_send_message_streaming_success_request( + @patch('a2a.client.transports.jsonrpc.aconnect_sse') + async def test_send_message_streaming_success( self, mock_aconnect_sse: AsyncMock, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock, ): - client = A2AClient( + client = JsonRpcTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card ) params = MessageSendParams( message=create_text_message_object(content='Hello stream') ) - - request = SendStreamingMessageRequest(id=123, params=params) - - mock_stream_response_1_dict: dict[str, Any] = { - 'id': 'stream_id_123', - 'jsonrpc': '2.0', - 'result': create_text_message_object( + mock_stream_response_1 = SendMessageSuccessResponse( + id='stream_id_123', + jsonrpc='2.0', + result=create_text_message_object( content='First part ', role=Role.agent - ).model_dump(mode='json', exclude_none=True), - } - mock_stream_response_2_dict: dict[str, Any] = { - 'id': 'stream_id_123', - 'jsonrpc': '2.0', - 'result': create_text_message_object( - content='second part ', role=Role.agent - ).model_dump(mode='json', exclude_none=True), - } - - sse_event_1 = ServerSentEvent( - data=json.dumps(mock_stream_response_1_dict) - ) - sse_event_2 = ServerSentEvent( - data=json.dumps(mock_stream_response_2_dict) - ) - - mock_event_source = AsyncMock(spec=EventSource) - with patch.object(mock_event_source, 'aiter_sse') as mock_aiter_sse: - mock_aiter_sse.return_value = async_iterable_from_list( - [sse_event_1, sse_event_2] - ) - mock_aconnect_sse.return_value.__aenter__.return_value = ( - mock_event_source - ) - - results: list[Any] = [] - async for response in client.send_message_streaming( - request=request - ): - results.append(response) - - assert len(results) == 2 - assert isinstance(results[0], SendStreamingMessageResponse) - # Assuming SendStreamingMessageResponse is a RootModel like SendMessageResponse - assert results[0].root.id == 'stream_id_123' - assert ( - results[0].root.result.model_dump( # type: ignore - mode='json', exclude_none=True - ) - == mock_stream_response_1_dict['result'] - ) - - assert isinstance(results[1], SendStreamingMessageResponse) - assert results[1].root.id == 'stream_id_123' - assert ( - results[1].root.result.model_dump( # type: ignore - mode='json', exclude_none=True - ) - == mock_stream_response_2_dict['result'] - ) - - mock_aconnect_sse.assert_called_once() - call_args, call_kwargs = mock_aconnect_sse.call_args - assert call_args[0] == mock_httpx_client - assert call_args[1] == 'POST' - assert call_args[2] == mock_agent_card.url - - sent_json_payload = call_kwargs['json'] - assert sent_json_payload['method'] == 'message/stream' - assert sent_json_payload['params'] == params.model_dump( - mode='json', exclude_none=True - ) - assert ( - call_kwargs['timeout'] is None - ) # Default timeout for streaming - - @pytest.mark.asyncio - @patch('a2a.client.jsonrpc_client.aconnect_sse') - async def test_send_message_streaming_http_kwargs_passed( - self, - mock_aconnect_sse: AsyncMock, - mock_httpx_client: AsyncMock, - mock_agent_card: MagicMock, - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - params = MessageSendParams( - message=create_text_message_object(content='Stream with kwargs') - ) - request = SendStreamingMessageRequest(id='kwarg_req', params=params) - custom_kwargs = { - 'headers': {'X-Custom-Header': 'TestValue'}, - 'timeout': 60, - } - - # Setup mock_aconnect_sse to behave minimally - mock_event_source = AsyncMock(spec=EventSource) - mock_event_source.aiter_sse.return_value = async_iterable_from_list( - [] - ) # No events needed for this test - mock_aconnect_sse.return_value.__aenter__.return_value = ( - mock_event_source - ) - - async for _ in client.send_message_streaming( - request=request, http_kwargs=custom_kwargs - ): - pass # We just want to check the call to aconnect_sse - - mock_aconnect_sse.assert_called_once() - _, called_kwargs = mock_aconnect_sse.call_args - assert called_kwargs['headers'] == custom_kwargs['headers'] - assert ( - called_kwargs['timeout'] == custom_kwargs['timeout'] - ) # Ensure custom timeout is used - - @pytest.mark.asyncio - @patch('a2a.client.jsonrpc_client.aconnect_sse') - async def test_send_message_streaming_sse_error_handling( - self, - mock_aconnect_sse: AsyncMock, - mock_httpx_client: AsyncMock, - mock_agent_card: MagicMock, - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - request = SendStreamingMessageRequest( - id='sse_err_req', - params=MessageSendParams( - message=create_text_message_object(content='SSE error test') ), ) - - # Configure the mock aconnect_sse to raise SSEError when aiter_sse is called - mock_event_source = AsyncMock(spec=EventSource) - mock_event_source.aiter_sse.side_effect = SSEError( - 'Simulated SSE protocol error' - ) - mock_aconnect_sse.return_value.__aenter__.return_value = ( - mock_event_source + mock_stream_response_2 = SendMessageSuccessResponse( + id='stream_id_123', + jsonrpc='2.0', + result=create_text_message_object( + content='second part ', role=Role.agent + ), ) - - with pytest.raises(A2AClientHTTPError) as exc_info: - async for _ in client.send_message_streaming(request=request): - pass - - assert exc_info.value.status_code == 400 # As per client implementation - assert 'Invalid SSE response or protocol error' in str(exc_info.value) - assert 'Simulated SSE protocol error' in str(exc_info.value) - - @pytest.mark.asyncio - @patch('a2a.client.jsonrpc_client.aconnect_sse') - async def test_send_message_streaming_json_decode_error_handling( - self, - mock_aconnect_sse: AsyncMock, - mock_httpx_client: AsyncMock, - mock_agent_card: MagicMock, - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card + sse_event_1 = ServerSentEvent( + data=mock_stream_response_1.model_dump_json() ) - request = SendStreamingMessageRequest( - id='json_err_req', - params=MessageSendParams( - message=create_text_message_object(content='JSON error test') - ), + sse_event_2 = ServerSentEvent( + data=mock_stream_response_2.model_dump_json() ) - - # Malformed JSON event - malformed_sse_event = ServerSentEvent(data='not valid json') - mock_event_source = AsyncMock(spec=EventSource) - # json.loads will be called on "not valid json" and raise JSONDecodeError mock_event_source.aiter_sse.return_value = async_iterable_from_list( - [malformed_sse_event] + [sse_event_1, sse_event_2] ) mock_aconnect_sse.return_value.__aenter__.return_value = ( mock_event_source ) - with pytest.raises(A2AClientJSONError) as exc_info: - async for _ in client.send_message_streaming(request=request): - pass - - assert 'Expecting value: line 1 column 1 (char 0)' in str( - exc_info.value - ) # Example of JSONDecodeError message - - @pytest.mark.asyncio - @patch('a2a.client.jsonrpc_client.aconnect_sse') - async def test_send_message_streaming_httpx_request_error_handling( - self, - mock_aconnect_sse: AsyncMock, - mock_httpx_client: AsyncMock, - mock_agent_card: MagicMock, - ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card - ) - request = SendStreamingMessageRequest( - id='httpx_err_req', - params=MessageSendParams( - message=create_text_message_object(content='httpx error test') - ), - ) + results = [ + item async for item in client.send_message_streaming(request=params) + ] - # Configure aconnect_sse itself to raise httpx.RequestError (e.g., during connection) - # This needs to be raised when aconnect_sse is entered or iterated. - # One way is to make the context manager's __aenter__ raise it, or aiter_sse. - # For simplicity, let's make aiter_sse raise it, as if the error occurs after connection. - mock_event_source = AsyncMock(spec=EventSource) - mock_event_source.aiter_sse.side_effect = httpx.RequestError( - 'Simulated network error', request=MagicMock() + assert len(results) == 2 + assert isinstance(results[0], Message) + assert ( + results[0].model_dump() + == mock_stream_response_1.result.model_dump() ) - mock_aconnect_sse.return_value.__aenter__.return_value = ( - mock_event_source + assert isinstance(results[1], Message) + assert ( + results[1].model_dump() + == mock_stream_response_2.result.model_dump() ) - with pytest.raises(A2AClientHTTPError) as exc_info: - async for _ in client.send_message_streaming(request=request): - pass - - assert exc_info.value.status_code == 503 # As per client implementation - assert 'Network communication error' in str(exc_info.value) - assert 'Simulated network error' in str(exc_info.value) - @pytest.mark.asyncio async def test_send_request_http_status_error( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): - client = A2AClient( + client = JsonRpcTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card ) mock_response = MagicMock(spec=httpx.Response) @@ -801,7 +482,7 @@ async def test_send_request_http_status_error( async def test_send_request_json_decode_error( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): - client = A2AClient( + client = JsonRpcTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card ) mock_response = AsyncMock(spec=httpx.Response) @@ -819,7 +500,7 @@ async def test_send_request_json_decode_error( async def test_send_request_httpx_request_error( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): - client = A2AClient( + client = JsonRpcTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card ) request_error = httpx.RequestError('Network issue', request=MagicMock()) @@ -833,468 +514,274 @@ async def test_send_request_httpx_request_error( assert 'Network issue' in str(exc_info.value) @pytest.mark.asyncio - async def test_set_task_callback_success( + async def test_send_message_client_timeout( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card + mock_httpx_client.post.side_effect = httpx.ReadTimeout( + 'Request timed out' ) - task_id_val = 'task_set_cb_001' - # Correctly create the PushNotificationConfig (inner model) - push_config_payload = PushNotificationConfig( - url='https://callback.example.com/taskupdate' + client = JsonRpcTransport( + httpx_client=mock_httpx_client, agent_card=mock_agent_card ) - # Correctly create the TaskPushNotificationConfig (outer model) - params_model = TaskPushNotificationConfig( - task_id=task_id_val, push_notification_config=push_config_payload + params = MessageSendParams( + message=create_text_message_object(content='Hello') ) - # request.id will be generated by the client method if not provided - request = SetTaskPushNotificationConfigRequest( - id='', params=params_model - ) # Test ID auto-generation - - # The result for a successful set operation is the same config - rpc_response_payload: dict[str, Any] = { - 'id': ANY, # Will be checked against generated ID - 'jsonrpc': '2.0', - 'result': params_model.model_dump(mode='json', exclude_none=True), - } + with pytest.raises(A2AClientTimeoutError) as exc_info: + await client.send_message(request=params) - with ( - patch.object( - client, '_send_request', new_callable=AsyncMock - ) as mock_send_req, - patch( - 'a2a.client.jsonrpc_client.uuid4', - return_value=MagicMock(hex='testuuid'), - ) as mock_uuid, - ): - # Capture the generated ID for assertion - generated_id = str(mock_uuid.return_value) - rpc_response_payload['id'] = ( - generated_id # Ensure mock response uses the generated ID - ) - mock_send_req.return_value = rpc_response_payload - - response = await client.set_task_callback(request=request) - - mock_send_req.assert_called_once() - called_args, _ = mock_send_req.call_args - sent_json_payload = called_args[0] - - assert sent_json_payload['id'] == generated_id - assert ( - sent_json_payload['method'] - == 'tasks/pushNotificationConfig/set' - ) - assert sent_json_payload['params'] == params_model.model_dump( - mode='json', exclude_none=True - ) - - assert isinstance(response, SetTaskPushNotificationConfigResponse) - assert isinstance( - response.root, SetTaskPushNotificationConfigSuccessResponse - ) - assert response.root.id == generated_id - assert response.root.result.model_dump( - mode='json', exclude_none=True - ) == params_model.model_dump(mode='json', exclude_none=True) + assert 'Client Request timed out' in str(exc_info.value) @pytest.mark.asyncio - async def test_set_task_callback_error_response( + async def test_get_task_success( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): - client = A2AClient( + client = JsonRpcTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card ) - req_id = 'set_cb_err_req' - push_config_payload = PushNotificationConfig(url='https://errors.com') - params_model = TaskPushNotificationConfig( - task_id='task_err_cb', push_notification_config=push_config_payload - ) - request = SetTaskPushNotificationConfigRequest( - id=req_id, params=params_model - ) - error_details = InvalidParamsError(message='Invalid callback URL') - - rpc_response_payload: dict[str, Any] = { - 'id': req_id, + params = TaskQueryParams(id='task-abc') + rpc_response = { + 'id': '123', 'jsonrpc': '2.0', - 'error': error_details.model_dump(mode='json', exclude_none=True), + 'result': MINIMAL_TASK, } - with patch.object( client, '_send_request', new_callable=AsyncMock - ) as mock_send_req: - mock_send_req.return_value = rpc_response_payload - response = await client.set_task_callback(request=request) + ) as mock_send_request: + mock_send_request.return_value = rpc_response + response = await client.get_task(request=params) - assert isinstance(response, SetTaskPushNotificationConfigResponse) - assert isinstance(response.root, JSONRPCErrorResponse) - assert response.root.error.model_dump( - mode='json', exclude_none=True - ) == error_details.model_dump(mode='json', exclude_none=True) - assert response.root.id == req_id + assert isinstance(response, Task) + assert ( + response.model_dump() + == Task.model_validate(MINIMAL_TASK).model_dump() + ) + mock_send_request.assert_called_once() + sent_payload = mock_send_request.call_args.args[0] + assert sent_payload['method'] == 'tasks/get' @pytest.mark.asyncio - async def test_set_task_callback_http_kwargs_passed( + async def test_cancel_task_success( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): - client = A2AClient( + client = JsonRpcTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card ) - push_config_payload = PushNotificationConfig(url='https://kwargs.com') - params_model = TaskPushNotificationConfig( - task_id='task_cb_kwargs', - push_notification_config=push_config_payload, - ) - request = SetTaskPushNotificationConfigRequest( - id='cb_kwargs_req', params=params_model - ) - custom_kwargs = {'headers': {'X-Callback-Token': 'secret'}} - - # Minimal successful response - rpc_response_payload: dict[str, Any] = { - 'id': 'cb_kwargs_req', + params = TaskIdParams(id='task-abc') + rpc_response = { + 'id': '123', 'jsonrpc': '2.0', - 'result': params_model.model_dump(mode='json'), + 'result': MINIMAL_CANCELLED_TASK, } - with patch.object( client, '_send_request', new_callable=AsyncMock - ) as mock_send_req: - mock_send_req.return_value = rpc_response_payload - await client.set_task_callback( - request=request, http_kwargs=custom_kwargs - ) - - mock_send_req.assert_called_once() - called_args, _ = mock_send_req.call_args # Correctly unpack args - assert ( - called_args[1] == custom_kwargs - ) # http_kwargs is the second positional arg + ) as mock_send_request: + mock_send_request.return_value = rpc_response + response = await client.cancel_task(request=params) + + assert isinstance(response, Task) + assert ( + response.model_dump() + == Task.model_validate(MINIMAL_CANCELLED_TASK).model_dump() + ) + mock_send_request.assert_called_once() + sent_payload = mock_send_request.call_args.args[0] + assert sent_payload['method'] == 'tasks/cancel' @pytest.mark.asyncio - async def test_get_task_callback_success( + async def test_set_task_callback_success( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): - client = A2AClient( + client = JsonRpcTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card ) - task_id_val = 'task_get_cb_001' - params_model = TaskIdParams( - id=task_id_val - ) # Params for get is just TaskIdParams - - request = GetTaskPushNotificationConfigRequest( - id='', params=params_model - ) # ID is empty string for auto-generation test - - # Expected result for a successful get operation - push_config_payload = PushNotificationConfig( - url='https://callback.example.com/taskupdate' - ) - expected_callback_config = TaskPushNotificationConfig( - task_id=task_id_val, push_notification_config=push_config_payload + params = TaskPushNotificationConfig( + task_id='task-abc', + push_notification_config=PushNotificationConfig( + url='http://callback.com' + ), ) - rpc_response_payload: dict[str, Any] = { - 'id': ANY, + rpc_response = { + 'id': '123', 'jsonrpc': '2.0', - 'result': expected_callback_config.model_dump( - mode='json', exclude_none=True - ), + 'result': params.model_dump(mode='json'), } + with patch.object( + client, '_send_request', new_callable=AsyncMock + ) as mock_send_request: + mock_send_request.return_value = rpc_response + response = await client.set_task_callback(request=params) - with ( - patch.object( - client, '_send_request', new_callable=AsyncMock - ) as mock_send_req, - patch( - 'a2a.client.jsonrpc_client.uuid4', - return_value=MagicMock(hex='testgetuuid'), - ) as mock_uuid, - ): - generated_id = str(mock_uuid.return_value) - rpc_response_payload['id'] = generated_id - mock_send_req.return_value = rpc_response_payload - - response = await client.get_task_callback(request=request) - - mock_send_req.assert_called_once() - called_args, _ = mock_send_req.call_args - sent_json_payload = called_args[0] - - assert sent_json_payload['id'] == generated_id - assert ( - sent_json_payload['method'] - == 'tasks/pushNotificationConfig/get' - ) - assert sent_json_payload['params'] == params_model.model_dump( - mode='json', exclude_none=True - ) - - assert isinstance(response, GetTaskPushNotificationConfigResponse) - assert isinstance( - response.root, GetTaskPushNotificationConfigSuccessResponse - ) - assert response.root.id == generated_id - assert response.root.result.model_dump( - mode='json', exclude_none=True - ) == expected_callback_config.model_dump( - mode='json', exclude_none=True - ) + assert isinstance(response, TaskPushNotificationConfig) + assert response.model_dump() == params.model_dump() + mock_send_request.assert_called_once() + sent_payload = mock_send_request.call_args.args[0] + assert sent_payload['method'] == 'tasks/pushNotificationConfig/set' @pytest.mark.asyncio - async def test_get_task_callback_error_response( + async def test_get_task_callback_success( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): - client = A2AClient( + client = JsonRpcTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card ) - req_id = 'get_cb_err_req' - params_model = TaskIdParams(id='task_get_err_cb') - request = GetTaskPushNotificationConfigRequest( - id=req_id, params=params_model + params = TaskIdParams(id='task-abc') + expected_response = TaskPushNotificationConfig( + task_id='task-abc', + push_notification_config=PushNotificationConfig( + url='http://callback.com' + ), ) - error_details = TaskNotCancelableError( - message='Cannot get callback for uncancelable task' - ) # Example error - - rpc_response_payload: dict[str, Any] = { - 'id': req_id, + rpc_response = { + 'id': '123', 'jsonrpc': '2.0', - 'error': error_details.model_dump(mode='json', exclude_none=True), + 'result': expected_response.model_dump(mode='json'), } - with patch.object( client, '_send_request', new_callable=AsyncMock - ) as mock_send_req: - mock_send_req.return_value = rpc_response_payload - response = await client.get_task_callback(request=request) + ) as mock_send_request: + mock_send_request.return_value = rpc_response + response = await client.get_task_callback(request=params) - assert isinstance(response, GetTaskPushNotificationConfigResponse) - assert isinstance(response.root, JSONRPCErrorResponse) - assert response.root.error.model_dump( - mode='json', exclude_none=True - ) == error_details.model_dump(mode='json', exclude_none=True) - assert response.root.id == req_id + assert isinstance(response, TaskPushNotificationConfig) + assert response.model_dump() == expected_response.model_dump() + mock_send_request.assert_called_once() + sent_payload = mock_send_request.call_args.args[0] + assert sent_payload['method'] == 'tasks/pushNotificationConfig/get' @pytest.mark.asyncio - async def test_get_task_callback_http_kwargs_passed( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + @patch('a2a.client.transports.jsonrpc.aconnect_sse') + async def test_send_message_streaming_sse_error( + self, + mock_aconnect_sse: AsyncMock, + mock_httpx_client: AsyncMock, + mock_agent_card: MagicMock, ): - client = A2AClient( + client = JsonRpcTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card ) - params_model = TaskIdParams(id='task_get_cb_kwargs') - request = GetTaskPushNotificationConfigRequest( - id='get_cb_kwargs_req', params=params_model + params = MessageSendParams( + message=create_text_message_object(content='Hello stream') ) - custom_kwargs = {'headers': {'X-Tenant-ID': 'tenant-x'}} - - # Correctly create the nested PushNotificationConfig - push_config_payload_for_expected = PushNotificationConfig( - url='https://getkwargs.com' + mock_event_source = AsyncMock(spec=EventSource) + mock_event_source.aiter_sse.side_effect = SSEError( + 'Simulated SSE error' ) - expected_callback_config = TaskPushNotificationConfig( - task_id='task_get_cb_kwargs', - push_notification_config=push_config_payload_for_expected, + mock_aconnect_sse.return_value.__aenter__.return_value = ( + mock_event_source ) - rpc_response_payload: dict[str, Any] = { - 'id': 'get_cb_kwargs_req', - 'jsonrpc': '2.0', - 'result': expected_callback_config.model_dump(mode='json'), - } - with patch.object( - client, '_send_request', new_callable=AsyncMock - ) as mock_send_req: - mock_send_req.return_value = rpc_response_payload - await client.get_task_callback( - request=request, http_kwargs=custom_kwargs - ) - - mock_send_req.assert_called_once() - called_args, _ = mock_send_req.call_args # Correctly unpack args - assert ( - called_args[1] == custom_kwargs - ) # http_kwargs is the second positional arg + with pytest.raises(A2AClientHTTPError): + _ = [ + item + async for item in client.send_message_streaming(request=params) + ] @pytest.mark.asyncio - async def test_get_task_success_use_request( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + @patch('a2a.client.transports.jsonrpc.aconnect_sse') + async def test_send_message_streaming_json_error( + self, + mock_aconnect_sse: AsyncMock, + mock_httpx_client: AsyncMock, + mock_agent_card: MagicMock, ): - client = A2AClient( + client = JsonRpcTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card ) - task_id_val = 'task_for_req_obj' - params_model = TaskQueryParams(id=task_id_val) - request_obj_id = 789 - request = GetTaskRequest(id=request_obj_id, params=params_model) - - rpc_response_payload: dict[str, Any] = { - 'id': request_obj_id, - 'jsonrpc': '2.0', - 'result': MINIMAL_TASK, - } + params = MessageSendParams( + message=create_text_message_object(content='Hello stream') + ) + sse_event = ServerSentEvent(data='{invalid json') + mock_event_source = AsyncMock(spec=EventSource) + mock_event_source.aiter_sse.return_value = async_iterable_from_list( + [sse_event] + ) + mock_aconnect_sse.return_value.__aenter__.return_value = ( + mock_event_source + ) - with patch.object( - client, '_send_request', new_callable=AsyncMock - ) as mock_send_req: - mock_send_req.return_value = rpc_response_payload - response = await client.get_task( - request=request, http_kwargs={'timeout': 20} - ) - - assert mock_send_req.call_count == 1 - called_args, called_kwargs = mock_send_req.call_args - assert len(called_args) == 2 - json_rpc_request_sent: dict[str, Any] = called_args[0] - assert not called_kwargs # no extra kwargs to _send_request - http_kwargs: dict[str, Any] = called_args[1] - assert http_kwargs['timeout'] == 20 - - assert json_rpc_request_sent['method'] == 'tasks/get' - assert json_rpc_request_sent['id'] == request_obj_id - assert json_rpc_request_sent['params'] == params_model.model_dump( - mode='json', exclude_none=True - ) - - assert isinstance(response, GetTaskResponse) - assert hasattr(response.root, 'result') - assert ( - response.root.result.model_dump(mode='json', exclude_none=True) # type: ignore - == MINIMAL_TASK - ) - assert response.root.id == request_obj_id + with pytest.raises(A2AClientJSONError): + _ = [ + item + async for item in client.send_message_streaming(request=params) + ] @pytest.mark.asyncio - async def test_get_task_error_response( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + @patch('a2a.client.transports.jsonrpc.aconnect_sse') + async def test_send_message_streaming_request_error( + self, + mock_aconnect_sse: AsyncMock, + mock_httpx_client: AsyncMock, + mock_agent_card: MagicMock, ): - client = A2AClient( + client = JsonRpcTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card ) - params_model = TaskQueryParams(id='task_error_case') - request = GetTaskRequest(id='err_req_id', params=params_model) - error_details = InvalidParamsError() - - rpc_response_payload: dict[str, Any] = { - 'id': 'err_req_id', - 'jsonrpc': '2.0', - 'error': error_details.model_dump(mode='json', exclude_none=True), - } - - with patch.object( - client, '_send_request', new_callable=AsyncMock - ) as mock_send_req: - mock_send_req.return_value = rpc_response_payload - response = await client.get_task(request=request) + params = MessageSendParams( + message=create_text_message_object(content='Hello stream') + ) + mock_event_source = AsyncMock(spec=EventSource) + mock_event_source.aiter_sse.side_effect = httpx.RequestError( + 'Simulated request error', request=MagicMock() + ) + mock_aconnect_sse.return_value.__aenter__.return_value = ( + mock_event_source + ) - assert isinstance(response, GetTaskResponse) - assert isinstance(response.root, JSONRPCErrorResponse) - assert response.root.error.model_dump( - mode='json', exclude_none=True - ) == error_details.model_dump(exclude_none=True) - assert response.root.id == 'err_req_id' + with pytest.raises(A2AClientHTTPError): + _ = [ + item + async for item in client.send_message_streaming(request=params) + ] @pytest.mark.asyncio - async def test_cancel_task_success_use_request( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + async def test_get_card_no_card_provided( + self, mock_httpx_client: AsyncMock ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card + client = JsonRpcTransport( + httpx_client=mock_httpx_client, url=self.AGENT_URL ) - task_id_val = MINIMAL_CANCELLED_TASK['id'] - params_model = TaskIdParams(id=task_id_val) - request_obj_id = 'cancel_req_obj_id_001' - request = CancelTaskRequest(id=request_obj_id, params=params_model) + mock_response = AsyncMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.json.return_value = AGENT_CARD.model_dump(mode='json') + mock_httpx_client.get.return_value = mock_response - rpc_response_payload: dict[str, Any] = { - 'id': request_obj_id, - 'jsonrpc': '2.0', - 'result': MINIMAL_CANCELLED_TASK, - } + card = await client.get_card() - with patch.object( - client, '_send_request', new_callable=AsyncMock - ) as mock_send_req: - mock_send_req.return_value = rpc_response_payload - response = await client.cancel_task( - request=request, http_kwargs={'timeout': 15} - ) - - assert mock_send_req.call_count == 1 - called_args, called_kwargs = mock_send_req.call_args - assert not called_kwargs # no extra kwargs to _send_request - assert len(called_args) == 2 - json_rpc_request_sent: dict[str, Any] = called_args[0] - http_kwargs: dict[str, Any] = called_args[1] - assert http_kwargs['timeout'] == 15 - - assert json_rpc_request_sent['method'] == 'tasks/cancel' - assert json_rpc_request_sent['id'] == request_obj_id - assert json_rpc_request_sent['params'] == params_model.model_dump( - mode='json', exclude_none=True - ) - - assert isinstance(response, CancelTaskResponse) - assert isinstance(response.root, CancelTaskSuccessResponse) - assert ( - response.root.result.model_dump(mode='json', exclude_none=True) # type: ignore - == MINIMAL_CANCELLED_TASK - ) - assert response.root.id == request_obj_id + assert card == AGENT_CARD + mock_httpx_client.get.assert_called_once() @pytest.mark.asyncio - async def test_cancel_task_error_response( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + async def test_get_card_with_extended_card_support( + self, mock_httpx_client: AsyncMock ): - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card + agent_card = AGENT_CARD.model_copy( + update={'supports_authenticated_extended_card': True} + ) + client = JsonRpcTransport( + httpx_client=mock_httpx_client, agent_card=agent_card ) - params_model = TaskIdParams(id='task_cancel_error_case') - request = CancelTaskRequest(id='err_cancel_req', params=params_model) - error_details = TaskNotCancelableError() - rpc_response_payload: dict[str, Any] = { - 'id': 'err_cancel_req', + rpc_response = { + 'id': '123', 'jsonrpc': '2.0', - 'error': error_details.model_dump(mode='json', exclude_none=True), + 'result': AGENT_CARD_EXTENDED.model_dump(mode='json'), } - with patch.object( client, '_send_request', new_callable=AsyncMock - ) as mock_send_req: - mock_send_req.return_value = rpc_response_payload - response = await client.cancel_task(request=request) + ) as mock_send_request: + mock_send_request.return_value = rpc_response + card = await client.get_card() - assert isinstance(response, CancelTaskResponse) - assert isinstance(response.root, JSONRPCErrorResponse) - assert response.root.error.model_dump( - mode='json', exclude_none=True - ) == error_details.model_dump(exclude_none=True) - assert response.root.id == 'err_cancel_req' + assert card == agent_card + mock_send_request.assert_called_once() + sent_payload = mock_send_request.call_args.args[0] + assert sent_payload['method'] == 'agent/getAuthenticatedExtendedCard' @pytest.mark.asyncio - async def test_send_message_client_timeout( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - mock_httpx_client.post.side_effect = httpx.ReadTimeout( - 'Request timed out' - ) - client = A2AClient( - httpx_client=mock_httpx_client, agent_card=mock_agent_card + async def test_close(self, mock_httpx_client: AsyncMock): + client = JsonRpcTransport( + httpx_client=mock_httpx_client, url=self.AGENT_URL ) - - params = MessageSendParams( - message=create_text_message_object(content='Hello') - ) - - request = SendMessageRequest(id=123, params=params) - - with pytest.raises(A2AClientTimeoutError) as exc_info: - await client.send_message(request=request) - - assert 'Request timed out' in str(exc_info.value) + await client.close() + mock_httpx_client.aclose.assert_called_once() diff --git a/tests/client/test_legacy_client.py b/tests/client/test_legacy_client.py new file mode 100644 index 00000000..247f0b18 --- /dev/null +++ b/tests/client/test_legacy_client.py @@ -0,0 +1,116 @@ +"""Tests for the legacy client compatibility layer.""" + +from unittest.mock import AsyncMock, MagicMock + +import httpx +import pytest + +from a2a.client import A2AClient, A2AGrpcClient +from a2a.types import ( + AgentCard, + AgentCapabilities, + Message, + Role, + TextPart, + Part, + Task, + TaskStatus, + TaskState, + TaskQueryParams, + SendMessageRequest, + MessageSendParams, + GetTaskRequest, +) + + +@pytest.fixture +def mock_httpx_client() -> AsyncMock: + return AsyncMock(spec=httpx.AsyncClient) + + +@pytest.fixture +def mock_grpc_stub() -> AsyncMock: + stub = AsyncMock() + stub._channel = MagicMock() + return stub + + +@pytest.fixture +def jsonrpc_agent_card() -> AgentCard: + return AgentCard( + name='Test Agent', + description='A test agent', + url='http://test.agent.com/rpc', + version='1.0.0', + capabilities=AgentCapabilities(streaming=True), + skills=[], + default_input_modes=[], + default_output_modes=[], + preferred_transport='jsonrpc', + ) + + +@pytest.fixture +def grpc_agent_card() -> AgentCard: + return AgentCard( + name='Test Agent', + description='A test agent', + url='http://test.agent.com/rpc', + version='1.0.0', + capabilities=AgentCapabilities(streaming=True), + skills=[], + default_input_modes=[], + default_output_modes=[], + preferred_transport='grpc', + ) + + +@pytest.mark.asyncio +async def test_a2a_client_send_message( + mock_httpx_client: AsyncMock, jsonrpc_agent_card: AgentCard +): + client = A2AClient( + httpx_client=mock_httpx_client, agent_card=jsonrpc_agent_card + ) + + # Mock the underlying transport's send_message method + mock_response_task = Task( + id='task-123', + context_id='ctx-456', + status=TaskStatus(state=TaskState.completed), + ) + + client._transport.send_message = AsyncMock(return_value=mock_response_task) + + message = Message( + message_id='msg-123', + role=Role.user, + parts=[Part(root=TextPart(text='Hello'))], + ) + request = SendMessageRequest( + id='req-123', params=MessageSendParams(message=message) + ) + response = await client.send_message(request) + + assert response.root.result.id == 'task-123' + + +@pytest.mark.asyncio +async def test_a2a_grpc_client_get_task( + mock_grpc_stub: AsyncMock, grpc_agent_card: AgentCard +): + client = A2AGrpcClient(grpc_stub=mock_grpc_stub, agent_card=grpc_agent_card) + + mock_response_task = Task( + id='task-456', + context_id='ctx-789', + status=TaskStatus(state=TaskState.working), + ) + + client.get_task = AsyncMock(return_value=mock_response_task) + + params = TaskQueryParams(id='task-456') + response = await client.get_task(params) + + assert response.id == 'task-456' + client.get_task.assert_awaited_once_with(params) diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py new file mode 100644 index 00000000..928ab2ea --- /dev/null +++ b/tests/integration/test_client_server_integration.py @@ -0,0 +1,747 @@ +import asyncio + +from collections.abc import AsyncGenerator +from typing import NamedTuple +from unittest.mock import ANY, AsyncMock + +import grpc +import httpx +import pytest +import pytest_asyncio + +from grpc.aio import Channel + +from a2a.client.transports import JsonRpcTransport, RestTransport +from a2a.client.transports.base import ClientTransport +from a2a.client.transports.grpc import GrpcTransport +from a2a.grpc import a2a_pb2_grpc +from a2a.server.apps import A2AFastAPIApplication, A2ARESTFastAPIApplication +from a2a.server.request_handlers import GrpcHandler, RequestHandler +from a2a.types import ( + AgentCapabilities, + AgentCard, + AgentInterface, + GetTaskPushNotificationConfigParams, + Message, + MessageSendParams, + Part, + PushNotificationConfig, + Role, + Task, + TaskIdParams, + TaskPushNotificationConfig, + TaskQueryParams, + TaskState, + TaskStatus, + TaskStatusUpdateEvent, + TextPart, + TransportProtocol, +) + + +# --- Test Constants --- + +TASK_FROM_STREAM = Task( + id='task-123-stream', + context_id='ctx-456-stream', + status=TaskStatus(state=TaskState.completed), + kind='task', +) + +TASK_FROM_BLOCKING = Task( + id='task-789-blocking', + context_id='ctx-101-blocking', + status=TaskStatus(state=TaskState.completed), + kind='task', +) + +GET_TASK_RESPONSE = Task( + id='task-get-456', + context_id='ctx-get-789', + status=TaskStatus(state=TaskState.working), + kind='task', +) + +CANCEL_TASK_RESPONSE = Task( + id='task-cancel-789', + context_id='ctx-cancel-101', + status=TaskStatus(state=TaskState.canceled), + kind='task', +) + +CALLBACK_CONFIG = TaskPushNotificationConfig( + task_id='task-callback-123', + push_notification_config=PushNotificationConfig( + id='pnc-abc', url='http://callback.example.com', token='' + ), +) + +RESUBSCRIBE_EVENT = TaskStatusUpdateEvent( + task_id='task-resub-456', + context_id='ctx-resub-789', + status=TaskStatus(state=TaskState.working), + final=False, +) + + +# --- Test Fixtures --- + + +@pytest.fixture +def mock_request_handler() -> AsyncMock: + """Provides a mock RequestHandler for the server-side handlers.""" + handler = AsyncMock(spec=RequestHandler) + + # Configure on_message_send for non-streaming calls + handler.on_message_send.return_value = TASK_FROM_BLOCKING + + # Configure on_message_send_stream for streaming calls + async def stream_side_effect(*args, **kwargs): + yield TASK_FROM_STREAM + + handler.on_message_send_stream.side_effect = stream_side_effect + + # Configure other methods + handler.on_get_task.return_value = GET_TASK_RESPONSE + handler.on_cancel_task.return_value = CANCEL_TASK_RESPONSE + handler.on_set_task_push_notification_config.side_effect = ( + lambda params, context: params + ) + handler.on_get_task_push_notification_config.return_value = CALLBACK_CONFIG + + async def resubscribe_side_effect(*args, **kwargs): + yield RESUBSCRIBE_EVENT + + handler.on_resubscribe_to_task.side_effect = resubscribe_side_effect + + return handler + + +@pytest.fixture +def agent_card() -> AgentCard: + """Provides a sample AgentCard for tests.""" + return AgentCard( + name='Test Agent', + description='An agent for integration testing.', + url='http://testserver', + version='1.0.0', + capabilities=AgentCapabilities(streaming=True, push_notifications=True), + skills=[], + default_input_modes=['text/plain'], + default_output_modes=['text/plain'], + preferred_transport=TransportProtocol.jsonrpc, + supports_authenticated_extended_card=True, + additional_interfaces=[ + AgentInterface( + transport=TransportProtocol.http_json, url='http://testserver' + ), + AgentInterface( + transport=TransportProtocol.grpc, url='localhost:50051' + ), + ], + ) + + +class TransportSetup(NamedTuple): + """Holds the transport and handler for a given test.""" + + transport: ClientTransport + handler: AsyncMock + + +# --- HTTP/JSON-RPC/REST Setup --- + + +@pytest.fixture +def http_base_setup(mock_request_handler: AsyncMock, agent_card: AgentCard): + """A base fixture to patch the sse-starlette event loop issue.""" + from sse_starlette import sse + + sse.AppStatus.should_exit_event = asyncio.Event() + yield mock_request_handler, agent_card + + +@pytest.fixture +def jsonrpc_setup(http_base_setup) -> TransportSetup: + """Sets up the JsonRpcTransport and in-memory server.""" + mock_request_handler, agent_card = http_base_setup + app_builder = A2AFastAPIApplication( + agent_card, mock_request_handler, extended_agent_card=agent_card + ) + app = app_builder.build() + httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) + transport = JsonRpcTransport( + httpx_client=httpx_client, agent_card=agent_card + ) + return TransportSetup(transport=transport, handler=mock_request_handler) + + +@pytest.fixture +def rest_setup(http_base_setup) -> TransportSetup: + """Sets up the RestTransport and in-memory server.""" + mock_request_handler, agent_card = http_base_setup + app_builder = A2ARESTFastAPIApplication(agent_card, mock_request_handler) + app = app_builder.build() + httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) + transport = RestTransport(httpx_client=httpx_client, agent_card=agent_card) + return TransportSetup(transport=transport, handler=mock_request_handler) + + +# --- gRPC Setup --- + + +@pytest_asyncio.fixture +async def grpc_server_and_handler( + mock_request_handler: AsyncMock, agent_card: AgentCard +) -> AsyncGenerator[tuple[str, AsyncMock], None]: + """Creates and manages an in-process gRPC test server.""" + server = grpc.aio.server() + port = server.add_insecure_port('[::]:0') + server_address = f'localhost:{port}' + servicer = GrpcHandler(agent_card, mock_request_handler) + a2a_pb2_grpc.add_A2AServiceServicer_to_server(servicer, server) + await server.start() + yield server_address, mock_request_handler + await server.stop(0) + + +# --- The Integration Tests --- + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'transport_setup_fixture', + [ + pytest.param('jsonrpc_setup', id='JSON-RPC'), + pytest.param('rest_setup', id='REST'), + ], +) +async def test_http_transport_sends_message_streaming( + transport_setup_fixture: str, request +) -> None: + """ + Integration test for HTTP-based transports (JSON-RPC, REST) streaming. + """ + transport_setup: TransportSetup = request.getfixturevalue( + transport_setup_fixture + ) + transport = transport_setup.transport + handler = transport_setup.handler + + message_to_send = Message( + role=Role.user, + message_id='msg-integration-test', + parts=[Part(root=TextPart(text='Hello, integration test!'))], + ) + params = MessageSendParams(message=message_to_send) + + stream = transport.send_message_streaming(request=params) + first_event = await anext(stream) + + assert first_event.id == TASK_FROM_STREAM.id + assert first_event.context_id == TASK_FROM_STREAM.context_id + + handler.on_message_send_stream.assert_called_once() + call_args, _ = handler.on_message_send_stream.call_args + received_params: MessageSendParams = call_args[0] + + assert received_params.message.message_id == message_to_send.message_id + assert ( + received_params.message.parts[0].root.text + == message_to_send.parts[0].root.text + ) + + if hasattr(transport, 'close'): + await transport.close() + + +@pytest.mark.asyncio +async def test_grpc_transport_sends_message_streaming( + grpc_server_and_handler: tuple[str, AsyncMock], + agent_card: AgentCard, +) -> None: + """ + Integration test specifically for the gRPC transport streaming. + """ + server_address, handler = grpc_server_and_handler + agent_card.url = server_address + + def channel_factory(address: str) -> Channel: + return grpc.aio.insecure_channel(address) + + stub = a2a_pb2_grpc.A2AServiceStub(channel_factory(server_address)) + transport = GrpcTransport(grpc_stub=stub, agent_card=agent_card) + + message_to_send = Message( + role=Role.user, + message_id='msg-grpc-integration-test', + parts=[Part(root=TextPart(text='Hello, gRPC integration test!'))], + ) + params = MessageSendParams(message=message_to_send) + + stream = transport.send_message_streaming(request=params) + first_event = await anext(stream) + + assert first_event.id == TASK_FROM_STREAM.id + assert first_event.context_id == TASK_FROM_STREAM.context_id + + handler.on_message_send_stream.assert_called_once() + call_args, _ = handler.on_message_send_stream.call_args + received_params: MessageSendParams = call_args[0] + + assert received_params.message.message_id == message_to_send.message_id + assert ( + received_params.message.parts[0].root.text + == message_to_send.parts[0].root.text + ) + + await transport.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'transport_setup_fixture', + [ + pytest.param('jsonrpc_setup', id='JSON-RPC'), + pytest.param('rest_setup', id='REST'), + ], +) +async def test_http_transport_sends_message_blocking( + transport_setup_fixture: str, request +) -> None: + """ + Integration test for HTTP-based transports (JSON-RPC, REST) blocking. + """ + transport_setup: TransportSetup = request.getfixturevalue( + transport_setup_fixture + ) + transport = transport_setup.transport + handler = transport_setup.handler + + message_to_send = Message( + role=Role.user, + message_id='msg-integration-test-blocking', + parts=[Part(root=TextPart(text='Hello, blocking test!'))], + ) + params = MessageSendParams(message=message_to_send) + + result = await transport.send_message(request=params) + + assert result.id == TASK_FROM_BLOCKING.id + assert result.context_id == TASK_FROM_BLOCKING.context_id + + handler.on_message_send.assert_awaited_once() + call_args, _ = handler.on_message_send.call_args + received_params: MessageSendParams = call_args[0] + + assert received_params.message.message_id == message_to_send.message_id + assert ( + received_params.message.parts[0].root.text + == message_to_send.parts[0].root.text + ) + + if hasattr(transport, 'close'): + await transport.close() + + +@pytest.mark.asyncio +async def test_grpc_transport_sends_message_blocking( + grpc_server_and_handler: tuple[str, AsyncMock], + agent_card: AgentCard, +) -> None: + """ + Integration test specifically for the gRPC transport blocking. + """ + server_address, handler = grpc_server_and_handler + agent_card.url = server_address + + def channel_factory(address: str) -> Channel: + return grpc.aio.insecure_channel(address) + + stub = a2a_pb2_grpc.A2AServiceStub(channel_factory(server_address)) + transport = GrpcTransport(grpc_stub=stub, agent_card=agent_card) + + message_to_send = Message( + role=Role.user, + message_id='msg-grpc-integration-test-blocking', + parts=[Part(root=TextPart(text='Hello, gRPC blocking test!'))], + ) + params = MessageSendParams(message=message_to_send) + + result = await transport.send_message(request=params) + + assert result.id == TASK_FROM_BLOCKING.id + assert result.context_id == TASK_FROM_BLOCKING.context_id + + handler.on_message_send.assert_awaited_once() + call_args, _ = handler.on_message_send.call_args + received_params: MessageSendParams = call_args[0] + + assert received_params.message.message_id == message_to_send.message_id + assert ( + received_params.message.parts[0].root.text + == message_to_send.parts[0].root.text + ) + + await transport.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'transport_setup_fixture', + [ + pytest.param('jsonrpc_setup', id='JSON-RPC'), + pytest.param('rest_setup', id='REST'), + ], +) +async def test_http_transport_get_task( + transport_setup_fixture: str, request +) -> None: + transport_setup: TransportSetup = request.getfixturevalue( + transport_setup_fixture + ) + transport = transport_setup.transport + handler = transport_setup.handler + + params = TaskQueryParams(id=GET_TASK_RESPONSE.id) + result = await transport.get_task(request=params) + + assert result.id == GET_TASK_RESPONSE.id + handler.on_get_task.assert_awaited_once_with(params, ANY) + + if hasattr(transport, 'close'): + await transport.close() + + +@pytest.mark.asyncio +async def test_grpc_transport_get_task( + grpc_server_and_handler: tuple[str, AsyncMock], + agent_card: AgentCard, +) -> None: + server_address, handler = grpc_server_and_handler + agent_card.url = server_address + + def channel_factory(address: str) -> Channel: + return grpc.aio.insecure_channel(address) + + stub = a2a_pb2_grpc.A2AServiceStub(channel_factory(server_address)) + transport = GrpcTransport(grpc_stub=stub, agent_card=agent_card) + + params = TaskQueryParams(id=GET_TASK_RESPONSE.id) + result = await transport.get_task(request=params) + + assert result.id == GET_TASK_RESPONSE.id + handler.on_get_task.assert_awaited_once() + assert handler.on_get_task.call_args[0][0].id == GET_TASK_RESPONSE.id + + await transport.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'transport_setup_fixture', + [ + pytest.param('jsonrpc_setup', id='JSON-RPC'), + pytest.param('rest_setup', id='REST'), + ], +) +async def test_http_transport_cancel_task( + transport_setup_fixture: str, request +) -> None: + transport_setup: TransportSetup = request.getfixturevalue( + transport_setup_fixture + ) + transport = transport_setup.transport + handler = transport_setup.handler + + params = TaskIdParams(id=CANCEL_TASK_RESPONSE.id) + result = await transport.cancel_task(request=params) + + assert result.id == CANCEL_TASK_RESPONSE.id + handler.on_cancel_task.assert_awaited_once_with(params, ANY) + + if hasattr(transport, 'close'): + await transport.close() + + +@pytest.mark.asyncio +async def test_grpc_transport_cancel_task( + grpc_server_and_handler: tuple[str, AsyncMock], + agent_card: AgentCard, +) -> None: + server_address, handler = grpc_server_and_handler + agent_card.url = server_address + + def channel_factory(address: str) -> Channel: + return grpc.aio.insecure_channel(address) + + stub = a2a_pb2_grpc.A2AServiceStub(channel_factory(server_address)) + transport = GrpcTransport(grpc_stub=stub, agent_card=agent_card) + + params = TaskIdParams(id=CANCEL_TASK_RESPONSE.id) + result = await transport.cancel_task(request=params) + + assert result.id == CANCEL_TASK_RESPONSE.id + handler.on_cancel_task.assert_awaited_once() + assert handler.on_cancel_task.call_args[0][0].id == CANCEL_TASK_RESPONSE.id + + await transport.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'transport_setup_fixture', + [ + pytest.param('jsonrpc_setup', id='JSON-RPC'), + pytest.param('rest_setup', id='REST'), + ], +) +async def test_http_transport_set_task_callback( + transport_setup_fixture: str, request +) -> None: + transport_setup: TransportSetup = request.getfixturevalue( + transport_setup_fixture + ) + transport = transport_setup.transport + handler = transport_setup.handler + + params = CALLBACK_CONFIG + result = await transport.set_task_callback(request=params) + + assert result.task_id == CALLBACK_CONFIG.task_id + assert ( + result.push_notification_config.id + == CALLBACK_CONFIG.push_notification_config.id + ) + assert ( + result.push_notification_config.url + == CALLBACK_CONFIG.push_notification_config.url + ) + handler.on_set_task_push_notification_config.assert_awaited_once_with( + params, ANY + ) + + if hasattr(transport, 'close'): + await transport.close() + + +@pytest.mark.asyncio +async def test_grpc_transport_set_task_callback( + grpc_server_and_handler: tuple[str, AsyncMock], + agent_card: AgentCard, +) -> None: + server_address, handler = grpc_server_and_handler + agent_card.url = server_address + + def channel_factory(address: str) -> Channel: + return grpc.aio.insecure_channel(address) + + stub = a2a_pb2_grpc.A2AServiceStub(channel_factory(server_address)) + transport = GrpcTransport(grpc_stub=stub, agent_card=agent_card) + + params = CALLBACK_CONFIG + result = await transport.set_task_callback(request=params) + + assert result.task_id == CALLBACK_CONFIG.task_id + assert ( + result.push_notification_config.id + == CALLBACK_CONFIG.push_notification_config.id + ) + assert ( + result.push_notification_config.url + == CALLBACK_CONFIG.push_notification_config.url + ) + handler.on_set_task_push_notification_config.assert_awaited_once() + assert ( + handler.on_set_task_push_notification_config.call_args[0][0].task_id + == CALLBACK_CONFIG.task_id + ) + + await transport.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'transport_setup_fixture', + [ + pytest.param('jsonrpc_setup', id='JSON-RPC'), + pytest.param('rest_setup', id='REST'), + ], +) +async def test_http_transport_get_task_callback( + transport_setup_fixture: str, request +) -> None: + transport_setup: TransportSetup = request.getfixturevalue( + transport_setup_fixture + ) + transport = transport_setup.transport + handler = transport_setup.handler + + params = GetTaskPushNotificationConfigParams( + id=CALLBACK_CONFIG.task_id, + push_notification_config_id=CALLBACK_CONFIG.push_notification_config.id, + ) + result = await transport.get_task_callback(request=params) + + assert result.task_id == CALLBACK_CONFIG.task_id + assert ( + result.push_notification_config.id + == CALLBACK_CONFIG.push_notification_config.id + ) + assert ( + result.push_notification_config.url + == CALLBACK_CONFIG.push_notification_config.url + ) + handler.on_get_task_push_notification_config.assert_awaited_once_with( + params, ANY + ) + + if hasattr(transport, 'close'): + await transport.close() + + +@pytest.mark.asyncio +async def test_grpc_transport_get_task_callback( + grpc_server_and_handler: tuple[str, AsyncMock], + agent_card: AgentCard, +) -> None: + server_address, handler = grpc_server_and_handler + agent_card.url = server_address + + def channel_factory(address: str) -> Channel: + return grpc.aio.insecure_channel(address) + + stub = a2a_pb2_grpc.A2AServiceStub(channel_factory(server_address)) + transport = GrpcTransport(grpc_stub=stub, agent_card=agent_card) + + params = GetTaskPushNotificationConfigParams( + id=CALLBACK_CONFIG.task_id, + push_notification_config_id=CALLBACK_CONFIG.push_notification_config.id, + ) + result = await transport.get_task_callback(request=params) + + assert result.task_id == CALLBACK_CONFIG.task_id + assert ( + result.push_notification_config.id + == CALLBACK_CONFIG.push_notification_config.id + ) + assert ( + result.push_notification_config.url + == CALLBACK_CONFIG.push_notification_config.url + ) + handler.on_get_task_push_notification_config.assert_awaited_once() + assert ( + handler.on_get_task_push_notification_config.call_args[0][0].id + == CALLBACK_CONFIG.task_id + ) + + await transport.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'transport_setup_fixture', + [ + pytest.param('jsonrpc_setup', id='JSON-RPC'), + pytest.param('rest_setup', id='REST'), + ], +) +async def test_http_transport_resubscribe( + transport_setup_fixture: str, request +) -> None: + transport_setup: TransportSetup = request.getfixturevalue( + transport_setup_fixture + ) + transport = transport_setup.transport + handler = transport_setup.handler + + params = TaskIdParams(id=RESUBSCRIBE_EVENT.task_id) + stream = transport.resubscribe(request=params) + first_event = await anext(stream) + + assert first_event.task_id == RESUBSCRIBE_EVENT.task_id + handler.on_resubscribe_to_task.assert_called_once_with(params, ANY) + + if hasattr(transport, 'close'): + await transport.close() + + +@pytest.mark.asyncio +async def test_grpc_transport_resubscribe( + grpc_server_and_handler: tuple[str, AsyncMock], + agent_card: AgentCard, +) -> None: + server_address, handler = grpc_server_and_handler + agent_card.url = server_address + + def channel_factory(address: str) -> Channel: + return grpc.aio.insecure_channel(address) + + stub = a2a_pb2_grpc.A2AServiceStub(channel_factory(server_address)) + transport = GrpcTransport(grpc_stub=stub, agent_card=agent_card) + + params = TaskIdParams(id=RESUBSCRIBE_EVENT.task_id) + stream = transport.resubscribe(request=params) + first_event = await anext(stream) + + assert first_event.task_id == RESUBSCRIBE_EVENT.task_id + handler.on_resubscribe_to_task.assert_called_once() + assert ( + handler.on_resubscribe_to_task.call_args[0][0].id + == RESUBSCRIBE_EVENT.task_id + ) + + await transport.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'transport_setup_fixture', + [ + pytest.param('jsonrpc_setup', id='JSON-RPC'), + pytest.param('rest_setup', id='REST'), + ], +) +async def test_http_transport_get_card( + transport_setup_fixture: str, request, agent_card: AgentCard +) -> None: + transport_setup: TransportSetup = request.getfixturevalue( + transport_setup_fixture + ) + transport = transport_setup.transport + + # The transport starts with a minimal card, get_card() fetches the full one + transport.agent_card.supports_authenticated_extended_card = True + result = await transport.get_card() + + assert result.name == agent_card.name + assert transport.agent_card.name == agent_card.name + assert transport._needs_extended_card is False + + if hasattr(transport, 'close'): + await transport.close() + + +@pytest.mark.asyncio +async def test_grpc_transport_get_card( + grpc_server_and_handler: tuple[str, AsyncMock], + agent_card: AgentCard, +) -> None: + server_address, _ = grpc_server_and_handler + agent_card.url = server_address + + def channel_factory(address: str) -> Channel: + return grpc.aio.insecure_channel(address) + + stub = a2a_pb2_grpc.A2AServiceStub(channel_factory(server_address)) + transport = GrpcTransport(grpc_stub=stub, agent_card=agent_card) + + # The transport starts with a minimal card, get_card() fetches the full one + transport.agent_card.supports_authenticated_extended_card = True + result = await transport.get_card() + + assert result.name == agent_card.name + assert transport.agent_card.name == agent_card.name + assert transport._needs_extended_card is False + + await transport.close() diff --git a/tests/utils/test_proto_utils.py b/tests/utils/test_proto_utils.py index 05c027fe..83848c24 100644 --- a/tests/utils/test_proto_utils.py +++ b/tests/utils/test_proto_utils.py @@ -237,14 +237,7 @@ def test_task_id_params_from_proto_invalid_name(self): assert isinstance(exc_info.value.error, types.InvalidParamsError) def test_task_push_config_from_proto_invalid_parent(self): - request = proto_utils.ToProto.task_push_notification_config( - types.TaskPushNotificationConfig( - task_id='test-task-id', - push_notification_config=types.PushNotificationConfig( - url='test_url' - ), - ) - ) + request = a2a_pb2.TaskPushNotificationConfig(name='invalid-name-format') with pytest.raises(ServerError) as exc_info: proto_utils.FromProto.task_push_notification_config(request) assert isinstance(exc_info.value.error, types.InvalidParamsError)