Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions src/a2a/extensions/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from a2a.types import AgentCard, AgentExtension


HTTP_EXTENSION_HEADER = 'X-A2A-Extensions'


def get_requested_extensions(values: list[str]) -> set[str]:
"""Get the set of requested extensions from an input list.
This handles the list containing potentially comma-separated values, as
occurs when using a list in an HTTP header.
"""
return {
stripped
for v in values
for ext in v.split(',')
if (stripped := ext.strip())
}


def find_extension_by_uri(card: AgentCard, uri: str) -> AgentExtension | None:
"""Find an AgentExtension in an AgentCard given a uri."""
for ext in card.capabilities.extensions or []:
if ext.uri == uri:
return ext

return None
18 changes: 18 additions & 0 deletions src/a2a/server/agent_execution/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,24 @@ def metadata(self) -> dict[str, Any]:
return {}
return self._params.metadata or {}

def add_activated_extension(self, uri: str) -> None:
"""Add an extension to the set of activated extensions for this request.

This causes the extension to be indicated back to the client in the
response.
"""
if self._call_context:
self._call_context.activated_extensions.add(uri)

@property
def requested_extensions(self) -> set[str]:
"""Extensions that the client requested to activate."""
return (
self._call_context.requested_extensions
if self._call_context
else set()
)

def _check_or_generate_task_id(self) -> None:
"""Ensures a task ID is present, generating one if necessary."""
if not self._params:
Expand Down
31 changes: 25 additions & 6 deletions src/a2a/server/apps/jsonrpc/jsonrpc_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@

from a2a.auth.user import UnauthenticatedUser
from a2a.auth.user import User as A2AUser
from a2a.extensions.common import (
HTTP_EXTENSION_HEADER,
get_requested_extensions,
)
from a2a.server.context import ServerCallContext
from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler
from a2a.server.request_handlers.request_handler import RequestHandler
Expand Down Expand Up @@ -99,7 +103,13 @@ def build(self, request: Request) -> ServerCallContext:
user = StarletteUserProxy(request.user)
state['auth'] = request.auth
state['headers'] = dict(request.headers)
return ServerCallContext(user=user, state=state)
return ServerCallContext(
user=user,
state=state,
requested_extensions=get_requested_extensions(
request.headers.getlist(HTTP_EXTENSION_HEADER)
),
)


class JSONRPCApplication(ABC):
Expand Down Expand Up @@ -281,7 +291,7 @@ async def _process_streaming_request(
request_obj, context
)

return self._create_response(handler_result)
return self._create_response(context, handler_result)

