diff --git a/src/a2a/extensions/common.py b/src/a2a/extensions/common.py new file mode 100644 index 00000000..2f752caa --- /dev/null +++ b/src/a2a/extensions/common.py @@ -0,0 +1,27 @@ +from a2a.types import AgentCard, AgentExtension + + +HTTP_EXTENSION_HEADER = 'X-A2A-Extensions' + + +def get_requested_extensions(values: list[str]) -> set[str]: + """Get the set of requested extensions from an input list. + + This handles the list containing potentially comma-separated values, as + occurs when using a list in an HTTP header. + """ + return { + stripped + for v in values + for ext in v.split(',') + if (stripped := ext.strip()) + } + + +def find_extension_by_uri(card: AgentCard, uri: str) -> AgentExtension | None: + """Find an AgentExtension in an AgentCard given a uri.""" + for ext in card.capabilities.extensions or []: + if ext.uri == uri: + return ext + + return None diff --git a/src/a2a/server/agent_execution/context.py b/src/a2a/server/agent_execution/context.py index c992ba8e..8b1559aa 100644 --- a/src/a2a/server/agent_execution/context.py +++ b/src/a2a/server/agent_execution/context.py @@ -143,6 +143,24 @@ def metadata(self) -> dict[str, Any]: return {} return self._params.metadata or {} + def add_activated_extension(self, uri: str) -> None: + """Add an extension to the set of activated extensions for this request. + + This causes the extension to be indicated back to the client in the + response. + """ + if self._call_context: + self._call_context.activated_extensions.add(uri) + + @property + def requested_extensions(self) -> set[str]: + """Extensions that the client requested to activate.""" + return ( + self._call_context.requested_extensions + if self._call_context + else set() + ) + def _check_or_generate_task_id(self) -> None: """Ensures a task ID is present, generating one if necessary.""" if not self._params: diff --git a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py index 08e6b53f..e149b4d1 100644 --- a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py +++ b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py @@ -19,6 +19,10 @@ from a2a.auth.user import UnauthenticatedUser from a2a.auth.user import User as A2AUser +from a2a.extensions.common import ( + HTTP_EXTENSION_HEADER, + get_requested_extensions, +) from a2a.server.context import ServerCallContext from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler from a2a.server.request_handlers.request_handler import RequestHandler @@ -99,7 +103,13 @@ def build(self, request: Request) -> ServerCallContext: user = StarletteUserProxy(request.user) state['auth'] = request.auth state['headers'] = dict(request.headers) - return ServerCallContext(user=user, state=state) + return ServerCallContext( + user=user, + state=state, + requested_extensions=get_requested_extensions( + request.headers.getlist(HTTP_EXTENSION_HEADER) + ), + ) class JSONRPCApplication(ABC): @@ -281,7 +291,7 @@ async def _process_streaming_request( request_obj, context ) - return self._create_response(handler_result) + return self._create_response(context, handler_result) async def _process_non_streaming_request( self, @@ -353,10 +363,11 @@ async def _process_non_streaming_request( id=request_id, error=error ) - return self._create_response(handler_result) + return self._create_response(context, handler_result) def _create_response( self, + context: ServerCallContext, handler_result: ( AsyncGenerator[SendStreamingMessageResponse] | JSONRPCErrorResponse @@ -372,12 +383,16 @@ def _create_response( payloads. Args: + context: The ServerCallContext provided to the request handler. handler_result: The result from a request handler method. Can be an async generator for streaming or a Pydantic model for non-streaming. Returns: A Starlette JSONResponse or EventSourceResponse. """ + headers = {} + if exts := context.activated_extensions: + headers[HTTP_EXTENSION_HEADER] = ', '.join(sorted(exts)) if isinstance(handler_result, AsyncGenerator): # Result is a stream of SendStreamingMessageResponse objects async def event_generator( @@ -386,17 +401,21 @@ async def event_generator( async for item in stream: yield {'data': item.root.model_dump_json(exclude_none=True)} - return EventSourceResponse(event_generator(handler_result)) + return EventSourceResponse( + event_generator(handler_result), headers=headers + ) if isinstance(handler_result, JSONRPCErrorResponse): return JSONResponse( handler_result.model_dump( mode='json', exclude_none=True, - ) + ), + headers=headers, ) return JSONResponse( - handler_result.root.model_dump(mode='json', exclude_none=True) + handler_result.root.model_dump(mode='json', exclude_none=True), + headers=headers, ) async def _handle_get_agent_card(self, request: Request) -> JSONResponse: diff --git a/src/a2a/server/context.py b/src/a2a/server/context.py index ce7f56bd..2b34cefe 100644 --- a/src/a2a/server/context.py +++ b/src/a2a/server/context.py @@ -21,3 +21,5 @@ class ServerCallContext(BaseModel): state: State = Field(default={}) user: User = Field(default=UnauthenticatedUser()) + requested_extensions: set[str] = Field(default_factory=set) + activated_extensions: set[str] = Field(default_factory=set) diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index 3693dc02..2761ed33 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -3,12 +3,14 @@ import logging from abc import ABC, abstractmethod -from collections.abc import AsyncIterable +from collections.abc import AsyncIterable, Sequence try: import grpc import grpc.aio + + from grpc.aio import Metadata except ImportError as e: raise ImportError( 'GrpcHandler requires grpcio and grpcio-tools to be installed. ' @@ -20,6 +22,10 @@ from a2a import types from a2a.auth.user import UnauthenticatedUser +from a2a.extensions.common import ( + HTTP_EXTENSION_HEADER, + get_requested_extensions, +) from a2a.grpc import a2a_pb2 from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler @@ -42,6 +48,19 @@ def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext: """Builds a ServerCallContext from a gRPC Request.""" +def _get_metadata_value( + context: grpc.aio.ServicerContext, key: str +) -> list[str]: + md = context.invocation_metadata + raw_values: list[str | bytes] = [] + if isinstance(md, Metadata): + raw_values = md.get_all(key) + elif isinstance(md, Sequence): + lower_key = key.lower() + raw_values = [e for (k, e) in md if k.lower() == lower_key] + return [e if isinstance(e, str) else e.decode('utf-8') for e in raw_values] + + class DefaultCallContextBuilder(CallContextBuilder): """A default implementation of CallContextBuilder.""" @@ -51,7 +70,13 @@ def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext: state = {} with contextlib.suppress(Exception): state['grpc_context'] = context - return ServerCallContext(user=user, state=state) + return ServerCallContext( + user=user, + state=state, + requested_extensions=get_requested_extensions( + _get_metadata_value(context, HTTP_EXTENSION_HEADER) + ), + ) class GrpcHandler(a2a_grpc.A2AServiceServicer): @@ -102,6 +127,7 @@ async def SendMessage( task_or_message = await self.request_handler.on_message_send( a2a_request, server_context ) + self._set_extension_metadata(context, server_context) return proto_utils.ToProto.task_or_message(task_or_message) except ServerError as e: await self.abort_context(e, context) @@ -140,6 +166,7 @@ async def SendStreamingMessage( a2a_request, server_context ): yield proto_utils.ToProto.stream_response(event) + self._set_extension_metadata(context, server_context) except ServerError as e: await self.abort_context(e, context) return @@ -371,3 +398,16 @@ async def abort_context( grpc.StatusCode.UNKNOWN, f'Unknown error type: {error.error}', ) + + def _set_extension_metadata( + self, + context: grpc.aio.ServicerContext, + server_context: ServerCallContext, + ) -> None: + if server_context.activated_extensions: + context.set_trailing_metadata( + [ + (HTTP_EXTENSION_HEADER, e) + for e in sorted(server_context.activated_extensions) + ] + ) diff --git a/tests/extensions/test_common.py b/tests/extensions/test_common.py new file mode 100644 index 00000000..137e64c9 --- /dev/null +++ b/tests/extensions/test_common.py @@ -0,0 +1,58 @@ +from a2a.extensions.common import ( + find_extension_by_uri, + get_requested_extensions, +) +from a2a.types import AgentCapabilities, AgentCard, AgentExtension + + +def test_get_requested_extensions(): + assert get_requested_extensions([]) == set() + assert get_requested_extensions(['foo']) == {'foo'} + assert get_requested_extensions(['foo', 'bar']) == {'foo', 'bar'} + assert get_requested_extensions(['foo, bar']) == {'foo', 'bar'} + assert get_requested_extensions(['foo,bar']) == {'foo', 'bar'} + assert get_requested_extensions(['foo', 'bar,baz']) == {'foo', 'bar', 'baz'} + assert get_requested_extensions(['foo,, bar', 'baz']) == { + 'foo', + 'bar', + 'baz', + } + assert get_requested_extensions([' foo , bar ', 'baz']) == { + 'foo', + 'bar', + 'baz', + } + + +def test_find_extension_by_uri(): + ext1 = AgentExtension(uri='foo', description='The Foo extension') + ext2 = AgentExtension(uri='bar', description='The Bar extension') + card = AgentCard( + name='Test Agent', + description='Test Agent Description', + version='1.0', + url='http://test.com', + skills=[], + default_input_modes=['text/plain'], + default_output_modes=['text/plain'], + capabilities=AgentCapabilities(extensions=[ext1, ext2]), + ) + + assert find_extension_by_uri(card, 'foo') == ext1 + assert find_extension_by_uri(card, 'bar') == ext2 + assert find_extension_by_uri(card, 'baz') is None + + +def test_find_extension_by_uri_no_extensions(): + card = AgentCard( + name='Test Agent', + description='Test Agent Description', + version='1.0', + url='http://test.com', + skills=[], + default_input_modes=['text/plain'], + default_output_modes=['text/plain'], + capabilities=AgentCapabilities(extensions=None), + ) + + assert find_extension_by_uri(card, 'foo') is None diff --git a/tests/server/agent_execution/test_context.py b/tests/server/agent_execution/test_context.py index 7b9651b0..5cecd892 100644 --- a/tests/server/agent_execution/test_context.py +++ b/tests/server/agent_execution/test_context.py @@ -5,6 +5,7 @@ import pytest from a2a.server.agent_execution import RequestContext +from a2a.server.context import ServerCallContext from a2a.types import ( Message, MessageSendParams, @@ -263,3 +264,16 @@ def test_init_with_context_id_and_existing_context_id_match( assert context.context_id == mock_task.context_id assert context.current_task == mock_task + + def test_extension_handling(self): + """Test extension handling in RequestContext.""" + call_context = ServerCallContext(requested_extensions={'foo', 'bar'}) + context = RequestContext(call_context=call_context) + + assert context.requested_extensions == {'foo', 'bar'} + + context.add_activated_extension('foo') + assert call_context.activated_extensions == {'foo'} + + context.add_activated_extension('baz') + assert call_context.activated_extensions == {'foo', 'baz'} diff --git a/tests/server/apps/jsonrpc/test_jsonrpc_app.py b/tests/server/apps/jsonrpc/test_jsonrpc_app.py index 9e5f88ba..62dc0fe2 100644 --- a/tests/server/apps/jsonrpc/test_jsonrpc_app.py +++ b/tests/server/apps/jsonrpc/test_jsonrpc_app.py @@ -1,7 +1,9 @@ -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock import pytest +from starlette.testclient import TestClient + # Attempt to import StarletteBaseUser, fallback to MagicMock if not available try: @@ -9,14 +11,27 @@ except ImportError: StarletteBaseUser = MagicMock() # type: ignore +from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.server.apps.jsonrpc.jsonrpc_app import ( - JSONRPCApplication, # Still needed for JSONRPCApplication default constructor arg + JSONRPCApplication, StarletteUserProxy, ) +from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication +from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import ( - RequestHandler, # For mock spec + RequestHandler, +) # For mock spec +from a2a.types import ( + AgentCard, + Message, + MessageSendParams, + Part, + Role, + SendMessageRequest, + SendMessageResponse, + SendMessageSuccessResponse, + TextPart, ) -from a2a.types import AgentCard # For mock spec # --- StarletteUserProxy Tests --- @@ -86,5 +101,176 @@ def some_other_method(self): ) +class TestJSONRPCExtensions: + @pytest.fixture + def mock_handler(self): + handler = AsyncMock(spec=RequestHandler) + handler.on_message_send.return_value = SendMessageResponse( + root=SendMessageSuccessResponse( + id='1', + result=Message( + message_id='test', + role=Role.agent, + parts=[Part(TextPart(text='response message'))], + ), + ) + ) + return handler + + @pytest.fixture + def test_app(self, mock_handler): + mock_agent_card = MagicMock(spec=AgentCard) + mock_agent_card.url = 'http://mockurl.com' + mock_agent_card.supports_authenticated_extended_card = False + + return A2AStarletteApplication( + agent_card=mock_agent_card, http_handler=mock_handler + ) + + @pytest.fixture + def client(self, test_app): + return TestClient(test_app.build()) + + def test_request_with_single_extension(self, client, mock_handler): + headers = {HTTP_EXTENSION_HEADER: 'foo'} + response = client.post( + '/', + headers=headers, + json=SendMessageRequest( + id='1', + params=MessageSendParams( + message=Message( + message_id='1', + role=Role.user, + parts=[Part(TextPart(text='hi'))], + ) + ), + ).model_dump(), + ) + response.raise_for_status() + + mock_handler.on_message_send.assert_called_once() + call_context = mock_handler.on_message_send.call_args[0][1] + assert isinstance(call_context, ServerCallContext) + assert call_context.requested_extensions == {'foo'} + + def test_request_with_comma_separated_extensions( + self, client, mock_handler + ): + headers = {HTTP_EXTENSION_HEADER: 'foo, bar'} + response = client.post( + '/', + headers=headers, + json=SendMessageRequest( + id='1', + params=MessageSendParams( + message=Message( + message_id='1', + role=Role.user, + parts=[Part(TextPart(text='hi'))], + ) + ), + ).model_dump(), + ) + response.raise_for_status() + + mock_handler.on_message_send.assert_called_once() + call_context = mock_handler.on_message_send.call_args[0][1] + assert call_context.requested_extensions == {'foo', 'bar'} + + def test_request_with_comma_separated_extensions_no_space( + self, client, mock_handler + ): + headers = [ + (HTTP_EXTENSION_HEADER, 'foo, bar'), + (HTTP_EXTENSION_HEADER, 'baz'), + ] + response = client.post( + '/', + headers=headers, + json=SendMessageRequest( + id='1', + params=MessageSendParams( + message=Message( + message_id='1', + role=Role.user, + parts=[Part(TextPart(text='hi'))], + ) + ), + ).model_dump(), + ) + response.raise_for_status() + + mock_handler.on_message_send.assert_called_once() + call_context = mock_handler.on_message_send.call_args[0][1] + assert call_context.requested_extensions == {'foo', 'bar', 'baz'} + + def test_request_with_multiple_extension_headers( + self, client, mock_handler + ): + headers = [ + (HTTP_EXTENSION_HEADER, 'foo'), + (HTTP_EXTENSION_HEADER, 'bar'), + ] + response = client.post( + '/', + headers=headers, + json=SendMessageRequest( + id='1', + params=MessageSendParams( + message=Message( + message_id='1', + role=Role.user, + parts=[Part(TextPart(text='hi'))], + ) + ), + ).model_dump(), + ) + response.raise_for_status() + + mock_handler.on_message_send.assert_called_once() + call_context = mock_handler.on_message_send.call_args[0][1] + assert call_context.requested_extensions == {'foo', 'bar'} + + def test_response_with_activated_extensions(self, client, mock_handler): + def side_effect(request, context: ServerCallContext): + context.activated_extensions.add('foo') + context.activated_extensions.add('baz') + return SendMessageResponse( + root=SendMessageSuccessResponse( + id='1', + result=Message( + message_id='test', + role=Role.agent, + parts=[Part(TextPart(text='response message'))], + ), + ) + ) + + mock_handler.on_message_send.side_effect = side_effect + + response = client.post( + '/', + json=SendMessageRequest( + id='1', + params=MessageSendParams( + message=Message( + message_id='1', + role=Role.user, + parts=[Part(TextPart(text='hi'))], + ) + ), + ).model_dump(), + ) + response.raise_for_status() + + assert response.status_code == 200 + assert HTTP_EXTENSION_HEADER in response.headers + assert set(response.headers[HTTP_EXTENSION_HEADER].split(', ')) == { + 'foo', + 'baz', + } + + if __name__ == '__main__': pytest.main([__file__]) diff --git a/tests/server/request_handlers/test_grpc_handler.py b/tests/server/request_handlers/test_grpc_handler.py index 852707c9..eb0a3459 100644 --- a/tests/server/request_handlers/test_grpc_handler.py +++ b/tests/server/request_handlers/test_grpc_handler.py @@ -1,10 +1,13 @@ -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, MagicMock import grpc +import grpc.aio import pytest from a2a import types +from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.grpc import a2a_pb2 +from a2a.server.context import ServerCallContext from a2a.server.request_handlers import GrpcHandler, RequestHandler from a2a.utils.errors import ServerError @@ -21,6 +24,7 @@ def mock_request_handler() -> AsyncMock: def mock_grpc_context() -> AsyncMock: context = AsyncMock(spec=grpc.aio.ServicerContext) context.abort = AsyncMock() + context.set_trailing_metadata = MagicMock() return context @@ -279,3 +283,117 @@ async def test_abort_context_error_mapping( call_args, _ = mock_grpc_context.abort.call_args assert call_args[0] == grpc_status_code assert error_message_part in call_args[1] + + +@pytest.mark.asyncio +class TestGrpcExtensions: + async def test_send_message_with_extensions( + self, + grpc_handler: GrpcHandler, + mock_request_handler: AsyncMock, + mock_grpc_context: AsyncMock, + ): + mock_grpc_context.invocation_metadata = grpc.aio.Metadata( + (HTTP_EXTENSION_HEADER, 'foo'), + (HTTP_EXTENSION_HEADER, 'bar'), + ) + + def side_effect(request, context: ServerCallContext): + context.activated_extensions.add('foo') + context.activated_extensions.add('baz') + return types.Task( + id='task-1', + contextId='ctx-1', + status=types.TaskStatus(state=types.TaskState.completed), + ) + + mock_request_handler.on_message_send.side_effect = side_effect + + await grpc_handler.SendMessage( + a2a_pb2.SendMessageRequest(), mock_grpc_context + ) + + mock_request_handler.on_message_send.assert_awaited_once() + call_context = mock_request_handler.on_message_send.call_args[0][1] + assert isinstance(call_context, ServerCallContext) + assert call_context.requested_extensions == {'foo', 'bar'} + + mock_grpc_context.set_trailing_metadata.assert_called_once() + called_metadata = ( + mock_grpc_context.set_trailing_metadata.call_args.args[0] + ) + assert set(called_metadata) == { + (HTTP_EXTENSION_HEADER, 'foo'), + (HTTP_EXTENSION_HEADER, 'baz'), + } + + async def test_send_message_with_comma_separated_extensions( + self, + grpc_handler: GrpcHandler, + mock_request_handler: AsyncMock, + mock_grpc_context: AsyncMock, + ): + mock_grpc_context.invocation_metadata = grpc.aio.Metadata( + (HTTP_EXTENSION_HEADER, 'foo ,, bar,'), + (HTTP_EXTENSION_HEADER, 'baz , bar'), + ) + mock_request_handler.on_message_send.return_value = types.Message( + messageId='1', + role=types.Role.agent, + parts=[types.TextPart(text='test')], + ) + + await grpc_handler.SendMessage( + a2a_pb2.SendMessageRequest(), mock_grpc_context + ) + + mock_request_handler.on_message_send.assert_awaited_once() + call_context = mock_request_handler.on_message_send.call_args[0][1] + assert isinstance(call_context, ServerCallContext) + assert call_context.requested_extensions == {'foo', 'bar', 'baz'} + + async def test_send_streaming_message_with_extensions( + self, + grpc_handler: GrpcHandler, + mock_request_handler: AsyncMock, + mock_grpc_context: AsyncMock, + ): + mock_grpc_context.invocation_metadata = grpc.aio.Metadata( + (HTTP_EXTENSION_HEADER, 'foo'), + (HTTP_EXTENSION_HEADER, 'bar'), + ) + + async def side_effect(request, context: ServerCallContext): + context.activated_extensions.add('foo') + context.activated_extensions.add('baz') + yield types.Task( + id='task-1', + contextId='ctx-1', + status=types.TaskStatus(state=types.TaskState.working), + ) + + mock_request_handler.on_message_send_stream.side_effect = side_effect + + results = [ + result + async for result in grpc_handler.SendStreamingMessage( + a2a_pb2.SendMessageRequest(), mock_grpc_context + ) + ] + assert results + + mock_request_handler.on_message_send_stream.assert_called_once() + call_context = mock_request_handler.on_message_send_stream.call_args[0][ + 1 + ] + assert isinstance(call_context, ServerCallContext) + assert call_context.requested_extensions == {'foo', 'bar'} + + mock_grpc_context.set_trailing_metadata.assert_called_once() + called_metadata = ( + mock_grpc_context.set_trailing_metadata.call_args.args[0] + ) + assert set(called_metadata) == { + (HTTP_EXTENSION_HEADER, 'foo'), + (HTTP_EXTENSION_HEADER, 'baz'), + }