diff --git a/python/packages/a2a/agent_framework_a2a/_agent.py b/python/packages/a2a/agent_framework_a2a/_agent.py index a07be3cf2f..696a160cf6 100644 --- a/python/packages/a2a/agent_framework_a2a/_agent.py +++ b/python/packages/a2a/agent_framework_a2a/_agent.py @@ -295,7 +295,10 @@ def run( # pyright: ignore[reportIncompatibleMethodOverride] else: if not normalized_messages: raise ValueError("At least one message is required when starting a new task (no continuation_token).") - a2a_message = self._prepare_message_for_a2a(normalized_messages[-1]) + a2a_message = self._prepare_message_for_a2a( + normalized_messages[-1], + context_id=session.service_session_id if session else None, + ) a2a_stream = self.client.send_message(a2a_message) provider_session = session @@ -584,7 +587,7 @@ async def poll_task(self, continuation_token: A2AContinuationToken) -> AgentResp return AgentResponse.from_updates(updates) return AgentResponse(messages=[], response_id=task.id, raw_representation=task) - def _prepare_message_for_a2a(self, message: Message) -> A2AMessage: + def _prepare_message_for_a2a(self, message: Message, *, context_id: str | None = None) -> A2AMessage: """Prepare a Message for the A2A protocol. Transforms Agent Framework Message objects into A2A protocol Messages by: @@ -593,6 +596,13 @@ def _prepare_message_for_a2a(self, message: Message) -> A2AMessage: - Converting file references (URI/data/hosted_file) to FilePart objects - Preserving metadata and additional properties from the original message - Setting the role to 'user' as framework messages are treated as user input + + Args: + message: The framework Message to convert. + context_id: Optional fallback context identifier (e.g. derived from + ``AgentSession.service_session_id``). When the *message* already + carries a ``context_id`` in its ``additional_properties`` that + value takes precedence; otherwise this fallback is used. """ parts: list[A2APart] = [] if not message.contents: @@ -672,7 +682,7 @@ def _prepare_message_for_a2a(self, message: Message) -> A2AMessage: role=A2ARole("user"), parts=parts, message_id=message.message_id or uuid.uuid4().hex, - context_id=message.additional_properties.get("context_id"), + context_id=message.additional_properties.get("context_id") or context_id, metadata=metadata, ) diff --git a/python/packages/a2a/tests/test_a2a_agent.py b/python/packages/a2a/tests/test_a2a_agent.py index 484d71e22c..dbbad8a865 100644 --- a/python/packages/a2a/tests/test_a2a_agent.py +++ b/python/packages/a2a/tests/test_a2a_agent.py @@ -46,6 +46,7 @@ def __init__(self) -> None: self.responses: list[Any] = [] self.resubscribe_responses: list[Any] = [] self.get_task_response: Task | None = None + self.last_message: Any = None def add_message_response(self, message_id: str, text: str, role: str = "agent") -> None: """Add a mock Message response.""" @@ -111,6 +112,7 @@ def add_in_progress_task_response( async def send_message(self, message: Any) -> AsyncIterator[Any]: """Mock send_message method that yields responses.""" + self.last_message = message self.call_count += 1 # All queued responses are delivered as a single streaming batch per call. @@ -539,6 +541,37 @@ def test_prepare_message_for_a2a_forwards_context_id() -> None: assert result.metadata == {"trace_id": "trace-456"} +def test_prepare_message_for_a2a_uses_fallback_context_id() -> None: + """Test that context_id kwarg is used when message has no context_id property.""" + + agent = A2AAgent(client=MagicMock(), http_client=None) + + message = Message( + role="user", + contents=[Content.from_text(text="Hello")], + ) + + result = agent._prepare_message_for_a2a(message, context_id="session-ctx-1") + + assert result.context_id == "session-ctx-1" + + +def test_prepare_message_for_a2a_message_context_id_takes_precedence() -> None: + """Test that message.additional_properties context_id wins over the fallback.""" + + agent = A2AAgent(client=MagicMock(), http_client=None) + + message = Message( + role="user", + contents=[Content.from_text(text="Hello")], + additional_properties={"context_id": "explicit-ctx"}, + ) + + result = agent._prepare_message_for_a2a(message, context_id="session-ctx-1") + + assert result.context_id == "explicit-ctx" + + def test_parse_contents_from_a2a_with_data_part() -> None: """Test conversion of A2A DataPart.""" @@ -868,6 +901,43 @@ async def test_poll_task_completed(a2a_agent: A2AAgent, mock_a2a_client: MockA2A # endregion +# region Session context_id Integration Tests + + +@mark.asyncio +async def test_run_passes_session_service_session_id_as_context_id(mock_a2a_client: MockA2AClient) -> None: + """Test that run() wires session.service_session_id to the A2A message context_id.""" + agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None) + mock_a2a_client.add_message_response("msg-ctx", "reply") + + session = AgentSession(service_session_id="svc-session-42") + await agent.run("Hello", session=session) + + assert mock_a2a_client.last_message is not None + assert mock_a2a_client.last_message.context_id == "svc-session-42" + + +@mark.asyncio +async def test_run_message_context_id_takes_precedence_over_session(mock_a2a_client: MockA2AClient) -> None: + """Test that an explicit context_id on the message wins over session.service_session_id.""" + agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None) + mock_a2a_client.add_message_response("msg-ctx2", "reply") + + session = AgentSession(service_session_id="svc-session-42") + message = Message( + role="user", + contents=[Content.from_text(text="Hello")], + additional_properties={"context_id": "explicit-ctx"}, + ) + await agent.run(messages=[message], session=session) + + assert mock_a2a_client.last_message is not None + assert mock_a2a_client.last_message.context_id == "explicit-ctx" + + +# endregion + + # region Context Provider Tests diff --git a/python/samples/02-agents/conversations/file_history_provider.py b/python/samples/02-agents/conversations/file_history_provider.py index 04a87f8224..20735ffd17 100644 --- a/python/samples/02-agents/conversations/file_history_provider.py +++ b/python/samples/02-agents/conversations/file_history_provider.py @@ -21,7 +21,7 @@ from pydantic import Field try: - import orjson + import orjson # pyright: ignore[reportMissingImports] except ImportError: orjson = None diff --git a/python/samples/02-agents/conversations/file_history_provider_conversation_persistence.py b/python/samples/02-agents/conversations/file_history_provider_conversation_persistence.py index 70c5d7e8e8..693501b0f9 100644 --- a/python/samples/02-agents/conversations/file_history_provider_conversation_persistence.py +++ b/python/samples/02-agents/conversations/file_history_provider_conversation_persistence.py @@ -22,7 +22,7 @@ from pydantic import Field try: - import orjson + import orjson # pyright: ignore[reportMissingImports] except ImportError: orjson = None