async def _process_non_streaming_request(
self,
Expand Down Expand Up @@ -353,10 +363,11 @@ async def _process_non_streaming_request(
id=request_id, error=error
)

return self._create_response(handler_result)
return self._create_response(context, handler_result)

def _create_response(
self,
context: ServerCallContext,
handler_result: (
AsyncGenerator[SendStreamingMessageResponse]
| JSONRPCErrorResponse
Expand All @@ -372,12 +383,16 @@ def _create_response(
payloads.

Args:
context: The ServerCallContext provided to the request handler.
handler_result: The result from a request handler method. Can be an
async generator for streaming or a Pydantic model for non-streaming.

Returns:
A Starlette JSONResponse or EventSourceResponse.
"""
headers = {}
if exts := context.activated_extensions:
headers[HTTP_EXTENSION_HEADER] = ', '.join(sorted(exts))
if isinstance(handler_result, AsyncGenerator):
# Result is a stream of SendStreamingMessageResponse objects
async def event_generator(
Expand All @@ -386,17 +401,21 @@ async def event_generator(
async for item in stream:
yield {'data': item.root.model_dump_json(exclude_none=True)}

return EventSourceResponse(event_generator(handler_result))
return EventSourceResponse(
event_generator(handler_result), headers=headers
)
if isinstance(handler_result, JSONRPCErrorResponse):
return JSONResponse(
handler_result.model_dump(
mode='json',
exclude_none=True,
)
),
headers=headers,
)

return JSONResponse(
handler_result.root.model_dump(mode='json', exclude_none=True)
handler_result.root.model_dump(mode='json', exclude_none=True),
headers=headers,
)

async def _handle_get_agent_card(self, request: Request) -> JSONResponse:
Expand Down
2 changes: 2 additions & 0 deletions src/a2a/server/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,5 @@ class ServerCallContext(BaseModel):

state: State = Field(default={})
user: User = Field(default=UnauthenticatedUser())
requested_extensions: set[str] = Field(default_factory=set)
activated_extensions: set[str] = Field(default_factory=set)
44 changes: 42 additions & 2 deletions src/a2a/server/request_handlers/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import logging

from abc import ABC, abstractmethod
from collections.abc import AsyncIterable
from collections.abc import AsyncIterable, Sequence


try:
import grpc
import grpc.aio

from grpc.aio import Metadata
except ImportError as e:
raise ImportError(
'GrpcHandler requires grpcio and grpcio-tools to be installed. '
Expand All @@ -20,6 +22,10 @@

from a2a import types
from a2a.auth.user import UnauthenticatedUser
from a2a.extensions.common import (
HTTP_EXTENSION_HEADER,
get_requested_extensions,
)
from a2a.grpc import a2a_pb2
from a2a.server.context import ServerCallContext
from a2a.server.request_handlers.request_handler import RequestHandler
Expand All @@ -42,6 +48,19 @@ def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext:
"""Builds a ServerCallContext from a gRPC Request."""


def _get_metadata_value(
context: grpc.aio.ServicerContext, key: str
) -> list[str]:
md = context.invocation_metadata
raw_values: list[str | bytes] = []
if isinstance(md, Metadata):
raw_values = md.get_all(key)
elif isinstance(md, Sequence):
lower_key = key.lower()
raw_values = [e for (k, e) in md if k.lower() == lower_key]
return [e if isinstance(e, str) else e.decode('utf-8') for e in raw_values]


class DefaultCallContextBuilder(CallContextBuilder):
"""A default implementation of CallContextBuilder."""

Expand All @@ -51,7 +70,13 @@ def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext:
state = {}
with contextlib.suppress(Exception):
state['grpc_context'] = context
return ServerCallContext(user=user, state=state)
return ServerCallContext(
user=user,
state=state,
requested_extensions=get_requested_extensions(
_get_metadata_value(context, HTTP_EXTENSION_HEADER)
),
)


class GrpcHandler(a2a_grpc.A2AServiceServicer):
Expand Down Expand Up @@ -102,6 +127,7 @@ async def SendMessage(
task_or_message = await self.request_handler.on_message_send(
a2a_request, server_context
)
self._set_extension_metadata(context, server_context)
return proto_utils.ToProto.task_or_message(task_or_message)
except ServerError as e:
await self.abort_context(e, context)
Expand Down Expand Up @@ -140,6 +166,7 @@ async def SendStreamingMessage(
a2a_request, server_context
):
yield proto_utils.ToProto.stream_response(event)
self._set_extension_metadata(context, server_context)
except ServerError as e:
await self.abort_context(e, context)
return
Expand Down Expand Up @@ -371,3 +398,16 @@ async def abort_context(
grpc.StatusCode.UNKNOWN,
f'Unknown error type: {error.error}',
)

def _set_extension_metadata(
self,
context: grpc.aio.ServicerContext,
server_context: ServerCallContext,
) -> None:
if server_context.activated_extensions:
context.set_trailing_metadata(
[
(HTTP_EXTENSION_HEADER, e)
for e in sorted(server_context.activated_extensions)
]
)
58 changes: 58 additions & 0 deletions tests/extensions/test_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from a2a.extensions.common import (
find_extension_by_uri,
get_requested_extensions,
)
from a2a.types import AgentCapabilities, AgentCard, AgentExtension


def test_get_requested_extensions():
assert get_requested_extensions([]) == set()
assert get_requested_extensions(['foo']) == {'foo'}
assert get_requested_extensions(['foo', 'bar']) == {'foo', 'bar'}
assert get_requested_extensions(['foo, bar']) == {'foo', 'bar'}
assert get_requested_extensions(['foo,bar']) == {'foo', 'bar'}
assert get_requested_extensions(['foo', 'bar,baz']) == {'foo', 'bar', 'baz'}
assert get_requested_extensions(['foo,, bar', 'baz']) == {
'foo',
'bar',
'baz',
}
assert get_requested_extensions([' foo , bar ', 'baz']) == {
'foo',
'bar',
'baz',
}


def test_find_extension_by_uri():
ext1 = AgentExtension(uri='foo', description='The Foo extension')
ext2 = AgentExtension(uri='bar', description='The Bar extension')
card = AgentCard(
name='Test Agent',
description='Test Agent Description',
version='1.0',
url='http://test.com',
skills=[],
default_input_modes=['text/plain'],
default_output_modes=['text/plain'],
capabilities=AgentCapabilities(extensions=[ext1, ext2]),
)

assert find_extension_by_uri(card, 'foo') == ext1
assert find_extension_by_uri(card, 'bar') == ext2
assert find_extension_by_uri(card, 'baz') is None


def test_find_extension_by_uri_no_extensions():
card = AgentCard(
name='Test Agent',
description='Test Agent Description',
version='1.0',
url='http://test.com',
skills=[],
default_input_modes=['text/plain'],
default_output_modes=['text/plain'],
capabilities=AgentCapabilities(extensions=None),
)

assert find_extension_by_uri(card, 'foo') is None
14 changes: 14 additions & 0 deletions tests/server/agent_execution/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

from a2a.server.agent_execution import RequestContext
from a2a.server.context import ServerCallContext
from a2a.types import (
Message,
MessageSendParams,
Expand Down Expand Up @@ -263,3 +264,16 @@ def test_init_with_context_id_and_existing_context_id_match(

assert context.context_id == mock_task.context_id
assert context.current_task == mock_task

def test_extension_handling(self):
"""Test extension handling in RequestContext."""
call_context = ServerCallContext(requested_extensions={'foo', 'bar'})
context = RequestContext(call_context=call_context)

assert context.requested_extensions == {'foo', 'bar'}

context.add_activated_extension('foo')
assert call_context.activated_extensions == {'foo'}

context.add_activated_extension('baz')
assert call_context.activated_extensions == {'foo', 'baz'}
Loading
Loading