From 7f8a6badaac4c491483e713bc905396244b43eed Mon Sep 17 00:00:00 2001 From: Swapnil agarwal Date: Wed, 18 Jun 2025 15:26:47 -0700 Subject: [PATCH 1/2] feat: raise error for tasks in terminal states --- .../default_request_handler.py | 39 ++++- .../test_default_request_handler.py | 136 ++++++++++++++++++ 2 files changed, 174 insertions(+), 1 deletion(-) diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index cc5d921a..23232b4b 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -29,6 +29,7 @@ from a2a.types import ( GetTaskPushNotificationConfigParams, InternalError, + InvalidParamsError, Message, MessageSendConfiguration, MessageSendParams, @@ -38,6 +39,7 @@ TaskNotFoundError, TaskPushNotificationConfig, TaskQueryParams, + TaskState, UnsupportedOperationError, ) from a2a.utils.errors import ServerError @@ -178,6 +180,18 @@ async def on_message_send( ) task: Task | None = await task_manager.get_task() if task: + if task.status.state in { + TaskState.completed, + TaskState.canceled, + TaskState.failed, + TaskState.rejected, + }: + raise ServerError( + error=InvalidParamsError( + message=f'Task {task.id} is in terminal state: {task.status.state}' + ) + ) + task = task_manager.update_with_message(params.message, task) if self.should_add_push_info(params): assert isinstance(self._push_notifier, PushNotifier) @@ -264,8 +278,19 @@ async def on_message_send_stream( task: Task | None = await task_manager.get_task() if task: - task = task_manager.update_with_message(params.message, task) + if task.status.state in { + TaskState.completed, + TaskState.canceled, + TaskState.failed, + TaskState.rejected, + }: + raise ServerError( + error=InvalidParamsError( + message=f'Task {task.id} is in terminal state: {task.status.state}' + ) + ) + task = task_manager.update_with_message(params.message, task) if self.should_add_push_info(params): assert isinstance(self._push_notifier, PushNotifier) assert isinstance( @@ -413,6 +438,18 @@ async def on_resubscribe_to_task( if not task: raise ServerError(error=TaskNotFoundError()) + if task.status.state in { + TaskState.completed, + TaskState.canceled, + TaskState.failed, + TaskState.rejected, + }: + raise ServerError( + error=InvalidParamsError( + message=f'Task {task.id} is in terminal state: {task.status.state}' + ) + ) + task_manager = TaskManager( task_id=task.id, context_id=task.contextId, diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index e30b843c..2a89f3d6 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -28,6 +28,7 @@ ) from a2a.types import ( InternalError, + InvalidParamsError, Message, MessageSendConfiguration, MessageSendParams, @@ -1137,3 +1138,138 @@ async def consume_stream(): texts = [p.root.text for e in events for p in e.status.message.parts] assert texts == ['Event 0', 'Event 1', 'Event 2'] + + +TERMINAL_STATES = [ + TaskState.completed, + TaskState.canceled, + TaskState.failed, + TaskState.rejected, +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize('terminal_state', TERMINAL_STATES) +async def test_on_message_send_task_in_terminal_state(terminal_state): + """Test on_message_send when task is already in a terminal state.""" + task_id = f'terminal_task_{terminal_state.value}' + terminal_task = create_sample_task( + task_id=task_id, status_state=terminal_state + ) + + mock_task_store = AsyncMock(spec=TaskStore) + # The get method of TaskManager calls task_store.get. + # We mock TaskManager.get_task which is an async method. + # So we should patch that instead. + + request_handler = DefaultRequestHandler( + agent_executor=DummyAgentExecutor(), task_store=mock_task_store + ) + + params = MessageSendParams( + message=Message( + role=Role.user, + messageId='msg_terminal', + parts=[], + taskId=task_id, + ) + ) + + from a2a.utils.errors import ServerError + + # Patch the TaskManager's get_task method to return our terminal task + with patch( + 'a2a.server.request_handlers.default_request_handler.TaskManager.get_task', + return_value=terminal_task, + ): + with pytest.raises(ServerError) as exc_info: + await request_handler.on_message_send( + params, create_server_call_context() + ) + + assert isinstance(exc_info.value.error, InvalidParamsError) + assert exc_info.value.error.message + assert ( + f'Task {task_id} is in terminal state: {terminal_state.value}' + in exc_info.value.error.message + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize('terminal_state', TERMINAL_STATES) +async def test_on_message_send_stream_task_in_terminal_state(terminal_state): + """Test on_message_send_stream when task is already in a terminal state.""" + task_id = f'terminal_stream_task_{terminal_state.value}' + terminal_task = create_sample_task( + task_id=task_id, status_state=terminal_state + ) + + mock_task_store = AsyncMock(spec=TaskStore) + + request_handler = DefaultRequestHandler( + agent_executor=DummyAgentExecutor(), task_store=mock_task_store + ) + + params = MessageSendParams( + message=Message( + role=Role.user, + messageId='msg_terminal_stream', + parts=[], + taskId=task_id, + ) + ) + + from a2a.utils.errors import ServerError + + with patch( + 'a2a.server.request_handlers.default_request_handler.TaskManager.get_task', + return_value=terminal_task, + ): + with pytest.raises(ServerError) as exc_info: + async for _ in request_handler.on_message_send_stream( + params, create_server_call_context() + ): + pass # pragma: no cover + + assert isinstance(exc_info.value.error, InvalidParamsError) + assert exc_info.value.error.message + assert ( + f'Task {task_id} is in terminal state: {terminal_state.value}' + in exc_info.value.error.message + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize('terminal_state', TERMINAL_STATES) +async def test_on_resubscribe_to_task_in_terminal_state(terminal_state): + """Test on_resubscribe_to_task when task is in a terminal state.""" + task_id = f'resub_terminal_task_{terminal_state.value}' + terminal_task = create_sample_task( + task_id=task_id, status_state=terminal_state + ) + + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = terminal_task + + request_handler = DefaultRequestHandler( + agent_executor=DummyAgentExecutor(), + task_store=mock_task_store, + queue_manager=AsyncMock(spec=QueueManager), + ) + params = TaskIdParams(id=task_id) + + from a2a.utils.errors import ServerError + + with pytest.raises(ServerError) as exc_info: + async for _ in request_handler.on_resubscribe_to_task( + params, create_server_call_context() + ): + pass # pragma: no cover + + assert isinstance(exc_info.value.error, InvalidParamsError) + assert exc_info.value.error.message + assert ( + f'Task {task_id} is in terminal state: {terminal_state.value}' + in exc_info.value.error.message + ) + mock_task_store.get.assert_awaited_once_with(task_id) From 6e3451dd00046036e79f6fda7374ca77867a9fc3 Mon Sep 17 00:00:00 2001 From: Swapnil agarwal Date: Wed, 18 Jun 2025 15:33:12 -0700 Subject: [PATCH 2/2] define terminal states in a constant --- .../default_request_handler.py | 27 +++++++------------ .../test_default_request_handler.py | 12 ++++----- 2 files changed, 14 insertions(+), 25 deletions(-) diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index 23232b4b..f98d06c0 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -48,6 +48,12 @@ logger = logging.getLogger(__name__) +TERMINAL_TASK_STATES = { + TaskState.completed, + TaskState.canceled, + TaskState.failed, + TaskState.rejected, +} @trace_class(kind=SpanKind.SERVER) class DefaultRequestHandler(RequestHandler): @@ -180,12 +186,7 @@ async def on_message_send( ) task: Task | None = await task_manager.get_task() if task: - if task.status.state in { - TaskState.completed, - TaskState.canceled, - TaskState.failed, - TaskState.rejected, - }: + if task.status.state in TERMINAL_TASK_STATES: raise ServerError( error=InvalidParamsError( message=f'Task {task.id} is in terminal state: {task.status.state}' @@ -278,12 +279,7 @@ async def on_message_send_stream( task: Task | None = await task_manager.get_task() if task: - if task.status.state in { - TaskState.completed, - TaskState.canceled, - TaskState.failed, - TaskState.rejected, - }: + if task.status.state in TERMINAL_TASK_STATES: raise ServerError( error=InvalidParamsError( message=f'Task {task.id} is in terminal state: {task.status.state}' @@ -438,12 +434,7 @@ async def on_resubscribe_to_task( if not task: raise ServerError(error=TaskNotFoundError()) - if task.status.state in { - TaskState.completed, - TaskState.canceled, - TaskState.failed, - TaskState.rejected, - }: + if task.status.state in TERMINAL_TASK_STATES: raise ServerError( error=InvalidParamsError( message=f'Task {task.id} is in terminal state: {task.status.state}' diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index 2a89f3d6..6f67c0f8 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -1139,17 +1139,15 @@ async def consume_stream(): texts = [p.root.text for e in events for p in e.status.message.parts] assert texts == ['Event 0', 'Event 1', 'Event 2'] - -TERMINAL_STATES = [ +TERMINAL_TASK_STATES = { TaskState.completed, TaskState.canceled, TaskState.failed, TaskState.rejected, -] - +} @pytest.mark.asyncio -@pytest.mark.parametrize('terminal_state', TERMINAL_STATES) +@pytest.mark.parametrize('terminal_state', TERMINAL_TASK_STATES) async def test_on_message_send_task_in_terminal_state(terminal_state): """Test on_message_send when task is already in a terminal state.""" task_id = f'terminal_task_{terminal_state.value}' @@ -1196,7 +1194,7 @@ async def test_on_message_send_task_in_terminal_state(terminal_state): @pytest.mark.asyncio -@pytest.mark.parametrize('terminal_state', TERMINAL_STATES) +@pytest.mark.parametrize('terminal_state', TERMINAL_TASK_STATES) async def test_on_message_send_stream_task_in_terminal_state(terminal_state): """Test on_message_send_stream when task is already in a terminal state.""" task_id = f'terminal_stream_task_{terminal_state.value}' @@ -1240,7 +1238,7 @@ async def test_on_message_send_stream_task_in_terminal_state(terminal_state): @pytest.mark.asyncio -@pytest.mark.parametrize('terminal_state', TERMINAL_STATES) +@pytest.mark.parametrize('terminal_state', TERMINAL_TASK_STATES) async def test_on_resubscribe_to_task_in_terminal_state(terminal_state): """Test on_resubscribe_to_task when task is in a terminal state.""" task_id = f'resub_terminal_task_{terminal_state.value}'