diff --git a/src/a2a/server/apps/jsonrpc/fastapi_app.py b/src/a2a/server/apps/jsonrpc/fastapi_app.py index 579ff79c..c18cde3d 100644 --- a/src/a2a/server/apps/jsonrpc/fastapi_app.py +++ b/src/a2a/server/apps/jsonrpc/fastapi_app.py @@ -7,11 +7,9 @@ from fastapi import FastAPI from a2a.server.apps.jsonrpc.jsonrpc_app import ( - CallContextBuilder, JSONRPCApplication, ) -from a2a.server.request_handlers.jsonrpc_handler import RequestHandler -from a2a.types import A2ARequest, AgentCard +from a2a.types import A2ARequest from a2a.utils.constants import ( AGENT_CARD_WELL_KNOWN_PATH, DEFAULT_RPC_URL, @@ -31,32 +29,6 @@ class A2AFastAPIApplication(JSONRPCApplication): (SSE). """ - def __init__( - self, - agent_card: AgentCard, - http_handler: RequestHandler, - extended_agent_card: AgentCard | None = None, - context_builder: CallContextBuilder | None = None, - ) -> None: - """Initializes the A2AStarletteApplication. - - Args: - agent_card: The AgentCard describing the agent's capabilities. - http_handler: The handler instance responsible for processing A2A - requests via http. - extended_agent_card: An optional, distinct AgentCard to be served - at the authenticated extended card endpoint. - context_builder: The CallContextBuilder used to construct the - ServerCallContext passed to the http_handler. If None, no - ServerCallContext is passed. - """ - super().__init__( - agent_card=agent_card, - http_handler=http_handler, - extended_agent_card=extended_agent_card, - context_builder=context_builder, - ) - def add_routes_to_app( self, app: FastAPI, @@ -90,13 +62,13 @@ def add_routes_to_app( )(self._handle_requests) app.get(agent_card_url)(self._handle_get_agent_card) - # add deprecated path only if the agent_card_url uses default well-known path if agent_card_url == AGENT_CARD_WELL_KNOWN_PATH: - app.get(PREV_AGENT_CARD_WELL_KNOWN_PATH, include_in_schema=False)( - self.handle_deprecated_agent_card_path + # For backward compatibility, serve the agent card at the deprecated path as well. + # TODO: remove in a future release + app.get(PREV_AGENT_CARD_WELL_KNOWN_PATH)( + self._handle_get_agent_card ) - # TODO: deprecated endpoint to be removed in a future release if self.agent_card.supports_authenticated_extended_card: app.get(extended_agent_card_url)( self._handle_get_authenticated_extended_agent_card diff --git a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py index 6f374848..95aa8079 100644 --- a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py +++ b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py @@ -4,7 +4,7 @@ import traceback from abc import ABC, abstractmethod -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Callable from typing import Any from fastapi import FastAPI @@ -123,12 +123,17 @@ class JSONRPCApplication(ABC): (SSE). """ - def __init__( + def __init__( # noqa: PLR0913 self, agent_card: AgentCard, http_handler: RequestHandler, extended_agent_card: AgentCard | None = None, context_builder: CallContextBuilder | None = None, + card_modifier: Callable[[AgentCard], AgentCard] | None = None, + extended_card_modifier: Callable[ + [AgentCard, ServerCallContext], AgentCard + ] + | None = None, ) -> None: """Initializes the A2AStarletteApplication. @@ -141,17 +146,26 @@ def __init__( context_builder: The CallContextBuilder used to construct the ServerCallContext passed to the http_handler. If None, no ServerCallContext is passed. + card_modifier: An optional callback to dynamically modify the public + agent card before it is served. + extended_card_modifier: An optional callback to dynamically modify + the extended agent card before it is served. It receives the + call context. """ self.agent_card = agent_card self.extended_agent_card = extended_agent_card + self.card_modifier = card_modifier + self.extended_card_modifier = extended_card_modifier self.handler = JSONRPCHandler( agent_card=agent_card, request_handler=http_handler, extended_agent_card=extended_agent_card, + extended_card_modifier=extended_card_modifier, ) if ( self.agent_card.supports_authenticated_extended_card and self.extended_agent_card is None + and self.extended_card_modifier is None ): logger.error( 'AgentCard.supports_authenticated_extended_card is True, but no extended_agent_card was provided. The /agent/authenticatedExtendedCard endpoint will return 404.' @@ -448,24 +462,23 @@ async def _handle_get_agent_card(self, request: Request) -> JSONResponse: Returns: A JSONResponse containing the agent card data. """ - # The public agent card is a direct serialization of the agent_card - # provided at initialization. + if request.url.path == PREV_AGENT_CARD_WELL_KNOWN_PATH: + logger.warning( + f"Deprecated agent card endpoint '{PREV_AGENT_CARD_WELL_KNOWN_PATH}' accessed. " + f"Please use '{AGENT_CARD_WELL_KNOWN_PATH}' instead. This endpoint will be removed in a future version." + ) + + card_to_serve = self.agent_card + if self.card_modifier: + card_to_serve = self.card_modifier(card_to_serve) + return JSONResponse( - self.agent_card.model_dump( + card_to_serve.model_dump( exclude_none=True, by_alias=True, ) ) - async def handle_deprecated_agent_card_path( - self, request: Request - ) -> JSONResponse: - """Handles GET requests for the deprecated agent card endpoint.""" - logger.warning( - f"Deprecated agent card endpoint '{PREV_AGENT_CARD_WELL_KNOWN_PATH}' accessed. Please use '{AGENT_CARD_WELL_KNOWN_PATH}' instead. This endpoint will be removed in a future version." - ) - return await self._handle_get_agent_card(request) - async def _handle_get_authenticated_extended_agent_card( self, request: Request ) -> JSONResponse: @@ -480,17 +493,24 @@ async def _handle_get_authenticated_extended_agent_card( status_code=404, ) - # If an explicit extended_agent_card is provided, serve that. - if self.extended_agent_card: + card_to_serve = self.extended_agent_card + + if self.extended_card_modifier: + context = self._context_builder.build(request) + # If no base extended card is provided, pass the public card to the modifier + base_card = card_to_serve if card_to_serve else self.agent_card + card_to_serve = self.extended_card_modifier(base_card, context) + + if card_to_serve: return JSONResponse( - self.extended_agent_card.model_dump( + card_to_serve.model_dump( exclude_none=True, by_alias=True, ) ) - # If supports_authenticated_extended_card is true, but no specific - # extended_agent_card was provided during server initialization, - # return a 404 + # If supports_authenticated_extended_card is true, but no + # extended_agent_card was provided, and no modifier produced a card, + # return a 404. return JSONResponse( { 'error': 'Authenticated extended agent card is supported but not configured on the server.' diff --git a/src/a2a/server/apps/jsonrpc/starlette_app.py b/src/a2a/server/apps/jsonrpc/starlette_app.py index 51974fc1..0f7de3df 100644 --- a/src/a2a/server/apps/jsonrpc/starlette_app.py +++ b/src/a2a/server/apps/jsonrpc/starlette_app.py @@ -6,11 +6,8 @@ from starlette.routing import Route from a2a.server.apps.jsonrpc.jsonrpc_app import ( - CallContextBuilder, JSONRPCApplication, ) -from a2a.server.request_handlers.jsonrpc_handler import RequestHandler -from a2a.types import AgentCard from a2a.utils.constants import ( AGENT_CARD_WELL_KNOWN_PATH, DEFAULT_RPC_URL, @@ -30,32 +27,6 @@ class A2AStarletteApplication(JSONRPCApplication): (SSE). """ - def __init__( - self, - agent_card: AgentCard, - http_handler: RequestHandler, - extended_agent_card: AgentCard | None = None, - context_builder: CallContextBuilder | None = None, - ) -> None: - """Initializes the A2AStarletteApplication. - - Args: - agent_card: The AgentCard describing the agent's capabilities. - http_handler: The handler instance responsible for processing A2A - requests via http. - extended_agent_card: An optional, distinct AgentCard to be served - at the authenticated extended card endpoint. - context_builder: The CallContextBuilder used to construct the - ServerCallContext passed to the http_handler. If None, no - ServerCallContext is passed. - """ - super().__init__( - agent_card=agent_card, - http_handler=http_handler, - extended_agent_card=extended_agent_card, - context_builder=context_builder, - ) - def routes( self, agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH, @@ -87,14 +58,15 @@ def routes( ), ] - # add deprecated path only if the agent_card_url uses default well-known path if agent_card_url == AGENT_CARD_WELL_KNOWN_PATH: + # For backward compatibility, serve the agent card at the deprecated path as well. + # TODO: remove in a future release app_routes.append( Route( PREV_AGENT_CARD_WELL_KNOWN_PATH, - self.handle_deprecated_agent_card_path, + self._handle_get_agent_card, methods=['GET'], - name='agent_card_path_deprecated', + name='deprecated_agent_card', ) ) diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index 5fc15cf9..e2ec69a1 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -18,6 +18,8 @@ "'pip install a2a-sdk[grpc]'" ) from e +from collections.abc import Callable + import a2a.grpc.a2a_pb2_grpc as a2a_grpc from a2a import types @@ -87,6 +89,7 @@ def __init__( agent_card: AgentCard, request_handler: RequestHandler, context_builder: CallContextBuilder | None = None, + card_modifier: Callable[[AgentCard], AgentCard] | None = None, ): """Initializes the GrpcHandler. @@ -96,10 +99,13 @@ def __init__( delegate requests to. context_builder: The CallContextBuilder object. If none the DefaultCallContextBuilder is used. + card_modifier: An optional callback to dynamically modify the public + agent card before it is served. """ self.agent_card = agent_card self.request_handler = request_handler self.context_builder = context_builder or DefaultCallContextBuilder() + self.card_modifier = card_modifier async def SendMessage( self, @@ -331,7 +337,10 @@ async def GetAgentCard( context: grpc.aio.ServicerContext, ) -> a2a_pb2.AgentCard: """Get the agent card for the agent served.""" - return proto_utils.ToProto.agent_card(self.agent_card) + card_to_serve = self.agent_card + if self.card_modifier: + card_to_serve = self.card_modifier(card_to_serve) + return proto_utils.ToProto.agent_card(card_to_serve) async def abort_context( self, error: ServerError, context: grpc.aio.ServicerContext diff --git a/src/a2a/server/request_handlers/jsonrpc_handler.py b/src/a2a/server/request_handlers/jsonrpc_handler.py index a0657859..97cff496 100644 --- a/src/a2a/server/request_handlers/jsonrpc_handler.py +++ b/src/a2a/server/request_handlers/jsonrpc_handler.py @@ -1,6 +1,6 @@ import logging -from collections.abc import AsyncIterable +from collections.abc import AsyncIterable, Callable from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler @@ -62,6 +62,10 @@ def __init__( agent_card: AgentCard, request_handler: RequestHandler, extended_agent_card: AgentCard | None = None, + extended_card_modifier: Callable[ + [AgentCard, ServerCallContext], AgentCard + ] + | None = None, ): """Initializes the JSONRPCHandler. @@ -69,10 +73,14 @@ def __init__( agent_card: The AgentCard describing the agent's capabilities. request_handler: The underlying `RequestHandler` instance to delegate requests to. extended_agent_card: An optional, distinct Extended AgentCard to be served + extended_card_modifier: An optional callback to dynamically modify + the extended agent card before it is served. It receives the + call context. """ self.agent_card = agent_card self.request_handler = request_handler self.extended_agent_card = extended_agent_card + self.extended_card_modifier = extended_card_modifier async def on_message_send( self, @@ -417,7 +425,10 @@ async def get_authenticated_extended_card( Returns: A `GetAuthenticatedExtendedCardResponse` object containing the config or a JSON-RPC error. """ - if self.extended_agent_card is None: + if ( + self.extended_agent_card is None + and self.extended_card_modifier is None + ): return GetAuthenticatedExtendedCardResponse( root=JSONRPCErrorResponse( id=request.id, @@ -425,8 +436,16 @@ async def get_authenticated_extended_card( ) ) + base_card = self.extended_agent_card + if base_card is None: + base_card = self.agent_card + + card_to_serve = base_card + if self.extended_card_modifier and context: + card_to_serve = self.extended_card_modifier(base_card, context) + return GetAuthenticatedExtendedCardResponse( root=GetAuthenticatedExtendedCardSuccessResponse( - id=request.id, result=self.extended_agent_card + id=request.id, result=card_to_serve ) ) diff --git a/tests/server/request_handlers/test_grpc_handler.py b/tests/server/request_handlers/test_grpc_handler.py index 8bd65e02..05af6cda 100644 --- a/tests/server/request_handlers/test_grpc_handler.py +++ b/tests/server/request_handlers/test_grpc_handler.py @@ -201,6 +201,34 @@ async def test_get_agent_card( assert response.version == sample_agent_card.version +@pytest.mark.asyncio +async def test_get_agent_card_with_modifier( + mock_request_handler: AsyncMock, + sample_agent_card: types.AgentCard, + mock_grpc_context: AsyncMock, +): + """Test GetAgentCard call with a card_modifier.""" + + def modifier(card: types.AgentCard) -> types.AgentCard: + modified_card = card.model_copy(deep=True) + modified_card.name = 'Modified gRPC Agent' + return modified_card + + grpc_handler_modified = GrpcHandler( + agent_card=sample_agent_card, + request_handler=mock_request_handler, + card_modifier=modifier, + ) + + request_proto = a2a_pb2.GetAgentCardRequest() + response = await grpc_handler_modified.GetAgentCard( + request_proto, mock_grpc_context + ) + + assert response.name == 'Modified gRPC Agent' + assert response.version == sample_agent_card.version + + @pytest.mark.asyncio @pytest.mark.parametrize( 'server_error, grpc_status_code, error_message_part', @@ -267,7 +295,7 @@ async def test_get_agent_card( ), ], ) -async def test_abort_context_error_mapping( +async def test_abort_context_error_mapping( # noqa: PLR0913 grpc_handler: GrpcHandler, mock_request_handler: AsyncMock, mock_grpc_context: AsyncMock, diff --git a/tests/server/request_handlers/test_jsonrpc_handler.py b/tests/server/request_handlers/test_jsonrpc_handler.py index ef43b05f..b460b2f3 100644 --- a/tests/server/request_handlers/test_jsonrpc_handler.py +++ b/tests/server/request_handlers/test_jsonrpc_handler.py @@ -1212,6 +1212,7 @@ async def test_get_authenticated_extended_card_success(self) -> None: self.mock_agent_card, mock_request_handler, extended_agent_card=mock_extended_card, + extended_card_modifier=None, ) request = GetAuthenticatedExtendedCardRequest(id='ext-card-req-1') call_context = ServerCallContext(state={'foo': 'bar'}) @@ -1233,7 +1234,10 @@ async def test_get_authenticated_extended_card_not_configured(self) -> None: # Arrange mock_request_handler = AsyncMock(spec=DefaultRequestHandler) handler = JSONRPCHandler( - self.mock_agent_card, mock_request_handler, extended_agent_card=None + self.mock_agent_card, + mock_request_handler, + extended_agent_card=None, + extended_card_modifier=None, ) request = GetAuthenticatedExtendedCardRequest(id='ext-card-req-2') call_context = ServerCallContext(state={'foo': 'bar'}) @@ -1249,3 +1253,50 @@ async def test_get_authenticated_extended_card_not_configured(self) -> None: self.assertIsInstance( response.root.error, AuthenticatedExtendedCardNotConfiguredError ) + + async def test_get_authenticated_extended_card_with_modifier(self) -> None: + """Test successful retrieval of a dynamically modified extended agent card.""" + # Arrange + mock_request_handler = AsyncMock(spec=DefaultRequestHandler) + mock_base_card = AgentCard( + name='Base Card', + description='Base details', + url='http://agent.example.com/api', + version='1.0', + capabilities=AgentCapabilities(), + default_input_modes=['text/plain'], + default_output_modes=['application/json'], + skills=[], + ) + + def modifier(card: AgentCard, context: ServerCallContext) -> AgentCard: + modified_card = card.model_copy(deep=True) + modified_card.name = 'Modified Card' + modified_card.description = ( + f'Modified for context: {context.state.get("foo")}' + ) + return modified_card + + handler = JSONRPCHandler( + self.mock_agent_card, + mock_request_handler, + extended_agent_card=mock_base_card, + extended_card_modifier=modifier, + ) + request = GetAuthenticatedExtendedCardRequest(id='ext-card-req-mod') + call_context = ServerCallContext(state={'foo': 'bar'}) + + # Act + response: GetAuthenticatedExtendedCardResponse = ( + await handler.get_authenticated_extended_card(request, call_context) + ) + + # Assert + self.assertIsInstance( + response.root, GetAuthenticatedExtendedCardSuccessResponse + ) + self.assertEqual(response.root.id, 'ext-card-req-mod') + modified_card = response.root.result + self.assertEqual(modified_card.name, 'Modified Card') + self.assertEqual(modified_card.description, 'Modified for context: bar') + self.assertEqual(modified_card.version, '1.0') diff --git a/tests/server/tasks/test_database_push_notification_config_store.py b/tests/server/tasks/test_database_push_notification_config_store.py index 8ec7f1d4..0c3bd468 100644 --- a/tests/server/tasks/test_database_push_notification_config_store.py +++ b/tests/server/tasks/test_database_push_notification_config_store.py @@ -3,14 +3,6 @@ from collections.abc import AsyncGenerator import pytest -import pytest_asyncio - -from _pytest.mark.structures import ParameterSet -from sqlalchemy import select -from sqlalchemy.ext.asyncio import ( - async_sessionmaker, - create_async_engine, -) # Skip entire test module if SQLAlchemy is not installed @@ -20,8 +12,17 @@ reason='Database tests require Cryptography. Install extra encryption', ) +import pytest_asyncio + +from _pytest.mark.structures import ParameterSet + # Now safe to import SQLAlchemy-dependent modules from cryptography.fernet import Fernet +from sqlalchemy import select +from sqlalchemy.ext.asyncio import ( + async_sessionmaker, + create_async_engine, +) from sqlalchemy.inspection import inspect from a2a.server.models import ( diff --git a/tests/server/test_integration.py b/tests/server/test_integration.py index ddd9a60f..f135349b 100644 --- a/tests/server/test_integration.py +++ b/tests/server/test_integration.py @@ -22,6 +22,7 @@ A2AFastAPIApplication, A2AStarletteApplication, ) +from a2a.server.context import ServerCallContext from a2a.types import ( AgentCapabilities, AgentCard, @@ -46,6 +47,7 @@ ) from a2a.utils import ( AGENT_CARD_WELL_KNOWN_PATH, + EXTENDED_AGENT_CARD_PATH, PREV_AGENT_CARD_WELL_KNOWN_PATH, ) from a2a.utils.errors import MethodNotImplementedError @@ -226,7 +228,7 @@ def test_authenticated_extended_agent_card_endpoint_supported_with_specific_exte agent_card.supports_authenticated_extended_card = ( True # Main card must support it ) - print(agent_card) + app_instance = A2AStarletteApplication( agent_card, handler, extended_agent_card=extended_agent_card_fixture ) @@ -846,6 +848,100 @@ def test_invalid_request_structure(client: TestClient): assert data['error']['code'] == InvalidRequestError().code +# === DYNAMIC CARD MODIFIER TESTS === + + +def test_dynamic_agent_card_modifier( + agent_card: AgentCard, handler: mock.AsyncMock +): + """Test that the card_modifier dynamically alters the public agent card.""" + + def modifier(card: AgentCard) -> AgentCard: + modified_card = card.model_copy(deep=True) + modified_card.name = 'Dynamically Modified Agent' + return modified_card + + app_instance = A2AStarletteApplication( + agent_card, handler, card_modifier=modifier + ) + client = TestClient(app_instance.build()) + + response = client.get(AGENT_CARD_WELL_KNOWN_PATH) + assert response.status_code == 200 + data = response.json() + assert data['name'] == 'Dynamically Modified Agent' + assert ( + data['version'] == agent_card.version + ) # Ensure other fields are intact + + +def test_dynamic_extended_agent_card_modifier( + agent_card: AgentCard, + extended_agent_card_fixture: AgentCard, + handler: mock.AsyncMock, +): + """Test that the extended_card_modifier dynamically alters the extended agent card.""" + agent_card.supports_authenticated_extended_card = True + + def modifier(card: AgentCard, context: ServerCallContext) -> AgentCard: + modified_card = card.model_copy(deep=True) + modified_card.description = 'Dynamically Modified Extended Description' + return modified_card + + # Test with a base extended card + app_instance = A2AStarletteApplication( + agent_card, + handler, + extended_agent_card=extended_agent_card_fixture, + extended_card_modifier=modifier, + ) + client = TestClient(app_instance.build()) + + response = client.get(EXTENDED_AGENT_CARD_PATH) + assert response.status_code == 200 + data = response.json() + assert data['name'] == extended_agent_card_fixture.name + assert data['description'] == 'Dynamically Modified Extended Description' + + # Test without a base extended card (modifier should receive public card) + app_instance_no_base = A2AStarletteApplication( + agent_card, + handler, + extended_agent_card=None, + extended_card_modifier=modifier, + ) + client_no_base = TestClient(app_instance_no_base.build()) + response_no_base = client_no_base.get(EXTENDED_AGENT_CARD_PATH) + assert response_no_base.status_code == 200 + data_no_base = response_no_base.json() + assert data_no_base['name'] == agent_card.name + assert ( + data_no_base['description'] + == 'Dynamically Modified Extended Description' + ) + + +def test_fastapi_dynamic_agent_card_modifier( + agent_card: AgentCard, handler: mock.AsyncMock +): + """Test that the card_modifier dynamically alters the public agent card for FastAPI.""" + + def modifier(card: AgentCard) -> AgentCard: + modified_card = card.model_copy(deep=True) + modified_card.name = 'Dynamically Modified Agent' + return modified_card + + app_instance = A2AFastAPIApplication( + agent_card, handler, card_modifier=modifier + ) + client = TestClient(app_instance.build()) + + response = client.get(AGENT_CARD_WELL_KNOWN_PATH) + assert response.status_code == 200 + data = response.json() + assert data['name'] == 'Dynamically Modified Agent' + + def test_method_not_implemented(client: TestClient, handler: mock.AsyncMock): """Test handling MethodNotImplementedError.""" handler.on_get_task.side_effect = MethodNotImplementedError()