diff --git a/python/packages/a2a/AGENTS.md b/python/packages/a2a/AGENTS.md index af6e4a492b..1474c59ef5 100644 --- a/python/packages/a2a/AGENTS.md +++ b/python/packages/a2a/AGENTS.md @@ -4,20 +4,48 @@ Agent-to-Agent (A2A) protocol support for inter-agent communication. ## Main Classes -- **`A2AAgent`** - Agent wrapper that exposes an agent via the A2A protocol +- **`A2AAgent`** - Client to connect to remote A2A-compliant agents. +- **`A2AExecutor`** - Bridge to expose Agent Framework agents via the A2A protocol. ## Usage +### A2AAgent (Client) + ```python from agent_framework.a2a import A2AAgent -a2a_agent = A2AAgent(agent=my_agent) +# Connect to a remote A2A agent +a2a_agent = A2AAgent(url="http://remote-agent/a2a") +response = await a2a_agent.run("Hello!") +``` + +### A2AExecutor (Server/Bridge) + +```python +from agent_framework.a2a import A2AExecutor +from a2a.server.apps import A2AStarletteApplication +from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.tasks import InMemoryTaskStore + +# Create an A2A executor for your agent +executor = A2AExecutor(agent=my_agent) + +# Set up the request handler and server application +request_handler = DefaultRequestHandler( + agent_executor=executor, + task_store=InMemoryTaskStore(), +) + +app = A2AStarletteApplication( + agent_card=my_agent_card, + http_handler=request_handler, +).build() ``` ## Import Path ```python -from agent_framework.a2a import A2AAgent +from agent_framework.a2a import A2AAgent, A2AExecutor # or directly: -from agent_framework_a2a import A2AAgent +from agent_framework_a2a import A2AAgent, A2AExecutor ``` diff --git a/python/packages/a2a/README.md b/python/packages/a2a/README.md index 5ae15e3647..4bdfa9221e 100644 --- a/python/packages/a2a/README.md +++ b/python/packages/a2a/README.md @@ -10,11 +10,49 @@ pip install agent-framework-a2a --pre The A2A agent integration enables communication with remote A2A-compliant agents using the standardized A2A protocol. This allows your Agent Framework applications to connect to agents running on different platforms, languages, or services. +### A2AAgent (Client) + +The `A2AAgent` class is a client that wraps an A2A Client to connect the Agent Framework with external A2A-compliant agents. + +```python +from agent_framework.a2a import A2AAgent + +# Connect to a remote A2A agent +a2a_agent = A2AAgent(url="http://remote-agent/a2a") +response = await a2a_agent.run("Hello!") +``` + +### A2AExecutor (Hosting) + +The `A2AExecutor` class bridges local AI agents built with the `agent_framework` library to the A2A protocol, allowing them to be hosted and accessed by other A2A-compliant clients. + +```python +from agent_framework.a2a import A2AExecutor +from a2a.server.apps import A2AStarletteApplication +from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.tasks import InMemoryTaskStore + +# Create an A2A executor for your agent +executor = A2AExecutor(agent=my_agent) + +# Set up the request handler and server application +request_handler = DefaultRequestHandler( + agent_executor=executor, + task_store=InMemoryTaskStore(), +) + +app = A2AStarletteApplication( + agent_card=my_agent_card, + http_handler=request_handler, +).build() +``` + ### Basic Usage Example See the [A2A agent examples](../../samples/04-hosting/a2a/) which demonstrate: - Connecting to remote A2A agents +- Hosting local agents via A2A protocol - Sending messages and receiving responses - Handling different content types (text, files, data) - Streaming responses and real-time interaction diff --git a/python/packages/a2a/agent_framework_a2a/__init__.py b/python/packages/a2a/agent_framework_a2a/__init__.py index 4b4d54ecc3..c5338965c2 100644 --- a/python/packages/a2a/agent_framework_a2a/__init__.py +++ b/python/packages/a2a/agent_framework_a2a/__init__.py @@ -2,6 +2,7 @@ import importlib.metadata +from ._a2a_executor import A2AExecutor from ._agent import A2AAgent, A2AContinuationToken try: @@ -12,5 +13,6 @@ __all__ = [ "A2AAgent", "A2AContinuationToken", + "A2AExecutor", "__version__", ] diff --git a/python/packages/a2a/agent_framework_a2a/_a2a_executor.py b/python/packages/a2a/agent_framework_a2a/_a2a_executor.py new file mode 100644 index 0000000000..0cf6d835a6 --- /dev/null +++ b/python/packages/a2a/agent_framework_a2a/_a2a_executor.py @@ -0,0 +1,275 @@ +# Copyright (c) Microsoft. All rights reserved. + +import logging +from asyncio import CancelledError +from collections.abc import Mapping +from functools import partial +from typing import Any + +from a2a.server.agent_execution import AgentExecutor, RequestContext +from a2a.server.events import EventQueue +from a2a.server.tasks import TaskUpdater +from a2a.types import FilePart, FileWithBytes, FileWithUri, Part, TaskState, TextPart +from a2a.utils import new_task +from agent_framework import ( + AgentResponseUpdate, + AgentSession, + Message, + SupportsAgentRun, +) +from typing_extensions import override + +from agent_framework_a2a._utils import get_uri_data + +logger = logging.getLogger("agent_framework.a2a") + + +class A2AExecutor(AgentExecutor): + """Execute AI agents using the A2A (Agent-to-Agent) protocol. + + The A2AExecutor bridges AI agents built with the agent_framework library and the A2A protocol, + enabling structured agent execution with event-driven communication. It handles execution + contexts, delegates history management to the agent's session, and converts agent + responses into A2A protocol events. + + The executor supports executing an Agent or WorkflowAgent. It provides comprehensive + error handling with task status updates and supports various content types including text, + binary data, and URI-based content. + + Example: + .. code-block:: python + + from a2a.server.apps import A2AStarletteApplication + from a2a.server.request_handlers import DefaultRequestHandler + from a2a.server.tasks import InMemoryTaskStore + from a2a.types import AgentCapabilities, AgentCard + from agent_framework.a2a import A2AExecutor + from agent_framework.openai import OpenAIResponsesClient + + public_agent_card = AgentCard( + name="Food Agent", + description="A simple agent that provides food-related information.", + url="http://localhost:9999/", + version="1.0.0", + defaultInputModes=["text"], + defaultOutputModes=["text"], + capabilities=AgentCapabilities(streaming=True), + skills=[], + ) + + # Create an agent + agent = OpenAIResponsesClient().as_agent( + name="Food Agent", + instructions="A simple agent that provides food-related information.", + ) + + # Set up the A2A server with the A2AExecutor enabled for streaming + # and passing custom keyword arguments to the agent's run method. + request_handler = DefaultRequestHandler( + agent_executor=A2AExecutor(agent, stream=True, run_kwargs={"client_kwargs": {"max_tokens": 500}}), + task_store=InMemoryTaskStore(), + ) + + server = A2AStarletteApplication( + agent_card=public_agent_card, + http_handler=request_handler, + ).build() + + Args: + agent: The AI agent to execute. + stream: Whether to stream the agent response. Defaults to False. + run_kwargs: Additional keyword arguments to pass to the agent's run method. + """ + + def __init__(self, agent: SupportsAgentRun, stream: bool = False, run_kwargs: Mapping[str, Any] | None = None): + """Initialize the A2AExecutor with the specified agent. + + Args: + agent: The AI agent or workflow to execute. + stream: Whether to stream the agent response. Defaults to False. + run_kwargs: Additional keyword arguments to pass to the agent's run method. + Cannot contain 'session' or 'stream' as these are managed by the executor. + + Raises: + ValueError: If run_kwargs contains 'session' or 'stream'. + """ + super().__init__() + self._agent: SupportsAgentRun = agent + self._stream: bool = stream + if run_kwargs: + if "session" in run_kwargs: + raise ValueError("run_kwargs cannot contain 'session' as it is managed by the executor.") + if "stream" in run_kwargs: + raise ValueError("run_kwargs cannot contain 'stream' as it is managed by the executor.") + self._run_kwargs: Mapping[str, Any] = run_kwargs or {} + + @override + async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None: + """Cancel agent execution for the given request context. + + Uses a TaskUpdater to send a cancellation event through the provided event queue. + + Args: + context: The request context identifying the task to cancel. + event_queue: The event queue to publish the cancellation event to. + + Raises: + ValueError: If context_id is not provided in the RequestContext. + """ + if context.context_id is None: + raise ValueError("Context ID must be provided in the RequestContext") + + updater = TaskUpdater( + event_queue=event_queue, + task_id=context.task_id or "", + context_id=context.context_id, + ) + + await updater.cancel() + + @override + async def execute(self, context: RequestContext, event_queue: EventQueue) -> None: + """Execute the agent with the given context and event queue. + + Orchestrates the agent execution process: sets up the agent session, + executes the agent, processes response messages, and handles errors with appropriate task status updates. + """ + if context.context_id is None: + raise ValueError("Context ID must be provided in the RequestContext") + if context.message is None: + raise ValueError("Message must be provided in the RequestContext") + + query = context.get_user_input() + task = context.current_task + + if not task: + task = new_task(context.message) + await event_queue.enqueue_event(task) + + updater = TaskUpdater(event_queue, task.id, context.context_id) + await updater.submit() + + try: + await updater.start_work() + + session = self._agent.create_session(session_id=task.context_id) + + if self._stream: + await self._run_stream(query, session, updater) + else: + await self._run(query, session, updater) + + # Mark as complete + await updater.complete() + except CancelledError: + await updater.update_status(state=TaskState.canceled, final=True) + except Exception as e: + logger.exception("A2AExecutor encountered an error during execution.", exc_info=e) + await updater.update_status( + state=TaskState.failed, + final=True, + message=updater.new_agent_message([Part(root=TextPart(text=str(e)))]), + ) + + async def _run_stream(self, query: Any, session: AgentSession, updater: TaskUpdater) -> None: + """Run the agent in streaming mode and publish updates to the task updater.""" + response_stream = self._agent.run(query, session=session, stream=True, **self._run_kwargs) + streamed_artifact_ids: set[str] = set() + await ( + response_stream.with_transform_hook( + partial(self.handle_events, updater=updater, streamed_artifact_ids=streamed_artifact_ids) + ) + ).get_final_response() + + async def _run(self, query: Any, session: AgentSession, updater: TaskUpdater) -> None: + """Run the agent in non-streaming mode and publish messages to the task updater.""" + response = await self._agent.run(query, session=session, stream=False, **self._run_kwargs) + response_messages = response.messages + + if not isinstance(response_messages, list): + response_messages = [response_messages] + + for message in response_messages: + await self.handle_events(message, updater) + + async def handle_events( + self, item: Message | AgentResponseUpdate, updater: TaskUpdater, streamed_artifact_ids: set[str] | None = None + ) -> None: + """Convert agent response items (Messages or Updates) to A2A protocol events. + + Processes Message or AgentResponseUpdate objects and converts them into A2A protocol format. + Handles text, data, and URI content. USER role messages are skipped. + + Users can override this method in a subclass to implement custom transformations + from their agent's output format to A2A protocol events. + + Args: + item: The agent response item (Message or AgentResponseUpdate) to process. + updater: The task updater to publish events to. + streamed_artifact_ids: A set of artifact IDs that have already been streamed. + Used to prevent duplicate updates for the same artifact. + + Example: + .. code-block:: python + + class CustomA2AExecutor(A2AExecutor): + async def handle_events( + self, + item: Message | AgentResponseUpdate, + updater: TaskUpdater, + streamed_artifact_ids: set[str] | None = None, + ) -> None: + # Custom logic to transform item contents + if item.role == "assistant" and item.contents: + parts = [Part(root=TextPart(text=f"Custom: {item.contents[0].text}"))] + await updater.update_status( + state=TaskState.working, + message=updater.new_agent_message(parts=parts), + ) + else: + await super().handle_events(item, updater) + """ + role = getattr(item, "role", None) + if role == "user": + # This is a user message, we can ignore it in the context of task updates + return + + parts: list[Part] = [] + metadata = getattr(item, "additional_properties", None) + + # AgentResponseUpdate uses 'contents', Message uses 'contents' + contents = getattr(item, "contents", []) + + for content in contents: + if content.type == "text" and content.text: + parts.append(Part(root=TextPart(text=content.text))) + elif content.type == "data" and content.uri: + base64_str = get_uri_data(content.uri) + parts.append(Part(root=FilePart(file=FileWithBytes(bytes=base64_str, mime_type=content.media_type)))) + elif content.type == "uri" and content.uri: + parts.append(Part(root=FilePart(file=FileWithUri(uri=content.uri, mime_type=content.media_type)))) + else: + # Silently skip unsupported content types + logger.warning("A2AExecutor does not yet support content type: %s. Omitted.", content.type) + + if parts: + if isinstance(item, AgentResponseUpdate): + # For streaming updates, we send TaskArtifactUpdateEvent via add_artifact + await updater.add_artifact( + parts=parts, + artifact_id=item.message_id, + metadata=metadata, + append=( + True + if streamed_artifact_ids is not None and item.message_id in (streamed_artifact_ids or set()) + else None + ), + ) + if item.message_id and streamed_artifact_ids is not None: + streamed_artifact_ids.add(item.message_id) + else: + # For final messages, we send TaskStatusUpdateEvent with 'working' state + await updater.update_status( + state=TaskState.working, + message=updater.new_agent_message(parts=parts, metadata=metadata), + ) diff --git a/python/packages/a2a/agent_framework_a2a/_agent.py b/python/packages/a2a/agent_framework_a2a/_agent.py index a07be3cf2f..4025687cfd 100644 --- a/python/packages/a2a/agent_framework_a2a/_agent.py +++ b/python/packages/a2a/agent_framework_a2a/_agent.py @@ -4,7 +4,6 @@ import base64 import json -import re import uuid from collections.abc import AsyncIterable, Awaitable, Mapping, Sequence from typing import Any, Final, Literal, TypeAlias, overload @@ -49,7 +48,7 @@ __all__ = ["A2AAgent", "A2AContinuationToken"] -URI_PATTERN = re.compile(r"^data:(?P[^;]+);base64,(?P[A-Za-z0-9+/=]+)$") +from agent_framework_a2a._utils import get_uri_data class A2AContinuationToken(ContinuationToken): @@ -78,14 +77,6 @@ class A2AContinuationToken(ContinuationToken): A2AStreamItem: TypeAlias = A2AMessage | A2AClientEvent -def _get_uri_data(uri: str) -> str: - match = URI_PATTERN.match(uri) - if not match: - raise ValueError(f"Invalid data URI format: {uri}") - - return match.group("base64_data") - - class A2AAgent(AgentTelemetryLayer, BaseAgent): """Agent2Agent (A2A) protocol implementation. @@ -642,7 +633,7 @@ def _prepare_message_for_a2a(self, message: Message) -> A2AMessage: A2APart( root=FilePart( file=FileWithBytes( - bytes=_get_uri_data(content.uri), + bytes=get_uri_data(content.uri), mime_type=content.media_type, ), metadata=content.additional_properties, diff --git a/python/packages/a2a/agent_framework_a2a/_utils.py b/python/packages/a2a/agent_framework_a2a/_utils.py new file mode 100644 index 0000000000..2b0a8e1600 --- /dev/null +++ b/python/packages/a2a/agent_framework_a2a/_utils.py @@ -0,0 +1,24 @@ +# Copyright (c) Microsoft. All rights reserved. + +import re + +URI_PATTERN = re.compile(r"^data:(?P[^;]+);base64,(?P[A-Za-z0-9+/=]+)$") + + +def get_uri_data(uri: str) -> str: + """Extracts the base64-encoded data from a data URI. + + Args: + uri: The data URI to parse. + + Returns: + The base64-encoded data part of the URI. + + Raises: + ValueError: If the URI format is invalid. + """ + match = URI_PATTERN.match(uri) + if not match: + raise ValueError(f"Invalid data URI format: {uri}") + + return match.group("base64_data") diff --git a/python/packages/a2a/tests/test_a2a_agent.py b/python/packages/a2a/tests/test_a2a_agent.py index 484d71e22c..aaefaf7fc8 100644 --- a/python/packages/a2a/tests/test_a2a_agent.py +++ b/python/packages/a2a/tests/test_a2a_agent.py @@ -35,7 +35,7 @@ from pytest import fixture, mark, raises from agent_framework_a2a import A2AContinuationToken -from agent_framework_a2a._agent import _get_uri_data # type: ignore +from agent_framework_a2a._utils import get_uri_data class MockA2AClient: @@ -351,18 +351,18 @@ def test_parse_message_from_artifact(a2a_agent: A2AAgent) -> None: def test_get_uri_data_valid_uri() -> None: - """Test _get_uri_data with valid data URI.""" + """Test get_uri_data with valid data URI.""" uri = "data:application/json;base64,eyJ0ZXN0IjoidmFsdWUifQ==" - result = _get_uri_data(uri) + result = get_uri_data(uri) assert result == "eyJ0ZXN0IjoidmFsdWUifQ==" def test_get_uri_data_invalid_uri() -> None: - """Test _get_uri_data with invalid URI format.""" + """Test get_uri_data with invalid URI format.""" with raises(ValueError, match="Invalid data URI format"): - _get_uri_data("not-a-valid-data-uri") + get_uri_data("not-a-valid-data-uri") def test_parse_contents_from_a2a_conversion(a2a_agent: A2AAgent) -> None: diff --git a/python/packages/a2a/tests/test_a2a_executor.py b/python/packages/a2a/tests/test_a2a_executor.py new file mode 100644 index 0000000000..bd3ead046e --- /dev/null +++ b/python/packages/a2a/tests/test_a2a_executor.py @@ -0,0 +1,910 @@ +# Copyright (c) Microsoft. All rights reserved. +from asyncio import CancelledError +from unittest.mock import AsyncMock, MagicMock, patch +from uuid import uuid4 + +from a2a.types import Task, TaskState, TextPart +from agent_framework import ( + AgentResponseUpdate, + Content, + Message, + SupportsAgentRun, +) +from agent_framework._types import AgentResponse +from agent_framework.a2a import A2AExecutor +from pytest import fixture, raises + + +@fixture +def mock_agent() -> MagicMock: + """Fixture that provides a mock SupportsAgentRun.""" + agent = MagicMock(spec=SupportsAgentRun) + agent.run = AsyncMock() + return agent + + +@fixture +def mock_request_context() -> MagicMock: + """Fixture that provides a mock RequestContext.""" + request_context = MagicMock() + request_context.context_id = str(uuid4()) + request_context.get_user_input = MagicMock(return_value="Test query") + request_context.current_task = None + request_context.message = None + return request_context + + +@fixture +def mock_event_queue() -> MagicMock: + """Fixture that provides a mock EventQueue.""" + queue = AsyncMock() + queue.enqueue_event = AsyncMock() + return queue + + +@fixture +def mock_task() -> Task: + """Fixture that provides a mock Task.""" + task = MagicMock(spec=Task) + task.id = str(uuid4()) + task.context_id = str(uuid4()) + task.state = TaskState.completed + return task + + +@fixture +def mock_task_updater() -> MagicMock: + """Fixture that provides a mock TaskUpdater.""" + updater = MagicMock() + updater.submit = AsyncMock() + updater.start_work = AsyncMock() + updater.complete = AsyncMock() + updater.update_status = AsyncMock() + updater.new_agent_message = MagicMock() + return updater + + +@fixture +def executor(mock_agent: MagicMock) -> A2AExecutor: + """Fixture that provides an A2AExecutor.""" + return A2AExecutor(agent=mock_agent) + + +class TestA2AExecutorInitialization: + """Tests for A2AExecutor initialization.""" + + def test_initialization_with_agent_only(self, mock_agent: MagicMock) -> None: + """Arrange: Create mock agent + Act: Initialize A2AExecutor with only agent + Assert: Executor is created with default values + """ + # Act + executor = A2AExecutor(agent=mock_agent) + + # Assert + assert executor._agent is mock_agent + assert executor._stream is False + assert executor._run_kwargs == {} + + def test_initialization_with_stream_and_kwargs(self, mock_agent: MagicMock) -> None: + """Arrange: Create mock agent + Act: Initialize A2AExecutor with stream and run_kwargs + Assert: Executor is created with specified values + """ + # Arrange + run_kwargs = {"temperature": 0.5} + + # Act + executor = A2AExecutor(agent=mock_agent, stream=True, run_kwargs=run_kwargs) + + # Assert + assert executor._agent is mock_agent + assert executor._stream is True + assert executor._run_kwargs == run_kwargs + + def test_initialization_with_invalid_run_kwargs(self, mock_agent: MagicMock) -> None: + """Arrange: Create mock agent + Act: Initialize A2AExecutor with reserved keys in run_kwargs + Assert: ValueError is raised + """ + # Act & Assert + with raises(ValueError, match="run_kwargs cannot contain 'session'"): + A2AExecutor(agent=mock_agent, run_kwargs={"session": "something"}) + + with raises(ValueError, match="run_kwargs cannot contain 'stream'"): + A2AExecutor(agent=mock_agent, run_kwargs={"stream": True}) + + +class TestA2AExecutorCancel: + """Tests for the cancel method.""" + + async def test_cancel_method_completes( + self, + executor: A2AExecutor, + mock_request_context: MagicMock, + mock_event_queue: MagicMock, + ) -> None: + """Arrange: Create executor with dependencies + Act: Call cancel method + Assert: Method completes without raising error + """ + # Arrange + mock_request_context.task_id = "task-123" + + # Act & Assert (should not raise) + await executor.cancel(mock_request_context, mock_event_queue) # type: ignore + + async def test_cancel_handles_different_contexts( + self, + executor: A2AExecutor, + mock_event_queue: MagicMock, + ) -> None: + """Arrange: Create executor with multiple request contexts + Act: Call cancel with different contexts + Assert: Each cancel completes successfully + """ + # Arrange + context1 = MagicMock() + context1.context_id = "ctx-1" + context1.task_id = "task-1" + context2 = MagicMock() + context2.context_id = "ctx-2" + context2.task_id = "task-2" + + # Act & Assert + await executor.cancel(context1, mock_event_queue) # type: ignore + await executor.cancel(context2, mock_event_queue) # type: ignore + + async def test_cancel_raises_error_when_context_id_missing( + self, + executor: A2AExecutor, + mock_event_queue: MagicMock, + ) -> None: + """Arrange: Create context without context_id + Act: Call cancel method + Assert: ValueError is raised + """ + # Arrange + mock_context = MagicMock() + mock_context.context_id = None + + # Act & Assert + with raises(ValueError) as excinfo: + await executor.cancel(mock_context, mock_event_queue) # type: ignore + + # Assert + assert "Context ID" in str(excinfo.value) + + +class TestA2AExecutorExecute: + """Tests for the execute method.""" + + async def test_execute_with_existing_task_succeeds( + self, + executor: A2AExecutor, + mock_request_context: MagicMock, + mock_event_queue: MagicMock, + mock_task: Task, + ) -> None: + """Arrange: Create executor with mocked dependencies and existing task + Act: Call execute method + Assert: Execution completes successfully + """ + # Arrange + mock_request_context.get_user_input = MagicMock(return_value="Hello") + mock_request_context.current_task = mock_task + mock_request_context.context_id = "ctx-123" + mock_request_context.message = MagicMock() + + response_message = Message(role="assistant", contents=[Content.from_text(text="Hello back")]) + response = MagicMock(spec=AgentResponse) + response.messages = [response_message] + executor._agent.run = AsyncMock(return_value=response) + executor._agent.create_session = MagicMock() + + with patch("agent_framework_a2a._a2a_executor.TaskUpdater") as mock_updater_class: + mock_updater = MagicMock() + mock_updater.submit = AsyncMock() + mock_updater.start_work = AsyncMock() + mock_updater.complete = AsyncMock() + mock_updater.update_status = AsyncMock() + mock_updater.new_agent_message = MagicMock(return_value="message_obj") + mock_updater_class.return_value = mock_updater + + # Act + await executor.execute(mock_request_context, mock_event_queue) + + # Assert + mock_updater.submit.assert_called_once() + mock_updater.start_work.assert_called_once() + mock_updater.complete.assert_called_once() + executor._agent.create_session.assert_called_once() + executor._agent.run.assert_called_once() + + async def test_execute_creates_task_when_not_exists( + self, + executor: A2AExecutor, + mock_request_context: MagicMock, + mock_event_queue: MagicMock, + ) -> None: + """Arrange: Create executor with request context without task + Act: Call execute method + Assert: New task is created and enqueued + """ + # Arrange + mock_message = MagicMock() + mock_request_context.get_user_input = MagicMock(return_value="Hello") + mock_request_context.current_task = None + mock_request_context.message = mock_message + mock_request_context.context_id = "ctx-123" + + response_message = Message(role="assistant", contents=[Content.from_text(text="Response")]) + response = MagicMock(spec=AgentResponse) + response.messages = [response_message] + executor._agent.run = AsyncMock(return_value=response) + executor._agent.create_session = MagicMock() + + with patch("agent_framework_a2a._a2a_executor.new_task") as mock_new_task: + mock_task = MagicMock(spec=Task) + mock_task.id = "task-new" + mock_task.context_id = "ctx-123" + mock_new_task.return_value = mock_task + + with patch("agent_framework_a2a._a2a_executor.TaskUpdater") as mock_updater_class: + mock_updater = MagicMock() + mock_updater.submit = AsyncMock() + mock_updater.start_work = AsyncMock() + mock_updater.complete = AsyncMock() + mock_updater.update_status = AsyncMock() + mock_updater.new_agent_message = MagicMock(return_value="message_obj") + mock_updater_class.return_value = mock_updater + + # Act + await executor.execute(mock_request_context, mock_event_queue) + + # Assert + mock_new_task.assert_called_once() + mock_event_queue.enqueue_event.assert_called_once() + + async def test_execute_raises_error_when_context_id_missing( + self, + executor: A2AExecutor, + mock_request_context: MagicMock, + mock_event_queue: MagicMock, + ) -> None: + """Arrange: Create context without context_id + Act: Call execute method + Assert: ValueError is raised + """ + # Arrange + mock_request_context.context_id = None + mock_request_context.message = MagicMock() + + # Act & Assert + with raises(ValueError) as excinfo: + await executor.execute(mock_request_context, mock_event_queue) + + # Assert + assert "Context ID" in str(excinfo.value) + + async def test_execute_raises_error_when_message_missing( + self, + executor: A2AExecutor, + mock_request_context: MagicMock, + mock_event_queue: MagicMock, + ) -> None: + """Arrange: Create context without message + Act: Call execute method + Assert: ValueError is raised + """ + # Arrange + mock_request_context.context_id = "ctx-123" + mock_request_context.message = None + + # Act & Assert + with raises(ValueError) as excinfo: + await executor.execute(mock_request_context, mock_event_queue) + + # Assert + assert "Message" in str(excinfo.value) + + async def test_execute_handles_cancelled_error( + self, + executor: A2AExecutor, + mock_request_context: MagicMock, + mock_event_queue: MagicMock, + mock_task: Task, + ) -> None: + """Arrange: Create executor that raises CancelledError + Act: Call execute method + Assert: Error is caught and task is marked as canceled + """ + # Arrange + mock_request_context.get_user_input = MagicMock(return_value="Hello") + mock_request_context.current_task = mock_task + mock_request_context.context_id = "ctx-123" + mock_request_context.message = MagicMock() + + executor._agent.run = AsyncMock(side_effect=CancelledError()) + executor._agent.create_session = MagicMock() + + with patch("agent_framework_a2a._a2a_executor.TaskUpdater") as mock_updater_class: + mock_updater = MagicMock() + mock_updater.submit = AsyncMock() + mock_updater.start_work = AsyncMock() + mock_updater.update_status = AsyncMock() + mock_updater_class.return_value = mock_updater + + # Act + await executor.execute(mock_request_context, mock_event_queue) # type: ignore + + # Assert + mock_updater.update_status.assert_called() + call_args_list = mock_updater.update_status.call_args_list + assert any( + call[1].get("state") == TaskState.canceled and call[1].get("final") is True for call in call_args_list + ) + + async def test_execute_handles_generic_exception( + self, + executor: A2AExecutor, + mock_request_context: MagicMock, + mock_event_queue: MagicMock, + mock_task: Task, + ) -> None: + """Arrange: Create executor that raises generic exception + Act: Call execute method + Assert: Error is caught and task is marked as failed + """ + # Arrange + mock_request_context.get_user_input = MagicMock(return_value="Hello") + mock_request_context.current_task = mock_task + mock_request_context.context_id = "ctx-123" + mock_request_context.message = MagicMock() + + error_message = "Test error" + executor._agent.run = AsyncMock(side_effect=ValueError(error_message)) + executor._agent.create_session = MagicMock() + + with patch("agent_framework_a2a._a2a_executor.TaskUpdater") as mock_updater_class: + mock_updater = MagicMock() + mock_updater.submit = AsyncMock() + mock_updater.start_work = AsyncMock() + mock_updater.update_status = AsyncMock() + mock_updater.new_agent_message = MagicMock(return_value="error_message_obj") + mock_updater_class.return_value = mock_updater + + # Act + await executor.execute(mock_request_context, mock_event_queue) + + # Assert + mock_updater.new_agent_message.assert_called_once() + args, _ = mock_updater.new_agent_message.call_args + parts = args[0] + assert len(parts) == 1 + assert isinstance(parts[0].root, TextPart) + assert parts[0].root.text == error_message + + call_args_list = mock_updater.update_status.call_args_list + assert any( + call[1].get("state") == TaskState.failed + and call[1].get("final") is True + and call[1].get("message") == "error_message_obj" + for call in call_args_list + ) + + async def test_execute_processes_multiple_response_messages( + self, + executor: A2AExecutor, + mock_request_context: MagicMock, + mock_event_queue: MagicMock, + mock_task: Task, + ) -> None: + """Arrange: Create executor that returns multiple response messages + Act: Call execute method + Assert: All messages are processed through handle_events + """ + # Arrange + mock_request_context.get_user_input = MagicMock(return_value="Hello") + mock_request_context.current_task = mock_task + mock_request_context.context_id = "ctx-123" + mock_request_context.message = MagicMock() + + response_message1 = Message(role="assistant", contents=[Content.from_text(text="First")]) + response_message2 = Message(role="assistant", contents=[Content.from_text(text="Second")]) + response = MagicMock(spec=AgentResponse) + response.messages = [response_message1, response_message2] + executor._agent.run = AsyncMock(return_value=response) + executor._agent.create_session = MagicMock() + + # Mock handle_events + executor.handle_events = AsyncMock() + + with patch("agent_framework_a2a._a2a_executor.TaskUpdater") as mock_updater_class: + mock_updater = MagicMock() + mock_updater.submit = AsyncMock() + mock_updater.start_work = AsyncMock() + mock_updater.complete = AsyncMock() + mock_updater.update_status = AsyncMock() + mock_updater_class.return_value = mock_updater + + # Act + await executor.execute(mock_request_context, mock_event_queue) + + # Assert + assert executor.handle_events.call_count == 2 + + async def test_execute_passes_query_to_run( + self, + executor: A2AExecutor, + mock_request_context: MagicMock, + mock_event_queue: MagicMock, + mock_task: Task, + ) -> None: + """Arrange: Create executor with request + Act: Call execute method + Assert: Query text is passed to run method with default stream and kwargs + """ + # Arrange + query_text = "Hello agent" + mock_request_context.get_user_input = MagicMock(return_value=query_text) + mock_request_context.current_task = mock_task + mock_request_context.context_id = "ctx-123" + mock_request_context.message = MagicMock() + + response_message = Message(role="assistant", contents=[Content.from_text(text="Response")]) + response = MagicMock(spec=AgentResponse) + response.messages = [response_message] + executor._agent.run = AsyncMock(return_value=response) + executor._agent.create_session = MagicMock() + + with patch("agent_framework_a2a._a2a_executor.TaskUpdater") as mock_updater_class: + mock_updater = MagicMock() + mock_updater.submit = AsyncMock() + mock_updater.start_work = AsyncMock() + mock_updater.complete = AsyncMock() + mock_updater.update_status = AsyncMock() + mock_updater.new_agent_message = MagicMock(return_value="message_obj") + mock_updater_class.return_value = mock_updater + + # Act + await executor.execute(mock_request_context, mock_event_queue) + + # Assert + executor._agent.run.assert_called_once_with( + query_text, session=executor._agent.create_session(), stream=False + ) + + async def test_execute_with_stream_enabled( + self, + mock_agent: MagicMock, + mock_request_context: MagicMock, + mock_event_queue: MagicMock, + mock_task: Task, + ) -> None: + """Arrange: Create executor with stream=True + Act: Call execute method + Assert: _run_stream is called and passes stream=True to run + """ + # Arrange + executor = A2AExecutor(agent=mock_agent, stream=True) + query_text = "Hello agent" + mock_request_context.get_user_input = MagicMock(return_value=query_text) + mock_request_context.current_task = mock_task + mock_request_context.context_id = "ctx-123" + mock_request_context.message = MagicMock() + + mock_response_stream = MagicMock() + mock_response_stream.with_transform_hook = MagicMock(return_value=mock_response_stream) + mock_response_stream.get_final_response = AsyncMock() + mock_agent.run = MagicMock(return_value=mock_response_stream) + mock_agent.create_session = MagicMock() + + with patch("agent_framework_a2a._a2a_executor.TaskUpdater") as mock_updater_class: + mock_updater = MagicMock() + mock_updater.submit = AsyncMock() + mock_updater.start_work = AsyncMock() + mock_updater.complete = AsyncMock() + mock_updater.update_status = AsyncMock() + mock_updater_class.return_value = mock_updater + + # Act + await executor.execute(mock_request_context, mock_event_queue) + + # Assert + mock_agent.run.assert_called_once_with(query_text, session=mock_agent.create_session(), stream=True) + mock_response_stream.with_transform_hook.assert_called_once() + mock_response_stream.get_final_response.assert_called_once() + + async def test_execute_with_run_kwargs( + self, + mock_agent: MagicMock, + mock_request_context: MagicMock, + mock_event_queue: MagicMock, + mock_task: Task, + ) -> None: + """Arrange: Create executor with run_kwargs + Act: Call execute method + Assert: run_kwargs are passed to run method + """ + # Arrange + run_kwargs = {"temperature": 0.5, "max_tokens": 100} + executor = A2AExecutor(agent=mock_agent, run_kwargs=run_kwargs) + query_text = "Hello agent" + mock_request_context.get_user_input = MagicMock(return_value=query_text) + mock_request_context.current_task = mock_task + mock_request_context.context_id = "ctx-123" + mock_request_context.message = MagicMock() + + response_message = Message(role="assistant", contents=[Content.from_text(text="Response")]) + response = MagicMock(spec=AgentResponse) + response.messages = [response_message] + mock_agent.run = AsyncMock(return_value=response) + mock_agent.create_session = MagicMock() + + with patch("agent_framework_a2a._a2a_executor.TaskUpdater") as mock_updater_class: + mock_updater = MagicMock() + mock_updater.submit = AsyncMock() + mock_updater.start_work = AsyncMock() + mock_updater.complete = AsyncMock() + mock_updater.update_status = AsyncMock() + mock_updater_class.return_value = mock_updater + + # Act + await executor.execute(mock_request_context, mock_event_queue) + + # Assert + mock_agent.run.assert_called_once_with( + query_text, session=mock_agent.create_session(), stream=False, **run_kwargs + ) + + +class TestA2AExecutorHandleEvents: + """Tests for A2AExecutor.handle_events method.""" + + async def test_run_method_with_single_message(self, executor: A2AExecutor, mock_updater: MagicMock) -> None: + """Test the private _run method with a single message (not a list).""" + # Arrange + query = "test query" + session = MagicMock() + response_message = Message(role="assistant", contents=[Content.from_text(text="Response")]) + response = MagicMock(spec=AgentResponse) + response.messages = response_message # Not a list + executor._agent.run = AsyncMock(return_value=response) + executor.handle_events = AsyncMock() + + # Act + await executor._run(query, session, mock_updater) + + # Assert + executor.handle_events.assert_called_once_with(response_message, mock_updater) + + @fixture + def mock_updater(self) -> MagicMock: + """Create a mock execution context.""" + updater = MagicMock() + updater.update_status = AsyncMock() + updater.new_agent_message = MagicMock(return_value="mock_message") + return updater + + async def test_ignore_user_messages(self, executor: A2AExecutor, mock_updater: MagicMock) -> None: + """Test that messages from USER role are ignored.""" + # Arrange + message = Message( + contents=[Content.from_text(text="User input")], + role="user", + ) + + # Act + await executor.handle_events(message, mock_updater) + + # Assert + mock_updater.update_status.assert_not_called() + + async def test_ignore_messages_with_no_contents(self, executor: A2AExecutor, mock_updater: MagicMock) -> None: + """Test that messages with no contents are ignored.""" + # Arrange + message = Message( + contents=[], + role="assistant", + ) + + # Act + await executor.handle_events(message, mock_updater) + + # Assert + mock_updater.update_status.assert_not_called() + + async def test_handle_text_content(self, executor: A2AExecutor, mock_updater: MagicMock) -> None: + """Test handling messages with text content.""" + # Arrange + text = "Hello, this is a test message" + message = Message( + contents=[Content.from_text(text=text)], + role="assistant", + ) + + # Act + await executor.handle_events(message, mock_updater) + + # Assert + mock_updater.update_status.assert_called_once() + call_args = mock_updater.update_status.call_args + assert call_args.kwargs["state"] == TaskState.working + assert mock_updater.new_agent_message.called + + async def test_handle_multiple_text_contents(self, executor: A2AExecutor, mock_updater: MagicMock) -> None: + """Test handling messages with multiple text contents.""" + # Arrange + message = Message( + contents=[ + Content.from_text(text="First message"), + Content.from_text(text="Second message"), + ], + role="assistant", + ) + + # Act + await executor.handle_events(message, mock_updater) + + # Assert + mock_updater.update_status.assert_called_once() + assert mock_updater.new_agent_message.called + + async def test_handle_data_content(self, executor: A2AExecutor, mock_updater: MagicMock) -> None: + """Test handling messages with data content.""" + # Arrange + data = b"test file data" + message = Message( + contents=[Content.from_data(data=data, media_type="application/octet-stream")], + role="assistant", + ) + + # Act + await executor.handle_events(message, mock_updater) + + # Assert + mock_updater.update_status.assert_called_once() + call_args = mock_updater.update_status.call_args + assert call_args.kwargs["state"] == TaskState.working + + async def test_handle_uri_content(self, executor: A2AExecutor, mock_updater: MagicMock) -> None: + """Test handling messages with URI content.""" + # Arrange + uri = "https://example.com/file.pdf" + message = Message( + contents=[Content.from_uri(uri=uri, media_type="application/pdf")], + role="assistant", + ) + + # Act + await executor.handle_events(message, mock_updater) + + # Assert + mock_updater.update_status.assert_called_once() + call_args = mock_updater.update_status.call_args + assert call_args.kwargs["state"] == TaskState.working + + async def test_handle_mixed_content_types(self, executor: A2AExecutor, mock_updater: MagicMock) -> None: + """Test handling messages with mixed content types.""" + # Arrange + data = b"file data" + + message = Message( + contents=[ + Content.from_text(text="Processing file..."), + Content.from_data(data=data, media_type="application/octet-stream"), + Content.from_uri(uri="https://example.com/reference.pdf", media_type="application/pdf"), + ], + role="assistant", + ) + + # Act + await executor.handle_events(message, mock_updater) + + # Assert + mock_updater.update_status.assert_called_once() + call_args = mock_updater.update_status.call_args + assert call_args.kwargs["state"] == TaskState.working + + async def test_handle_with_additional_properties(self, executor: A2AExecutor, mock_updater: MagicMock) -> None: + """Test handling messages with additional properties metadata.""" + # Arrange + additional_props = {"custom_field": "custom_value", "priority": "high"} + message = Message( + contents=[Content.from_text(text="Test message")], + role="assistant", + additional_properties=additional_props, + ) + + # Act + await executor.handle_events(message, mock_updater) + + # Assert + mock_updater.update_status.assert_called_once() + mock_updater.new_agent_message.assert_called_once() + call_args = mock_updater.new_agent_message.call_args + assert call_args.kwargs["metadata"] == additional_props + + async def test_handle_with_no_additional_properties(self, executor: A2AExecutor, mock_updater: MagicMock) -> None: + """Test handling messages without additional properties.""" + # Arrange + message = Message( + contents=[Content.from_text(text="Test message")], + role="assistant", + additional_properties=None, + ) + + # Act + await executor.handle_events(message, mock_updater) + + # Assert + mock_updater.update_status.assert_called_once() + mock_updater.new_agent_message.assert_called_once() + call_args = mock_updater.new_agent_message.call_args + assert call_args.kwargs["metadata"] == {} + + async def test_parts_list_passed_to_new_agent_message(self, executor: A2AExecutor, mock_updater: MagicMock) -> None: + """Test that parts list is correctly passed to new_agent_message.""" + # Arrange + message = Message( + contents=[ + Content.from_text(text="Message 1"), + Content.from_text(text="Message 2"), + ], + role="assistant", + ) + + # Act + await executor.handle_events(message, mock_updater) + + # Assert + mock_updater.new_agent_message.assert_called_once() + call_kwargs = mock_updater.new_agent_message.call_args.kwargs + assert "parts" in call_kwargs + parts_list = call_kwargs["parts"] + assert len(parts_list) == 2 + + async def test_task_state_always_working(self, executor: A2AExecutor, mock_updater: MagicMock) -> None: + """Test that task state is always set to working.""" + # Arrange + message = Message( + contents=[Content.from_text(text="Any message")], + role="assistant", + ) + + # Act + await executor.handle_events(message, mock_updater) + + # Assert + call_kwargs = mock_updater.update_status.call_args.kwargs + assert call_kwargs["state"] == TaskState.working + + async def test_handle_agent_response_update_no_streamed_set( + self, executor: A2AExecutor, mock_updater: MagicMock + ) -> None: + """Test handling AgentResponseUpdate (streaming) without a tracking set.""" + # Arrange + update = AgentResponseUpdate( + contents=[Content.from_text(text="Streaming chunk")], + role="assistant", + message_id="msg-1", + ) + mock_updater.add_artifact = AsyncMock() + + # Act + await executor.handle_events(update, mock_updater) + + # Assert + mock_updater.add_artifact.assert_called_once() + call_kwargs = mock_updater.add_artifact.call_args.kwargs + assert call_kwargs["artifact_id"] == "msg-1" + assert call_kwargs["append"] is None + + async def test_handle_agent_response_update_first_time( + self, executor: A2AExecutor, mock_updater: MagicMock + ) -> None: + """Test handling AgentResponseUpdate (streaming) for the first time with a tracking set.""" + # Arrange + update = AgentResponseUpdate( + contents=[Content.from_text(text="Streaming chunk")], + role="assistant", + message_id="msg-1", + ) + mock_updater.add_artifact = AsyncMock() + streamed_artifact_ids = set() + + # Act + await executor.handle_events(update, mock_updater, streamed_artifact_ids=streamed_artifact_ids) + + # Assert + mock_updater.add_artifact.assert_called_once() + call_kwargs = mock_updater.add_artifact.call_args.kwargs + assert call_kwargs["append"] is None + assert "msg-1" in streamed_artifact_ids + + async def test_handle_agent_response_update_subsequent_time( + self, executor: A2AExecutor, mock_updater: MagicMock + ) -> None: + """Test handling AgentResponseUpdate (streaming) for subsequent times with a tracking set.""" + # Arrange + update = AgentResponseUpdate( + contents=[Content.from_text(text="Next chunk")], + role="assistant", + message_id="msg-1", + ) + mock_updater.add_artifact = AsyncMock() + streamed_artifact_ids = {"msg-1"} + + # Act + await executor.handle_events(update, mock_updater, streamed_artifact_ids=streamed_artifact_ids) + + # Assert + mock_updater.add_artifact.assert_called_once() + call_kwargs = mock_updater.add_artifact.call_args.kwargs + assert call_kwargs["append"] is True + + async def test_handle_unsupported_content_type(self, executor: A2AExecutor, mock_updater: MagicMock) -> None: + """Test handling messages with unsupported content types.""" + # Arrange + message = Message( + contents=[Content(type="unknown", text="Some text")], + role="assistant", + ) + + # Act + with patch("agent_framework_a2a._a2a_executor.logger") as mock_logger: + await executor.handle_events(message, mock_updater) + + # Assert + mock_logger.warning.assert_called_once() + mock_updater.update_status.assert_not_called() + + +class TestA2AExecutorIntegration: + """Integration tests for A2AExecutor.""" + + async def test_full_execution_flow_with_responses( + self, + executor: A2AExecutor, + mock_request_context: MagicMock, + mock_event_queue: MagicMock, + mock_task: Task, + ) -> None: + """Arrange: Create executor with all mocked dependencies + Act: Execute full flow from request to completion + Assert: All components interact correctly + """ + # Arrange + mock_request_context.get_user_input = MagicMock(return_value="Hello agent") + mock_request_context.current_task = mock_task + mock_request_context.context_id = "ctx-123" + mock_request_context.message = MagicMock() + + response = MagicMock(spec=AgentResponse) + response_message = MagicMock(spec=Message) + response.messages = [response_message] + response_message.contents = [Content.from_text(text="Hello user")] + response_message.role = "assistant" + response_message.additional_properties = None + + executor._agent.run = AsyncMock(return_value=response) + executor._agent.create_session = MagicMock() + executor.handle_events = AsyncMock() + + with patch("agent_framework_a2a._a2a_executor.TaskUpdater") as mock_updater_class: + mock_updater = MagicMock() + mock_updater.submit = AsyncMock() + mock_updater.start_work = AsyncMock() + mock_updater.complete = AsyncMock() + mock_updater.update_status = AsyncMock() + mock_updater_class.return_value = mock_updater + + # Act + await executor.execute(mock_request_context, mock_event_queue) + + # Assert + mock_updater.submit.assert_called_once() + mock_updater.start_work.assert_called_once() + executor.handle_events.assert_called_once() + mock_updater.complete.assert_called_once() diff --git a/python/packages/a2a/tests/test_utils.py b/python/packages/a2a/tests/test_utils.py new file mode 100644 index 0000000000..2c73b2e7cf --- /dev/null +++ b/python/packages/a2a/tests/test_utils.py @@ -0,0 +1,41 @@ +# Copyright (c) Microsoft. All rights reserved. + +import pytest + +from agent_framework_a2a._utils import get_uri_data + + +def test_get_uri_data_valid() -> None: + """Test get_uri_data with valid data URIs.""" + # Simple text/plain + uri = "data:text/plain;base64,SGVsbG8sIFdvcmxkIQ==" + assert get_uri_data(uri) == "SGVsbG8sIFdvcmxkIQ==" + + # Image png + uri = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==" + assert get_uri_data(uri) == "iVBORw0KGgoAAAANSUhEUgfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==" + + # Application octet-stream + uri = "data:application/octet-stream;base64,AQIDBA==" + assert get_uri_data(uri) == "AQIDBA==" + + +def test_get_uri_data_invalid_format() -> None: + """Test get_uri_data with invalid URI formats.""" + invalid_uris = [ + "not-a-uri", + "http://example.com", + "data:text/plain;SGVsbG8sIFdvcmxkIQ==", # Missing base64 marker + "data:base64,SGVsbG8sIFdvcmxkIQ==", # Missing media type + "data:text/plain;charset=utf-8;base64,SGVsbG8sIFdvcmxkIQ==", # Extra parameters (current regex doesn't support) + "data:text/plain;base64,SGVsbG8sIFdvcmxkIQ== extra", + ] + for uri in invalid_uris: + with pytest.raises(ValueError, match="Invalid data URI format"): + get_uri_data(uri) + + +def test_get_uri_data_empty() -> None: + """Test get_uri_data with empty string.""" + with pytest.raises(ValueError, match="Invalid data URI format"): + get_uri_data("") diff --git a/python/packages/core/agent_framework/a2a/__init__.py b/python/packages/core/agent_framework/a2a/__init__.py index 7c7de63456..90daf5380a 100644 --- a/python/packages/core/agent_framework/a2a/__init__.py +++ b/python/packages/core/agent_framework/a2a/__init__.py @@ -7,6 +7,7 @@ Supported classes: - A2AAgent +- A2AExecutor """ import importlib @@ -14,7 +15,7 @@ IMPORT_PATH = "agent_framework_a2a" PACKAGE_NAME = "agent-framework-a2a" -_IMPORTS = ["A2AAgent"] +_IMPORTS = ["A2AAgent", "A2AExecutor"] def __getattr__(name: str) -> Any: diff --git a/python/packages/core/agent_framework/a2a/__init__.pyi b/python/packages/core/agent_framework/a2a/__init__.pyi index 5a54bb22a9..65aa8f1a37 100644 --- a/python/packages/core/agent_framework/a2a/__init__.pyi +++ b/python/packages/core/agent_framework/a2a/__init__.pyi @@ -1,9 +1,5 @@ # Copyright (c) Microsoft. All rights reserved. -from agent_framework_a2a import ( - A2AAgent, -) +from agent_framework_a2a import A2AAgent, A2AExecutor -__all__ = [ - "A2AAgent", -] +__all__ = ["A2AAgent", "A2AExecutor"] diff --git a/python/samples/04-hosting/a2a/README.md b/python/samples/04-hosting/a2a/README.md index f377eed8ba..aca25a4dab 100644 --- a/python/samples/04-hosting/a2a/README.md +++ b/python/samples/04-hosting/a2a/README.md @@ -12,6 +12,7 @@ The remaining files are supporting modules used by the server: | File | Description | |------|-------------| +| [`agent_framework_to_a2a.py`](agent_framework_to_a2a.py) | Exposes an agent_framework agent as an A2A-compliant server. Demonstrates how to wrap an agent_framework agent and expose it as an A2A service that other A2A clients can discover and communicate with. | | [`agent_definitions.py`](agent_definitions.py) | Agent and AgentCard factory definitions for invoice, policy, and logistics agents. | | [`agent_executor.py`](agent_executor.py) | Bridges the a2a-sdk `AgentExecutor` interface to Agent Framework agents. | | [`invoice_data.py`](invoice_data.py) | Mock invoice data and tool functions for the invoice agent. | @@ -60,6 +61,9 @@ In a separate terminal (from the same directory), point the client at a running ```powershell $env:A2A_AGENT_HOST = "http://localhost:5001/" uv run python agent_with_a2a.py + +# A2A server exposing an agent_framework agent +uv run python agent_framework_to_a2a.py ``` ### 3. Run the Function Tools Sample diff --git a/python/samples/04-hosting/a2a/agent_framework_to_a2a.py b/python/samples/04-hosting/a2a/agent_framework_to_a2a.py new file mode 100644 index 0000000000..0693b61b0a --- /dev/null +++ b/python/samples/04-hosting/a2a/agent_framework_to_a2a.py @@ -0,0 +1,70 @@ +# Copyright (c) Microsoft. All rights reserved. + +import uvicorn +from a2a.server.apps import A2AStarletteApplication +from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.tasks import InMemoryTaskStore +from a2a.types import ( + AgentCapabilities, + AgentCard, + AgentSkill, +) +from agent_framework import Agent +from agent_framework.a2a import A2AExecutor +from agent_framework.openai import OpenAIChatClient +from dotenv import load_dotenv + +load_dotenv() + +if __name__ == "__main__": + # --8<-- [start:AgentSkill] + flight_skill = AgentSkill( + id="Flight_Booking", + name="Flight Booking", + description="Search and book flights across Europe.", + tags=["flights", "travel", "europe"], + examples=[], + ) + hotel_skill = AgentSkill( + id="Hotel_Booking", + name="Hotel Booking", + description="Search and book hotels across Europe.", + tags=["hotels", "travel", "accommodation"], + examples=[], + ) + # --8<-- [end:AgentSkill] + + # --8<-- [start:AgentCard] + # This will be the public-facing agent card + public_agent_card = AgentCard( + name="Europe Travel Agent", + description="A helpful Europe Travel Agent that can help users search and book flights and hotels across Europe.", + url="http://localhost:9999/", + version="1.0.0", + defaultInputModes=["text"], + defaultOutputModes=["text"], + capabilities=AgentCapabilities(streaming=True), + skills=[flight_skill, hotel_skill], + ) + # --8<-- [end:AgentCard] + + agent = Agent( + client=OpenAIChatClient(), + name="Europe Travel Agent", + instructions="You are a helpful Europe Travel Agent. You can help users search and book flights and hotels across Europe." + ) + + request_handler = DefaultRequestHandler( + agent_executor=A2AExecutor(agent), + task_store=InMemoryTaskStore(), + ) + + server = A2AStarletteApplication( + agent_card=public_agent_card, + http_handler=request_handler, + ) + + server = server.build() + # print(schemas.get_schema(server.routes)) + + uvicorn.run(server, host="0.0.0.0", port=9999)