Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions src/a2a/server/request_handlers/default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,12 +280,18 @@ async def on_message_send(
consumer = EventConsumer(queue)
producer_task.add_done_callback(consumer.agent_task_callback)

interrupted = False
blocking = True # Default to blocking behavior
if params.configuration and params.configuration.blocking is False:
blocking = False

interrupted_or_non_blocking = False
try:
(
result,
interrupted,
) = await result_aggregator.consume_and_break_on_interrupt(consumer)
interrupted_or_non_blocking,
) = await result_aggregator.consume_and_break_on_interrupt(
consumer, blocking=blocking
)
if not result:
raise ServerError(error=InternalError())

Expand All @@ -300,7 +306,7 @@ async def on_message_send(
logger.error(f'Agent execution failed. Error: {e}')
raise
finally:
if interrupted:
if interrupted_or_non_blocking:
# TODO: Track this disconnected cleanup task.
asyncio.create_task( # noqa: RUF006
self._cleanup_producer(producer_task, task_id)
Expand Down
27 changes: 23 additions & 4 deletions src/a2a/server/tasks/result_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,19 @@ async def consume_all(
return await self.task_manager.get_task()

async def consume_and_break_on_interrupt(
self, consumer: EventConsumer
self, consumer: EventConsumer, blocking: bool = True
) -> tuple[Task | Message | None, bool]:
"""Processes the event stream until completion or an interruptable state is encountered.

Interruptable states currently include `TaskState.auth_required`.
If `blocking` is False, it returns after the first event that creates a Task or Message.
If `blocking` is True, it waits for completion unless an `auth_required`
state is encountered, which is always an interruption.
If interrupted, consumption continues in a background task.

Args:
consumer: The `EventConsumer` to read events from.
blocking: If `False`, the method returns as soon as a task/message
is available. If `True`, it waits for a terminal state.

Returns:
A tuple containing:
Expand All @@ -117,10 +121,15 @@ async def consume_and_break_on_interrupt(
self._message = event
return event, False
await self.task_manager.process(event)
if (

should_interrupt = False
is_auth_required = (
isinstance(event, Task | TaskStatusUpdateEvent)
and event.status.state == TaskState.auth_required
):
)

# Always interrupt on auth_required, as it needs external action.
if is_auth_required:
# auth-required is a special state: the message should be
# escalated back to the caller, but the agent is expected to
# continue producing events once the authorization is received
Expand All @@ -130,6 +139,16 @@ async def consume_and_break_on_interrupt(
logger.debug(
'Encountered an auth-required task: breaking synchronous message/send flow.'
)
should_interrupt = True
# For non-blocking calls, interrupt as soon as a task is available.
elif not blocking:
logger.debug(
'Non-blocking call: returning task after first event.'
)
should_interrupt = True

if should_interrupt:
# Continue consuming the rest of the events in the background.
# TODO: We should track all outstanding tasks to ensure they eventually complete.
asyncio.create_task(self._continue_consuming(event_stream)) # noqa: RUF006
interrupted = True
Expand Down
38 changes: 38 additions & 0 deletions tests/server/tasks/test_result_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,44 @@ async def raiser_gen_interrupt():
)
self.mock_task_manager.get_task.assert_not_called()

@patch('asyncio.create_task')
async def test_consume_and_break_non_blocking(
self, mock_create_task: MagicMock
):
"""Test that with blocking=False, the method returns after the first event."""
first_event = create_sample_task('non_blocking_task')
event_after = create_sample_message('should be consumed later')

async def mock_consume_generator():
yield first_event
yield event_after

self.mock_event_consumer.consume_all.return_value = (
mock_consume_generator()
)
# After processing `first_event`, the current result will be that task.
self.aggregator.task_manager.get_task.return_value = first_event

self.aggregator._continue_consuming = AsyncMock()
mock_create_task.side_effect = lambda coro: asyncio.ensure_future(coro)

(
result,
interrupted,
) = await self.aggregator.consume_and_break_on_interrupt(
self.mock_event_consumer, blocking=False
)

self.assertEqual(result, first_event)
self.assertTrue(interrupted)
self.mock_task_manager.process.assert_called_once_with(first_event)
mock_create_task.assert_called_once()
# The background task should be created with the remaining stream
self.aggregator._continue_consuming.assert_called_once()
self.assertIsInstance(
self.aggregator._continue_consuming.call_args[0][0], AsyncIterator
)

@patch('asyncio.create_task') # To verify _continue_consuming is called
async def test_continue_consuming_processes_remaining_events(
self, mock_create_task: MagicMock
Expand Down
Loading