Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions python/packages/a2a/agent_framework_a2a/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,10 @@ def run( # pyright: ignore[reportIncompatibleMethodOverride]
else:
if not normalized_messages:
raise ValueError("At least one message is required when starting a new task (no continuation_token).")
a2a_message = self._prepare_message_for_a2a(normalized_messages[-1])
a2a_message = self._prepare_message_for_a2a(
normalized_messages[-1],
context_id=session.service_session_id if session else None,
)
Comment thread
moonbox3 marked this conversation as resolved.
a2a_stream = self.client.send_message(a2a_message)

provider_session = session
Expand Down Expand Up @@ -584,7 +587,7 @@ async def poll_task(self, continuation_token: A2AContinuationToken) -> AgentResp
return AgentResponse.from_updates(updates)
return AgentResponse(messages=[], response_id=task.id, raw_representation=task)

def _prepare_message_for_a2a(self, message: Message) -> A2AMessage:
def _prepare_message_for_a2a(self, message: Message, *, context_id: str | None = None) -> A2AMessage:
"""Prepare a Message for the A2A protocol.

Transforms Agent Framework Message objects into A2A protocol Messages by:
Expand All @@ -593,6 +596,13 @@ def _prepare_message_for_a2a(self, message: Message) -> A2AMessage:
- Converting file references (URI/data/hosted_file) to FilePart objects
- Preserving metadata and additional properties from the original message
- Setting the role to 'user' as framework messages are treated as user input

Args:
message: The framework Message to convert.
context_id: Optional fallback context identifier (e.g. derived from
``AgentSession.service_session_id``). When the *message* already
carries a ``context_id`` in its ``additional_properties`` that
value takes precedence; otherwise this fallback is used.
"""
parts: list[A2APart] = []
if not message.contents:
Expand Down Expand Up @@ -672,7 +682,7 @@ def _prepare_message_for_a2a(self, message: Message) -> A2AMessage:
role=A2ARole("user"),
parts=parts,
message_id=message.message_id or uuid.uuid4().hex,
context_id=message.additional_properties.get("context_id"),
context_id=message.additional_properties.get("context_id") or context_id,
metadata=metadata,
)

Expand Down
70 changes: 70 additions & 0 deletions python/packages/a2a/tests/test_a2a_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(self) -> None:
self.responses: list[Any] = []
self.resubscribe_responses: list[Any] = []
self.get_task_response: Task | None = None
self.last_message: Any = None

def add_message_response(self, message_id: str, text: str, role: str = "agent") -> None:
"""Add a mock Message response."""
Expand Down Expand Up @@ -111,6 +112,7 @@ def add_in_progress_task_response(

async def send_message(self, message: Any) -> AsyncIterator[Any]:
"""Mock send_message method that yields responses."""
self.last_message = message
self.call_count += 1

# All queued responses are delivered as a single streaming batch per call.
Expand Down Expand Up @@ -539,6 +541,37 @@ def test_prepare_message_for_a2a_forwards_context_id() -> None:
assert result.metadata == {"trace_id": "trace-456"}


def test_prepare_message_for_a2a_uses_fallback_context_id() -> None:
"""Test that context_id kwarg is used when message has no context_id property."""

agent = A2AAgent(client=MagicMock(), http_client=None)

message = Message(
role="user",
contents=[Content.from_text(text="Hello")],
)

result = agent._prepare_message_for_a2a(message, context_id="session-ctx-1")

assert result.context_id == "session-ctx-1"


def test_prepare_message_for_a2a_message_context_id_takes_precedence() -> None:
"""Test that message.additional_properties context_id wins over the fallback."""

agent = A2AAgent(client=MagicMock(), http_client=None)

message = Message(
role="user",
contents=[Content.from_text(text="Hello")],
additional_properties={"context_id": "explicit-ctx"},
)

result = agent._prepare_message_for_a2a(message, context_id="session-ctx-1")

assert result.context_id == "explicit-ctx"


def test_parse_contents_from_a2a_with_data_part() -> None:
"""Test conversion of A2A DataPart."""

Expand Down Expand Up @@ -868,6 +901,43 @@ async def test_poll_task_completed(a2a_agent: A2AAgent, mock_a2a_client: MockA2A
# endregion


# region Session context_id Integration Tests


@mark.asyncio
async def test_run_passes_session_service_session_id_as_context_id(mock_a2a_client: MockA2AClient) -> None:
"""Test that run() wires session.service_session_id to the A2A message context_id."""
agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None)
mock_a2a_client.add_message_response("msg-ctx", "reply")

session = AgentSession(service_session_id="svc-session-42")
await agent.run("Hello", session=session)

assert mock_a2a_client.last_message is not None
assert mock_a2a_client.last_message.context_id == "svc-session-42"


@mark.asyncio
async def test_run_message_context_id_takes_precedence_over_session(mock_a2a_client: MockA2AClient) -> None:
"""Test that an explicit context_id on the message wins over session.service_session_id."""
agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None)
mock_a2a_client.add_message_response("msg-ctx2", "reply")

session = AgentSession(service_session_id="svc-session-42")
message = Message(
role="user",
contents=[Content.from_text(text="Hello")],
additional_properties={"context_id": "explicit-ctx"},
)
await agent.run(messages=[message], session=session)

assert mock_a2a_client.last_message is not None
assert mock_a2a_client.last_message.context_id == "explicit-ctx"


# endregion


# region Context Provider Tests


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pydantic import Field

try:
import orjson
import orjson # pyright: ignore[reportMissingImports]
except ImportError:
orjson = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pydantic import Field

try:
import orjson
import orjson # pyright: ignore[reportMissingImports]
except ImportError:
orjson = None

Expand Down
Loading