From 56cb1f6c5d95c39e8fb49c5f6ee5318fe3572461 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Mon, 21 Jul 2025 19:13:11 +0000 Subject: [PATCH 01/13] Add server support for propagating extensions, both input and output. This commit adds support for extracting extensions requested by a client, and marking an extension as activated, which causes it to be returned to clients in a header. This is purely plumbing, no particular support for actually using extensions. --- src/a2a/extensions/common.py | 13 +++++ src/a2a/server/agent_execution/context.py | 18 +++++++ src/a2a/server/apps/jsonrpc/jsonrpc_app.py | 30 +++++++++--- src/a2a/server/context.py | 2 + .../server/request_handlers/grpc_handler.py | 47 ++++++++++++++++++- 5 files changed, 102 insertions(+), 8 deletions(-) create mode 100644 src/a2a/extensions/common.py diff --git a/src/a2a/extensions/common.py b/src/a2a/extensions/common.py new file mode 100644 index 00000000..009cab8a --- /dev/null +++ b/src/a2a/extensions/common.py @@ -0,0 +1,13 @@ +from a2a.types import AgentCard, AgentExtension + + +HTTP_EXTENSION_HEADER = 'X-A2A-Extensions' + + +def find_extension_by_uri(card: AgentCard, uri: str) -> AgentExtension | None: + """Find an AgentExtension in an AgentCard given a uri.""" + for ext in card.capabilities.extensions or []: + if ext.uri == uri: + return ext + + return None diff --git a/src/a2a/server/agent_execution/context.py b/src/a2a/server/agent_execution/context.py index 782d488b..2b1fa948 100644 --- a/src/a2a/server/agent_execution/context.py +++ b/src/a2a/server/agent_execution/context.py @@ -143,6 +143,24 @@ def metadata(self) -> dict[str, Any]: return {} return self._params.metadata or {} + def add_activated_extension(self, uri: str) -> None: + """Add an extension to the set of activated extensions for this request. + + This causes the extension to be indicated back to the client in the + response. + """ + if self._call_context: + self._call_context.activated_extensions.add(uri) + + @property + def requested_extensions(self) -> set[str]: + """Extensions that the client requested to activate.""" + return ( + self._call_context.requested_extensions + if self._call_context + else set() + ) + def _check_or_generate_task_id(self) -> None: """Ensures a task ID is present, generating one if necessary.""" if not self._params: diff --git a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py index 2591bb00..77aaebf9 100644 --- a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py +++ b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py @@ -19,6 +19,7 @@ from a2a.auth.user import UnauthenticatedUser from a2a.auth.user import User as A2AUser +from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.server.context import ServerCallContext from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler from a2a.server.request_handlers.request_handler import RequestHandler @@ -99,7 +100,15 @@ def build(self, request: Request) -> ServerCallContext: user = StarletteUserProxy(request.user) state['auth'] = request.auth state['headers'] = dict(request.headers) - return ServerCallContext(user=user, state=state) + return ServerCallContext( + user=user, + state=state, + requested_extensions={ + ext + for h in request.headers.getlist(HTTP_EXTENSION_HEADER) + for ext in h.split(', ') + }, + ) class JSONRPCApplication(ABC): @@ -281,7 +290,7 @@ async def _process_streaming_request( request_obj, context ) - return self._create_response(handler_result) + return self._create_response(context, handler_result) async def _process_non_streaming_request( self, @@ -353,10 +362,11 @@ async def _process_non_streaming_request( id=request_id, error=error ) - return self._create_response(handler_result) + return self._create_response(context, handler_result) def _create_response( self, + context: ServerCallContext, handler_result: ( AsyncGenerator[SendStreamingMessageResponse] | JSONRPCErrorResponse @@ -372,12 +382,16 @@ def _create_response( payloads. Args: + context: The ServerCallContext provided to the request handler. handler_result: The result from a request handler method. Can be an async generator for streaming or a Pydantic model for non-streaming. Returns: A Starlette JSONResponse or EventSourceResponse. """ + headers = {} + if exts := context.activated_extensions: + headers[HTTP_EXTENSION_HEADER] = ', '.join(exts) if isinstance(handler_result, AsyncGenerator): # Result is a stream of SendStreamingMessageResponse objects async def event_generator( @@ -386,17 +400,21 @@ async def event_generator( async for item in stream: yield {'data': item.root.model_dump_json(exclude_none=True)} - return EventSourceResponse(event_generator(handler_result)) + return EventSourceResponse( + event_generator(handler_result), headers=headers + ) if isinstance(handler_result, JSONRPCErrorResponse): return JSONResponse( handler_result.model_dump( mode='json', exclude_none=True, - ) + ), + headers=headers, ) return JSONResponse( - handler_result.root.model_dump(mode='json', exclude_none=True) + handler_result.root.model_dump(mode='json', exclude_none=True), + headers=headers, ) async def _handle_get_agent_card(self, request: Request) -> JSONResponse: diff --git a/src/a2a/server/context.py b/src/a2a/server/context.py index ce7f56bd..2a75c6d5 100644 --- a/src/a2a/server/context.py +++ b/src/a2a/server/context.py @@ -21,3 +21,5 @@ class ServerCallContext(BaseModel): state: State = Field(default={}) user: User = Field(default=UnauthenticatedUser()) + requested_extensions: set[str] = Field(default=set()) + activated_extensions: set[str] = Field(default=set()) diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index 987a0d6d..6f64813b 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -3,7 +3,7 @@ import logging from abc import ABC, abstractmethod -from collections.abc import AsyncIterable +from collections.abc import AsyncIterable, Sequence try: @@ -16,10 +16,13 @@ "'pip install a2a-sdk[grpc]'" ) from e +from grpc.aio import Metadata + import a2a.grpc.a2a_pb2_grpc as a2a_grpc from a2a import types from a2a.auth.user import UnauthenticatedUser +from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.grpc import a2a_pb2 from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler @@ -42,6 +45,25 @@ def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext: """Builds a ServerCallContext from a gRPC Request.""" +def _get_metadata_value( + context: grpc.aio.ServicerContext, key: str +) -> list[str]: + md = context.invocation_metadata + vs = [] + if isinstance(md, Metadata): + vs = [ + e if isinstance(e, str) else e.decode('utf-8') + for e in md.get_all(key) + ] + elif isinstance(md, Sequence): + vs = [ + e if isinstance(e, str) else e.decode('utf-8') + for (k, e) in md + if k == key.lower() + ] + return vs + + class DefaultCallContextBuilder(CallContextBuilder): """A default implementation of CallContextBuilder.""" @@ -51,7 +73,13 @@ def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext: state = {} with contextlib.suppress(Exception): state['grpc_context'] = context - return ServerCallContext(user=user, state=state) + return ServerCallContext( + user=user, + state=state, + requested_extensions=set( + _get_metadata_value(context, HTTP_EXTENSION_HEADER) + ), + ) class GrpcHandler(a2a_grpc.A2AServiceServicer): @@ -102,6 +130,7 @@ async def SendMessage( task_or_message = await self.request_handler.on_message_send( a2a_request, server_context ) + self._set_extension_metadata(context, server_context) return proto_utils.ToProto.task_or_message(task_or_message) except ServerError as e: await self.abort_context(e, context) @@ -140,6 +169,7 @@ async def SendStreamingMessage( a2a_request, server_context ): yield proto_utils.ToProto.stream_response(event) + self._set_extension_metadata(context, server_context) except ServerError as e: await self.abort_context(e, context) return @@ -371,3 +401,16 @@ async def abort_context( grpc.StatusCode.UNKNOWN, f'Unknown error type: {error.error}', ) + + def _set_extension_metadata( + self, + context: grpc.aio.ServicerContext, + server_context: ServerCallContext, + ) -> None: + if server_context.activated_extensions: + context.set_trailing_metadata( + [ + (HTTP_EXTENSION_HEADER, e) + for e in server_context.activated_extensions + ] + ) From a941547a7d3a8d0903cbed5885a7d3b25861c091 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Mon, 21 Jul 2025 19:15:51 +0000 Subject: [PATCH 02/13] Add tests related to server side extensions. --- tests/extensions/test_common.py | 38 ++++ tests/server/agent_execution/test_context.py | 14 ++ tests/server/apps/jsonrpc/test_jsonrpc_app.py | 170 +++++++++++++++++- .../request_handlers/test_grpc_handler.py | 92 +++++++++- 4 files changed, 306 insertions(+), 8 deletions(-) create mode 100644 tests/extensions/test_common.py diff --git a/tests/extensions/test_common.py b/tests/extensions/test_common.py new file mode 100644 index 00000000..4efef96d --- /dev/null +++ b/tests/extensions/test_common.py @@ -0,0 +1,38 @@ +from a2a.extensions.common import find_extension_by_uri +from a2a.types import AgentCard, AgentExtension, AgentCapabilities + + +def test_find_extension_by_uri(): + ext1 = AgentExtension(uri='foo', name='Foo', description='The Foo extension') + ext2 = AgentExtension(uri='bar', name='Bar', description='The Bar extension') + card = AgentCard( + agent_id='test-agent', + name='Test Agent', + description='Test Agent Description', + version='1.0', + url='http://test.com', + skills=[], + defaultInputModes=['text/plain'], + defaultOutputModes=['text/plain'], + capabilities=AgentCapabilities(extensions=[ext1, ext2]), + ) + + assert find_extension_by_uri(card, 'foo') == ext1 + assert find_extension_by_uri(card, 'bar') == ext2 + assert find_extension_by_uri(card, 'baz') is None + + +def test_find_extension_by_uri_no_extensions(): + card = AgentCard( + agent_id='test-agent', + name='Test Agent', + description='Test Agent Description', + version='1.0', + url='http://test.com', + skills=[], + defaultInputModes=['text/plain'], + defaultOutputModes=['text/plain'], + capabilities=AgentCapabilities(extensions=None), + ) + + assert find_extension_by_uri(card, 'foo') is None diff --git a/tests/server/agent_execution/test_context.py b/tests/server/agent_execution/test_context.py index aa72b5a6..fdb8aa00 100644 --- a/tests/server/agent_execution/test_context.py +++ b/tests/server/agent_execution/test_context.py @@ -5,6 +5,7 @@ import pytest from a2a.server.agent_execution import RequestContext +from a2a.server.context import ServerCallContext from a2a.types import ( Message, MessageSendParams, @@ -262,3 +263,16 @@ def test_init_with_context_id_and_existing_context_id_match( assert context.context_id == mock_task.contextId assert context.current_task == mock_task + + def test_extension_handling(self): + """Test extension handling in RequestContext.""" + call_context = ServerCallContext(requested_extensions={'foo', 'bar'}) + context = RequestContext(call_context=call_context) + + assert context.requested_extensions == {'foo', 'bar'} + + context.add_activated_extension('foo') + assert call_context.activated_extensions == {'foo'} + + context.add_activated_extension('baz') + assert call_context.activated_extensions == {'foo', 'baz'} diff --git a/tests/server/apps/jsonrpc/test_jsonrpc_app.py b/tests/server/apps/jsonrpc/test_jsonrpc_app.py index 6670e40b..09a57060 100644 --- a/tests/server/apps/jsonrpc/test_jsonrpc_app.py +++ b/tests/server/apps/jsonrpc/test_jsonrpc_app.py @@ -1,7 +1,8 @@ -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock import pytest - +from starlette.applications import Starlette +from starlette.testclient import TestClient # Attempt to import StarletteBaseUser, fallback to MagicMock if not available try: @@ -9,15 +10,27 @@ except ImportError: StarletteBaseUser = MagicMock() # type: ignore +from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.server.apps.jsonrpc.jsonrpc_app import ( - JSONRPCApplication, # Still needed for JSONRPCApplication default constructor arg + JSONRPCApplication, StarletteUserProxy, ) +from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication +from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import ( - RequestHandler, # For mock spec + RequestHandler, +) # For mock spec +from a2a.types import ( + AgentCapabilities, + AgentCard, + Message, + MessageSendParams, + Role, + SendMessageRequest, + SendMessageResponse, + SendMessageSuccessResponse, + TextPart, ) -from a2a.types import AgentCard # For mock spec - # --- StarletteUserProxy Tests --- @@ -69,6 +82,7 @@ def test_jsonrpc_app_build_method_abstract_raises_typeerror( mock_agent_card.url = 'http://mockurl.com' # Ensure 'supportsAuthenticatedExtendedCard' attribute exists mock_agent_card.supportsAuthenticatedExtendedCard = False + mock_agent_card.capabilities = AgentCapabilities(streaming=True) # This will fail at definition time if an abstract method is not implemented with pytest.raises( @@ -86,5 +100,149 @@ def some_other_method(self): ) +class TestJSONRPCExtensions: + @pytest.fixture + def mock_handler(self): + handler = AsyncMock(spec=RequestHandler) + handler.on_message_send.return_value = SendMessageResponse( + root=SendMessageSuccessResponse( + id='1', + result=Message( + messageId='test', + role=Role.agent, + parts=[TextPart(text='response message')], + ), + ) + ) + return handler + + @pytest.fixture + def test_app(self, mock_handler): + mock_agent_card = MagicMock(spec=AgentCard) + mock_agent_card.url = 'http://mockurl.com' + mock_agent_card.supportsAuthenticatedExtendedCard = False + + return A2AStarletteApplication( + agent_card=mock_agent_card, http_handler=mock_handler + ) + + @pytest.fixture + def client(self, test_app): + return TestClient(test_app.build()) + + def test_request_with_single_extension(self, client, mock_handler): + headers = {HTTP_EXTENSION_HEADER: 'foo'} + response = client.post( + '/', + headers=headers, + json=SendMessageRequest( + id='1', + params=MessageSendParams( + message=Message( + messageId='1', + role=Role.user, + parts=[TextPart(text='hi')], + ) + ), + ).model_dump(), + ) + response.raise_for_status() + + mock_handler.on_message_send.assert_called_once() + call_context = mock_handler.on_message_send.call_args[0][1] + assert isinstance(call_context, ServerCallContext) + assert call_context.requested_extensions == {'foo'} + + def test_request_with_comma_separated_extensions( + self, client, mock_handler + ): + headers = {HTTP_EXTENSION_HEADER: 'foo, bar'} + response = client.post( + '/', + headers=headers, + json=SendMessageRequest( + id='1', + params=MessageSendParams( + message=Message( + messageId='1', + role=Role.user, + parts=[TextPart(text='hi')], + ) + ), + ).model_dump(), + ) + response.raise_for_status() + + mock_handler.on_message_send.assert_called_once() + call_context = mock_handler.on_message_send.call_args[0][1] + assert call_context.requested_extensions == {'foo', 'bar'} + + def test_request_with_multiple_extension_headers( + self, client, mock_handler + ): + headers = [ + (HTTP_EXTENSION_HEADER, 'foo'), + (HTTP_EXTENSION_HEADER, 'bar'), + ] + response = client.post( + '/', + headers=headers, + json=SendMessageRequest( + id='1', + params=MessageSendParams( + message=Message( + messageId='1', + role=Role.user, + parts=[TextPart(text='hi')], + ) + ), + ).model_dump(), + ) + response.raise_for_status() + + mock_handler.on_message_send.assert_called_once() + call_context = mock_handler.on_message_send.call_args[0][1] + assert call_context.requested_extensions == {'foo', 'bar'} + + def test_response_with_activated_extensions(self, client, mock_handler): + def side_effect(request, context: ServerCallContext): + context.activated_extensions.add('foo') + context.activated_extensions.add('baz') + return SendMessageResponse( + root=SendMessageSuccessResponse( + id='1', + result=Message( + messageId='test', + role=Role.agent, + parts=[TextPart(text='response message')], + ), + ) + ) + + mock_handler.on_message_send.side_effect = side_effect + + response = client.post( + '/', + json=SendMessageRequest( + id='1', + params=MessageSendParams( + message=Message( + messageId='1', + role=Role.user, + parts=[TextPart(text='hi')], + ) + ), + ).model_dump(), + ) + response.raise_for_status() + + assert response.status_code == 200 + assert HTTP_EXTENSION_HEADER in response.headers + assert set(response.headers[HTTP_EXTENSION_HEADER].split(', ')) == { + 'foo', + 'baz', + } + + if __name__ == '__main__': pytest.main([__file__]) diff --git a/tests/server/request_handlers/test_grpc_handler.py b/tests/server/request_handlers/test_grpc_handler.py index e1b8b940..523f8ba4 100644 --- a/tests/server/request_handlers/test_grpc_handler.py +++ b/tests/server/request_handlers/test_grpc_handler.py @@ -1,14 +1,15 @@ -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, MagicMock, patch import grpc import pytest from a2a import types +from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.grpc import a2a_pb2 +from a2a.server.context import ServerCallContext from a2a.server.request_handlers import GrpcHandler, RequestHandler from a2a.utils.errors import ServerError - # --- Fixtures --- @@ -21,6 +22,8 @@ def mock_request_handler() -> AsyncMock: def mock_grpc_context() -> AsyncMock: context = AsyncMock(spec=grpc.aio.ServicerContext) context.abort = AsyncMock() + context.invocation_metadata = MagicMock() + context.set_trailing_metadata = MagicMock() return context @@ -279,3 +282,88 @@ async def test_abort_context_error_mapping( call_args, _ = mock_grpc_context.abort.call_args assert call_args[0] == grpc_status_code assert error_message_part in call_args[1] + + +@pytest.mark.asyncio +class TestGrpcExtensions: + @patch( + 'a2a.server.request_handlers.grpc_handler.DefaultCallContextBuilder.build' + ) + async def test_send_message_with_extensions( + self, + mock_build, + grpc_handler: GrpcHandler, + mock_request_handler: AsyncMock, + mock_grpc_context: AsyncMock, + ): + mock_build.return_value = ServerCallContext( + requested_extensions={'foo', 'bar'} + ) + + def side_effect(request, context: ServerCallContext): + context.activated_extensions.add('foo') + context.activated_extensions.add('baz') + return types.Task( + id='task-1', + contextId='ctx-1', + status=types.TaskStatus(state=types.TaskState.completed), + ) + + mock_request_handler.on_message_send.side_effect = side_effect + + await grpc_handler.SendMessage( + a2a_pb2.SendMessageRequest(), mock_grpc_context + ) + + mock_request_handler.on_message_send.assert_awaited_once() + call_context = mock_request_handler.on_message_send.call_args[0][1] + assert isinstance(call_context, ServerCallContext) + assert call_context.requested_extensions == {'foo', 'bar'} + + mock_grpc_context.set_trailing_metadata.assert_called_once_with( + [(HTTP_EXTENSION_HEADER, 'foo'), (HTTP_EXTENSION_HEADER, 'baz')] + ) + + @patch( + 'a2a.server.request_handlers.grpc_handler.DefaultCallContextBuilder.build' + ) + async def test_send_streaming_message_with_extensions( + self, + mock_build, + grpc_handler: GrpcHandler, + mock_request_handler: AsyncMock, + mock_grpc_context: AsyncMock, + ): + mock_build.return_value = ServerCallContext( + requested_extensions={'foo', 'bar'} + ) + + async def side_effect(request, context: ServerCallContext): + context.activated_extensions.add('foo') + context.activated_extensions.add('baz') + yield types.Task( + id='task-1', + contextId='ctx-1', + status=types.TaskStatus(state=types.TaskState.working), + ) + + mock_request_handler.on_message_send_stream.side_effect = side_effect + + results = [ + result + async for result in grpc_handler.SendStreamingMessage( + a2a_pb2.SendMessageRequest(), mock_grpc_context + ) + ] + assert results + + mock_request_handler.on_message_send_stream.assert_called_once() + call_context = mock_request_handler.on_message_send_stream.call_args[0][ + 1 + ] + assert isinstance(call_context, ServerCallContext) + assert call_context.requested_extensions == {'foo', 'bar'} + + mock_grpc_context.set_trailing_metadata.assert_called_once_with( + [(HTTP_EXTENSION_HEADER, 'foo'), (HTTP_EXTENSION_HEADER, 'baz')] + ) From b5007421fd14fa2b94c26e7f7d4ef9e6369f6483 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Mon, 21 Jul 2025 19:29:41 +0000 Subject: [PATCH 03/13] Ruff format --- tests/extensions/test_common.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/extensions/test_common.py b/tests/extensions/test_common.py index 4efef96d..a769f647 100644 --- a/tests/extensions/test_common.py +++ b/tests/extensions/test_common.py @@ -3,8 +3,12 @@ def test_find_extension_by_uri(): - ext1 = AgentExtension(uri='foo', name='Foo', description='The Foo extension') - ext2 = AgentExtension(uri='bar', name='Bar', description='The Bar extension') + ext1 = AgentExtension( + uri='foo', name='Foo', description='The Foo extension' + ) + ext2 = AgentExtension( + uri='bar', name='Bar', description='The Bar extension' + ) card = AgentCard( agent_id='test-agent', name='Test Agent', From 091499cac12a0d1b22c604937946bd5aa9cce5a6 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Mon, 21 Jul 2025 19:32:55 +0000 Subject: [PATCH 04/13] Apply suggestions from code review Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/a2a/server/context.py | 4 ++-- tests/server/request_handlers/test_grpc_handler.py | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/a2a/server/context.py b/src/a2a/server/context.py index 2a75c6d5..2b34cefe 100644 --- a/src/a2a/server/context.py +++ b/src/a2a/server/context.py @@ -21,5 +21,5 @@ class ServerCallContext(BaseModel): state: State = Field(default={}) user: User = Field(default=UnauthenticatedUser()) - requested_extensions: set[str] = Field(default=set()) - activated_extensions: set[str] = Field(default=set()) + requested_extensions: set[str] = Field(default_factory=set) + activated_extensions: set[str] = Field(default_factory=set) diff --git a/tests/server/request_handlers/test_grpc_handler.py b/tests/server/request_handlers/test_grpc_handler.py index 523f8ba4..74db9748 100644 --- a/tests/server/request_handlers/test_grpc_handler.py +++ b/tests/server/request_handlers/test_grpc_handler.py @@ -320,9 +320,9 @@ def side_effect(request, context: ServerCallContext): assert isinstance(call_context, ServerCallContext) assert call_context.requested_extensions == {'foo', 'bar'} - mock_grpc_context.set_trailing_metadata.assert_called_once_with( - [(HTTP_EXTENSION_HEADER, 'foo'), (HTTP_EXTENSION_HEADER, 'baz')] - ) + mock_grpc_context.set_trailing_metadata.assert_called_once() + called_metadata = mock_grpc_context.set_trailing_metadata.call_args.args[0] + assert set(called_metadata) == {(HTTP_EXTENSION_HEADER, 'foo'), (HTTP_EXTENSION_HEADER, 'baz')} @patch( 'a2a.server.request_handlers.grpc_handler.DefaultCallContextBuilder.build' @@ -364,6 +364,6 @@ async def side_effect(request, context: ServerCallContext): assert isinstance(call_context, ServerCallContext) assert call_context.requested_extensions == {'foo', 'bar'} - mock_grpc_context.set_trailing_metadata.assert_called_once_with( - [(HTTP_EXTENSION_HEADER, 'foo'), (HTTP_EXTENSION_HEADER, 'baz')] - ) + mock_grpc_context.set_trailing_metadata.assert_called_once() + called_metadata = mock_grpc_context.set_trailing_metadata.call_args.args[0] + assert set(called_metadata) == {(HTTP_EXTENSION_HEADER, 'foo'), (HTTP_EXTENSION_HEADER, 'baz')} From 75849c05feb81076ca110cc7b732d6b053a9bbff Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Mon, 21 Jul 2025 19:35:09 +0000 Subject: [PATCH 05/13] Split headers by single space, handle trimming --- src/a2a/server/apps/jsonrpc/jsonrpc_app.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py index 77aaebf9..094126a6 100644 --- a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py +++ b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py @@ -104,9 +104,10 @@ def build(self, request: Request) -> ServerCallContext: user=user, state=state, requested_extensions={ - ext + stripped for h in request.headers.getlist(HTTP_EXTENSION_HEADER) - for ext in h.split(', ') + for ext in h.split(',') + if (stripped := ext.strip()) }, ) From b141ecfd405d0db33fe8700282c18fa7ce2e5b39 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Mon, 21 Jul 2025 20:30:55 +0000 Subject: [PATCH 06/13] Refactor to use common header value splitter --- src/a2a/extensions/common.py | 14 +++++ src/a2a/server/apps/jsonrpc/jsonrpc_app.py | 14 ++--- .../server/request_handlers/grpc_handler.py | 7 ++- tests/extensions/test_common.py | 36 ++++++++--- tests/server/apps/jsonrpc/test_jsonrpc_app.py | 27 ++++++++ .../request_handlers/test_grpc_handler.py | 63 ++++++++++++++----- 6 files changed, 125 insertions(+), 36 deletions(-) diff --git a/src/a2a/extensions/common.py b/src/a2a/extensions/common.py index 009cab8a..2f752caa 100644 --- a/src/a2a/extensions/common.py +++ b/src/a2a/extensions/common.py @@ -4,6 +4,20 @@ HTTP_EXTENSION_HEADER = 'X-A2A-Extensions' +def get_requested_extensions(values: list[str]) -> set[str]: + """Get the set of requested extensions from an input list. + + This handles the list containing potentially comma-separated values, as + occurs when using a list in an HTTP header. + """ + return { + stripped + for v in values + for ext in v.split(',') + if (stripped := ext.strip()) + } + + def find_extension_by_uri(card: AgentCard, uri: str) -> AgentExtension | None: """Find an AgentExtension in an AgentCard given a uri.""" for ext in card.capabilities.extensions or []: diff --git a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py index 094126a6..2846a15c 100644 --- a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py +++ b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py @@ -19,7 +19,10 @@ from a2a.auth.user import UnauthenticatedUser from a2a.auth.user import User as A2AUser -from a2a.extensions.common import HTTP_EXTENSION_HEADER +from a2a.extensions.common import ( + HTTP_EXTENSION_HEADER, + get_requested_extensions, +) from a2a.server.context import ServerCallContext from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler from a2a.server.request_handlers.request_handler import RequestHandler @@ -103,12 +106,9 @@ def build(self, request: Request) -> ServerCallContext: return ServerCallContext( user=user, state=state, - requested_extensions={ - stripped - for h in request.headers.getlist(HTTP_EXTENSION_HEADER) - for ext in h.split(',') - if (stripped := ext.strip()) - }, + requested_extensions=get_requested_extensions( + request.headers.getlist(HTTP_EXTENSION_HEADER) + ), ) diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index 6f64813b..2eb0f4c9 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -22,7 +22,10 @@ from a2a import types from a2a.auth.user import UnauthenticatedUser -from a2a.extensions.common import HTTP_EXTENSION_HEADER +from a2a.extensions.common import ( + HTTP_EXTENSION_HEADER, + get_requested_extensions, +) from a2a.grpc import a2a_pb2 from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler @@ -76,7 +79,7 @@ def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext: return ServerCallContext( user=user, state=state, - requested_extensions=set( + requested_extensions=get_requested_extensions( _get_metadata_value(context, HTTP_EXTENSION_HEADER) ), ) diff --git a/tests/extensions/test_common.py b/tests/extensions/test_common.py index a769f647..967689e4 100644 --- a/tests/extensions/test_common.py +++ b/tests/extensions/test_common.py @@ -1,16 +1,33 @@ -from a2a.extensions.common import find_extension_by_uri -from a2a.types import AgentCard, AgentExtension, AgentCapabilities +from a2a.extensions.common import ( + find_extension_by_uri, + get_requested_extensions, +) +from a2a.types import AgentCapabilities, AgentCard, AgentExtension + + +def test_get_requested_extensions(): + assert get_requested_extensions([]) == set() + assert get_requested_extensions(['foo']) == {'foo'} + assert get_requested_extensions(['foo', 'bar']) == {'foo', 'bar'} + assert get_requested_extensions(['foo, bar']) == {'foo', 'bar'} + assert get_requested_extensions(['foo,bar']) == {'foo', 'bar'} + assert get_requested_extensions(['foo', 'bar,baz']) == {'foo', 'bar', 'baz'} + assert get_requested_extensions(['foo,, bar', 'baz']) == { + 'foo', + 'bar', + 'baz', + } + assert get_requested_extensions([' foo , bar ', 'baz']) == { + 'foo', + 'bar', + 'baz', + } def test_find_extension_by_uri(): - ext1 = AgentExtension( - uri='foo', name='Foo', description='The Foo extension' - ) - ext2 = AgentExtension( - uri='bar', name='Bar', description='The Bar extension' - ) + ext1 = AgentExtension(uri='foo', description='The Foo extension') + ext2 = AgentExtension(uri='bar', description='The Bar extension') card = AgentCard( - agent_id='test-agent', name='Test Agent', description='Test Agent Description', version='1.0', @@ -28,7 +45,6 @@ def test_find_extension_by_uri(): def test_find_extension_by_uri_no_extensions(): card = AgentCard( - agent_id='test-agent', name='Test Agent', description='Test Agent Description', version='1.0', diff --git a/tests/server/apps/jsonrpc/test_jsonrpc_app.py b/tests/server/apps/jsonrpc/test_jsonrpc_app.py index 09a57060..0443ecec 100644 --- a/tests/server/apps/jsonrpc/test_jsonrpc_app.py +++ b/tests/server/apps/jsonrpc/test_jsonrpc_app.py @@ -177,6 +177,33 @@ def test_request_with_comma_separated_extensions( call_context = mock_handler.on_message_send.call_args[0][1] assert call_context.requested_extensions == {'foo', 'bar'} + def test_request_with_comma_separated_extensions_no_space( + self, client, mock_handler + ): + headers = [ + (HTTP_EXTENSION_HEADER, 'foo, bar'), + (HTTP_EXTENSION_HEADER, 'baz'), + ] + response = client.post( + '/', + headers=headers, + json=SendMessageRequest( + id='1', + params=MessageSendParams( + message=Message( + messageId='1', + role=Role.user, + parts=[TextPart(text='hi')], + ) + ), + ).model_dump(), + ) + response.raise_for_status() + + mock_handler.on_message_send.assert_called_once() + call_context = mock_handler.on_message_send.call_args[0][1] + assert call_context.requested_extensions == {'foo', 'bar', 'baz'} + def test_request_with_multiple_extension_headers( self, client, mock_handler ): diff --git a/tests/server/request_handlers/test_grpc_handler.py b/tests/server/request_handlers/test_grpc_handler.py index 74db9748..f38b7a65 100644 --- a/tests/server/request_handlers/test_grpc_handler.py +++ b/tests/server/request_handlers/test_grpc_handler.py @@ -1,6 +1,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import grpc +import grpc.aio import pytest from a2a import types @@ -22,7 +23,6 @@ def mock_request_handler() -> AsyncMock: def mock_grpc_context() -> AsyncMock: context = AsyncMock(spec=grpc.aio.ServicerContext) context.abort = AsyncMock() - context.invocation_metadata = MagicMock() context.set_trailing_metadata = MagicMock() return context @@ -286,18 +286,15 @@ async def test_abort_context_error_mapping( @pytest.mark.asyncio class TestGrpcExtensions: - @patch( - 'a2a.server.request_handlers.grpc_handler.DefaultCallContextBuilder.build' - ) async def test_send_message_with_extensions( self, - mock_build, grpc_handler: GrpcHandler, mock_request_handler: AsyncMock, mock_grpc_context: AsyncMock, ): - mock_build.return_value = ServerCallContext( - requested_extensions={'foo', 'bar'} + mock_grpc_context.invocation_metadata = grpc.aio.Metadata( + (HTTP_EXTENSION_HEADER, 'foo'), + (HTTP_EXTENSION_HEADER, 'bar'), ) def side_effect(request, context: ServerCallContext): @@ -321,21 +318,48 @@ def side_effect(request, context: ServerCallContext): assert call_context.requested_extensions == {'foo', 'bar'} mock_grpc_context.set_trailing_metadata.assert_called_once() - called_metadata = mock_grpc_context.set_trailing_metadata.call_args.args[0] - assert set(called_metadata) == {(HTTP_EXTENSION_HEADER, 'foo'), (HTTP_EXTENSION_HEADER, 'baz')} + called_metadata = ( + mock_grpc_context.set_trailing_metadata.call_args.args[0] + ) + assert set(called_metadata) == { + (HTTP_EXTENSION_HEADER, 'foo'), + (HTTP_EXTENSION_HEADER, 'baz'), + } + + async def test_send_message_with_comma_separated_extensions( + self, + grpc_handler: GrpcHandler, + mock_request_handler: AsyncMock, + mock_grpc_context: AsyncMock, + ): + mock_grpc_context.invocation_metadata = grpc.aio.Metadata( + (HTTP_EXTENSION_HEADER, 'foo ,, bar,'), + (HTTP_EXTENSION_HEADER, 'baz , bar'), + ) + mock_request_handler.on_message_send.return_value = types.Message( + messageId='1', + role=types.Role.agent, + parts=[types.TextPart(text='test')], + ) + + await grpc_handler.SendMessage( + a2a_pb2.SendMessageRequest(), mock_grpc_context + ) + + mock_request_handler.on_message_send.assert_awaited_once() + call_context = mock_request_handler.on_message_send.call_args[0][1] + assert isinstance(call_context, ServerCallContext) + assert call_context.requested_extensions == {'foo', 'bar', 'baz'} - @patch( - 'a2a.server.request_handlers.grpc_handler.DefaultCallContextBuilder.build' - ) async def test_send_streaming_message_with_extensions( self, - mock_build, grpc_handler: GrpcHandler, mock_request_handler: AsyncMock, mock_grpc_context: AsyncMock, ): - mock_build.return_value = ServerCallContext( - requested_extensions={'foo', 'bar'} + mock_grpc_context.invocation_metadata = grpc.aio.Metadata( + (HTTP_EXTENSION_HEADER, 'foo'), + (HTTP_EXTENSION_HEADER, 'bar'), ) async def side_effect(request, context: ServerCallContext): @@ -365,5 +389,10 @@ async def side_effect(request, context: ServerCallContext): assert call_context.requested_extensions == {'foo', 'bar'} mock_grpc_context.set_trailing_metadata.assert_called_once() - called_metadata = mock_grpc_context.set_trailing_metadata.call_args.args[0] - assert set(called_metadata) == {(HTTP_EXTENSION_HEADER, 'foo'), (HTTP_EXTENSION_HEADER, 'baz')} + called_metadata = ( + mock_grpc_context.set_trailing_metadata.call_args.args[0] + ) + assert set(called_metadata) == { + (HTTP_EXTENSION_HEADER, 'foo'), + (HTTP_EXTENSION_HEADER, 'baz'), + } From c27d7f56ad18fc452deb63096d089f8fd55e2f96 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Mon, 21 Jul 2025 20:37:12 +0000 Subject: [PATCH 07/13] Fix bad merge --- tests/server/apps/jsonrpc/test_jsonrpc_app.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/server/apps/jsonrpc/test_jsonrpc_app.py b/tests/server/apps/jsonrpc/test_jsonrpc_app.py index f8fff1e9..227a2648 100644 --- a/tests/server/apps/jsonrpc/test_jsonrpc_app.py +++ b/tests/server/apps/jsonrpc/test_jsonrpc_app.py @@ -82,7 +82,6 @@ def test_jsonrpc_app_build_method_abstract_raises_typeerror( mock_agent_card.url = 'http://mockurl.com' # Ensure 'supportsAuthenticatedExtendedCard' attribute exists mock_agent_card.supports_authenticated_extended_card = False - mock_agent_card.capabilities = AgentCapabilities(streaming=True) # This will fail at definition time if an abstract method is not implemented with pytest.raises( @@ -120,7 +119,7 @@ def mock_handler(self): def test_app(self, mock_handler): mock_agent_card = MagicMock(spec=AgentCard) mock_agent_card.url = 'http://mockurl.com' - mock_agent_card.supportsAuthenticatedExtendedCard = False + mock_agent_card.supports_authenticated_extended_card = False return A2AStarletteApplication( agent_card=mock_agent_card, http_handler=mock_handler From d279515038406f30114ce9a2575ec5c86bed878e Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Mon, 21 Jul 2025 20:43:48 +0000 Subject: [PATCH 08/13] Refactors from review: sort extensions in response, small method refactor --- src/a2a/server/apps/jsonrpc/jsonrpc_app.py | 2 +- .../server/request_handlers/grpc_handler.py | 18 ++++++------------ 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py index e4fca1a7..e149b4d1 100644 --- a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py +++ b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py @@ -392,7 +392,7 @@ def _create_response( """ headers = {} if exts := context.activated_extensions: - headers[HTTP_EXTENSION_HEADER] = ', '.join(exts) + headers[HTTP_EXTENSION_HEADER] = ', '.join(sorted(exts)) if isinstance(handler_result, AsyncGenerator): # Result is a stream of SendStreamingMessageResponse objects async def event_generator( diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index b3349f7c..e1c9c8a2 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -52,19 +52,13 @@ def _get_metadata_value( context: grpc.aio.ServicerContext, key: str ) -> list[str]: md = context.invocation_metadata - vs = [] + raw_values: list[str | bytes] = [] if isinstance(md, Metadata): - vs = [ - e if isinstance(e, str) else e.decode('utf-8') - for e in md.get_all(key) - ] + raw_values = md.get_all(key) elif isinstance(md, Sequence): - vs = [ - e if isinstance(e, str) else e.decode('utf-8') - for (k, e) in md - if k == key.lower() - ] - return vs + lower_key = key.lower() + raw_values = [e for (k, e) in md if k == lower_key] + return [e if isinstance(e, str) else e.decode('utf-8') for e in raw_values] class DefaultCallContextBuilder(CallContextBuilder): @@ -414,6 +408,6 @@ def _set_extension_metadata( context.set_trailing_metadata( [ (HTTP_EXTENSION_HEADER, e) - for e in server_context.activated_extensions + for e in sorted(server_context.activated_extensions) ] ) From edf38c1504a68fc6d25e3118976321969baf21e2 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Mon, 21 Jul 2025 20:56:37 +0000 Subject: [PATCH 09/13] Case insensitive comparison of grpc metadata keys --- src/a2a/server/request_handlers/grpc_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index e1c9c8a2..409dacdf 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -57,7 +57,7 @@ def _get_metadata_value( raw_values = md.get_all(key) elif isinstance(md, Sequence): lower_key = key.lower() - raw_values = [e for (k, e) in md if k == lower_key] + raw_values = [e for (k, e) in md if k.lower() == lower_key] return [e if isinstance(e, str) else e.decode('utf-8') for e in raw_values] From 93675e17f03c71b1951a32dcedbe3b898ca86ef9 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Mon, 21 Jul 2025 21:29:38 +0000 Subject: [PATCH 10/13] Rename fields to camel_case, fix grpc import --- .../server/request_handlers/grpc_handler.py | 7 +---- tests/extensions/test_common.py | 8 ++--- tests/server/apps/jsonrpc/test_jsonrpc_app.py | 31 +++++++++---------- 3 files changed, 20 insertions(+), 26 deletions(-) diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index 409dacdf..aaaef8ef 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -1,14 +1,13 @@ # ruff: noqa: N802 import contextlib import logging - from abc import ABC, abstractmethod from collections.abc import AsyncIterable, Sequence - try: import grpc import grpc.aio + from grpc.aio import Metadata except ImportError as e: raise ImportError( 'GrpcHandler requires grpcio and grpcio-tools to be installed. ' @@ -16,10 +15,7 @@ "'pip install a2a-sdk[grpc]'" ) from e -from grpc.aio import Metadata - import a2a.grpc.a2a_pb2_grpc as a2a_grpc - from a2a import types from a2a.auth.user import UnauthenticatedUser from a2a.extensions.common import ( @@ -34,7 +30,6 @@ from a2a.utils.errors import ServerError from a2a.utils.helpers import validate, validate_async_generator - logger = logging.getLogger(__name__) # For now we use a trivial wrapper on the grpc context object diff --git a/tests/extensions/test_common.py b/tests/extensions/test_common.py index 967689e4..137e64c9 100644 --- a/tests/extensions/test_common.py +++ b/tests/extensions/test_common.py @@ -33,8 +33,8 @@ def test_find_extension_by_uri(): version='1.0', url='http://test.com', skills=[], - defaultInputModes=['text/plain'], - defaultOutputModes=['text/plain'], + default_input_modes=['text/plain'], + default_output_modes=['text/plain'], capabilities=AgentCapabilities(extensions=[ext1, ext2]), ) @@ -50,8 +50,8 @@ def test_find_extension_by_uri_no_extensions(): version='1.0', url='http://test.com', skills=[], - defaultInputModes=['text/plain'], - defaultOutputModes=['text/plain'], + default_input_modes=['text/plain'], + default_output_modes=['text/plain'], capabilities=AgentCapabilities(extensions=None), ) diff --git a/tests/server/apps/jsonrpc/test_jsonrpc_app.py b/tests/server/apps/jsonrpc/test_jsonrpc_app.py index 227a2648..7c6bf872 100644 --- a/tests/server/apps/jsonrpc/test_jsonrpc_app.py +++ b/tests/server/apps/jsonrpc/test_jsonrpc_app.py @@ -1,7 +1,6 @@ from unittest.mock import AsyncMock, MagicMock import pytest -from starlette.applications import Starlette from starlette.testclient import TestClient # Attempt to import StarletteBaseUser, fallback to MagicMock if not available @@ -21,10 +20,10 @@ RequestHandler, ) # For mock spec from a2a.types import ( - AgentCapabilities, AgentCard, Message, MessageSendParams, + Part, Role, SendMessageRequest, SendMessageResponse, @@ -107,9 +106,9 @@ def mock_handler(self): root=SendMessageSuccessResponse( id='1', result=Message( - messageId='test', + message_id='test', role=Role.agent, - parts=[TextPart(text='response message')], + parts=[Part(TextPart(text='response message'))], ), ) ) @@ -138,9 +137,9 @@ def test_request_with_single_extension(self, client, mock_handler): id='1', params=MessageSendParams( message=Message( - messageId='1', + message_id='1', role=Role.user, - parts=[TextPart(text='hi')], + parts=[Part(TextPart(text='hi'))], ) ), ).model_dump(), @@ -163,9 +162,9 @@ def test_request_with_comma_separated_extensions( id='1', params=MessageSendParams( message=Message( - messageId='1', + message_id='1', role=Role.user, - parts=[TextPart(text='hi')], + parts=[Part(TextPart(text='hi'))], ) ), ).model_dump(), @@ -190,9 +189,9 @@ def test_request_with_comma_separated_extensions_no_space( id='1', params=MessageSendParams( message=Message( - messageId='1', + message_id='1', role=Role.user, - parts=[TextPart(text='hi')], + parts=[Part(TextPart(text='hi'))], ) ), ).model_dump(), @@ -217,9 +216,9 @@ def test_request_with_multiple_extension_headers( id='1', params=MessageSendParams( message=Message( - messageId='1', + message_id='1', role=Role.user, - parts=[TextPart(text='hi')], + parts=[Part(TextPart(text='hi'))], ) ), ).model_dump(), @@ -238,9 +237,9 @@ def side_effect(request, context: ServerCallContext): root=SendMessageSuccessResponse( id='1', result=Message( - messageId='test', + message_id='test', role=Role.agent, - parts=[TextPart(text='response message')], + parts=[Part(TextPart(text='response message'))], ), ) ) @@ -253,9 +252,9 @@ def side_effect(request, context: ServerCallContext): id='1', params=MessageSendParams( message=Message( - messageId='1', + message_id='1', role=Role.user, - parts=[TextPart(text='hi')], + parts=[Part(TextPart(text='hi'))], ) ), ).model_dump(), From ea78b6b89c88aabb18071f2628119fd3b35c2409 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Mon, 21 Jul 2025 21:30:59 +0000 Subject: [PATCH 11/13] uv run ruff check --fix --- src/a2a/server/request_handlers/grpc_handler.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index aaaef8ef..2761ed33 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -1,12 +1,15 @@ # ruff: noqa: N802 import contextlib import logging + from abc import ABC, abstractmethod from collections.abc import AsyncIterable, Sequence + try: import grpc import grpc.aio + from grpc.aio import Metadata except ImportError as e: raise ImportError( @@ -16,6 +19,7 @@ ) from e import a2a.grpc.a2a_pb2_grpc as a2a_grpc + from a2a import types from a2a.auth.user import UnauthenticatedUser from a2a.extensions.common import ( @@ -30,6 +34,7 @@ from a2a.utils.errors import ServerError from a2a.utils.helpers import validate, validate_async_generator + logger = logging.getLogger(__name__) # For now we use a trivial wrapper on the grpc context object From 772656c2d4e0904da9995c6eeda3048aa1b2e859 Mon Sep 17 00:00:00 2001 From: Holt Skinner Date: Tue, 22 Jul 2025 15:21:56 +0100 Subject: [PATCH 12/13] Formatting --- tests/server/apps/jsonrpc/test_jsonrpc_app.py | 3 +++ tests/server/request_handlers/test_grpc_handler.py | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/server/apps/jsonrpc/test_jsonrpc_app.py b/tests/server/apps/jsonrpc/test_jsonrpc_app.py index 7c6bf872..62dc0fe2 100644 --- a/tests/server/apps/jsonrpc/test_jsonrpc_app.py +++ b/tests/server/apps/jsonrpc/test_jsonrpc_app.py @@ -1,8 +1,10 @@ from unittest.mock import AsyncMock, MagicMock import pytest + from starlette.testclient import TestClient + # Attempt to import StarletteBaseUser, fallback to MagicMock if not available try: from starlette.authentication import BaseUser as StarletteBaseUser @@ -31,6 +33,7 @@ TextPart, ) + # --- StarletteUserProxy Tests --- diff --git a/tests/server/request_handlers/test_grpc_handler.py b/tests/server/request_handlers/test_grpc_handler.py index 9ed32a68..eb0a3459 100644 --- a/tests/server/request_handlers/test_grpc_handler.py +++ b/tests/server/request_handlers/test_grpc_handler.py @@ -1,4 +1,4 @@ -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock import grpc import grpc.aio @@ -11,6 +11,7 @@ from a2a.server.request_handlers import GrpcHandler, RequestHandler from a2a.utils.errors import ServerError + # --- Fixtures --- From 5bfa79055f7d40f2ee1888200a09007095d1f9f4 Mon Sep 17 00:00:00 2001 From: Holt Skinner Date: Tue, 22 Jul 2025 15:32:01 +0100 Subject: [PATCH 13/13] ci: Change no_implicit_optional execution in format.sh --- scripts/format.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/format.sh b/scripts/format.sh index 2b91144b..efe187a9 100755 --- a/scripts/format.sh +++ b/scripts/format.sh @@ -32,7 +32,7 @@ run_formatter() { echo "$CHANGED_FILES" | xargs -r "$@" } -run_formatter no_implicit_optional --use-union-or +no-implicit-optional --use-union-or . run_formatter pyupgrade --exit-zero-even-if-changed --py310-plus run_formatter autoflake -i -r --remove-all-unused-imports run_formatter ruff check --fix-only