diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index 8c5f58a3..2549d087 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -280,12 +280,18 @@ async def on_message_send( consumer = EventConsumer(queue) producer_task.add_done_callback(consumer.agent_task_callback) - interrupted = False + blocking = True # Default to blocking behavior + if params.configuration and params.configuration.blocking is False: + blocking = False + + interrupted_or_non_blocking = False try: ( result, - interrupted, - ) = await result_aggregator.consume_and_break_on_interrupt(consumer) + interrupted_or_non_blocking, + ) = await result_aggregator.consume_and_break_on_interrupt( + consumer, blocking=blocking + ) if not result: raise ServerError(error=InternalError()) @@ -300,7 +306,7 @@ async def on_message_send( logger.error(f'Agent execution failed. Error: {e}') raise finally: - if interrupted: + if interrupted_or_non_blocking: # TODO: Track this disconnected cleanup task. asyncio.create_task( # noqa: RUF006 self._cleanup_producer(producer_task, task_id) diff --git a/src/a2a/server/tasks/result_aggregator.py b/src/a2a/server/tasks/result_aggregator.py index a11f0222..147c3202 100644 --- a/src/a2a/server/tasks/result_aggregator.py +++ b/src/a2a/server/tasks/result_aggregator.py @@ -92,15 +92,19 @@ async def consume_all( return await self.task_manager.get_task() async def consume_and_break_on_interrupt( - self, consumer: EventConsumer + self, consumer: EventConsumer, blocking: bool = True ) -> tuple[Task | Message | None, bool]: """Processes the event stream until completion or an interruptable state is encountered. - Interruptable states currently include `TaskState.auth_required`. + If `blocking` is False, it returns after the first event that creates a Task or Message. + If `blocking` is True, it waits for completion unless an `auth_required` + state is encountered, which is always an interruption. If interrupted, consumption continues in a background task. Args: consumer: The `EventConsumer` to read events from. + blocking: If `False`, the method returns as soon as a task/message + is available. If `True`, it waits for a terminal state. Returns: A tuple containing: @@ -117,10 +121,15 @@ async def consume_and_break_on_interrupt( self._message = event return event, False await self.task_manager.process(event) - if ( + + should_interrupt = False + is_auth_required = ( isinstance(event, Task | TaskStatusUpdateEvent) and event.status.state == TaskState.auth_required - ): + ) + + # Always interrupt on auth_required, as it needs external action. + if is_auth_required: # auth-required is a special state: the message should be # escalated back to the caller, but the agent is expected to # continue producing events once the authorization is received @@ -130,6 +139,16 @@ async def consume_and_break_on_interrupt( logger.debug( 'Encountered an auth-required task: breaking synchronous message/send flow.' ) + should_interrupt = True + # For non-blocking calls, interrupt as soon as a task is available. + elif not blocking: + logger.debug( + 'Non-blocking call: returning task after first event.' + ) + should_interrupt = True + + if should_interrupt: + # Continue consuming the rest of the events in the background. # TODO: We should track all outstanding tasks to ensure they eventually complete. asyncio.create_task(self._continue_consuming(event_stream)) # noqa: RUF006 interrupted = True diff --git a/tests/server/tasks/test_result_aggregator.py b/tests/server/tasks/test_result_aggregator.py index 7ff5915e..da77e693 100644 --- a/tests/server/tasks/test_result_aggregator.py +++ b/tests/server/tasks/test_result_aggregator.py @@ -384,6 +384,44 @@ async def raiser_gen_interrupt(): ) self.mock_task_manager.get_task.assert_not_called() + @patch('asyncio.create_task') + async def test_consume_and_break_non_blocking( + self, mock_create_task: MagicMock + ): + """Test that with blocking=False, the method returns after the first event.""" + first_event = create_sample_task('non_blocking_task') + event_after = create_sample_message('should be consumed later') + + async def mock_consume_generator(): + yield first_event + yield event_after + + self.mock_event_consumer.consume_all.return_value = ( + mock_consume_generator() + ) + # After processing `first_event`, the current result will be that task. + self.aggregator.task_manager.get_task.return_value = first_event + + self.aggregator._continue_consuming = AsyncMock() + mock_create_task.side_effect = lambda coro: asyncio.ensure_future(coro) + + ( + result, + interrupted, + ) = await self.aggregator.consume_and_break_on_interrupt( + self.mock_event_consumer, blocking=False + ) + + self.assertEqual(result, first_event) + self.assertTrue(interrupted) + self.mock_task_manager.process.assert_called_once_with(first_event) + mock_create_task.assert_called_once() + # The background task should be created with the remaining stream + self.aggregator._continue_consuming.assert_called_once() + self.assertIsInstance( + self.aggregator._continue_consuming.call_args[0][0], AsyncIterator + ) + @patch('asyncio.create_task') # To verify _continue_consuming is called async def test_continue_consuming_processes_remaining_events( self, mock_create_task: MagicMock