diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index cc5d921a..f98d06c0 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -29,6 +29,7 @@ from a2a.types import ( GetTaskPushNotificationConfigParams, InternalError, + InvalidParamsError, Message, MessageSendConfiguration, MessageSendParams, @@ -38,6 +39,7 @@ TaskNotFoundError, TaskPushNotificationConfig, TaskQueryParams, + TaskState, UnsupportedOperationError, ) from a2a.utils.errors import ServerError @@ -46,6 +48,12 @@ logger = logging.getLogger(__name__) +TERMINAL_TASK_STATES = { + TaskState.completed, + TaskState.canceled, + TaskState.failed, + TaskState.rejected, +} @trace_class(kind=SpanKind.SERVER) class DefaultRequestHandler(RequestHandler): @@ -178,6 +186,13 @@ async def on_message_send( ) task: Task | None = await task_manager.get_task() if task: + if task.status.state in TERMINAL_TASK_STATES: + raise ServerError( + error=InvalidParamsError( + message=f'Task {task.id} is in terminal state: {task.status.state}' + ) + ) + task = task_manager.update_with_message(params.message, task) if self.should_add_push_info(params): assert isinstance(self._push_notifier, PushNotifier) @@ -264,8 +279,14 @@ async def on_message_send_stream( task: Task | None = await task_manager.get_task() if task: - task = task_manager.update_with_message(params.message, task) + if task.status.state in TERMINAL_TASK_STATES: + raise ServerError( + error=InvalidParamsError( + message=f'Task {task.id} is in terminal state: {task.status.state}' + ) + ) + task = task_manager.update_with_message(params.message, task) if self.should_add_push_info(params): assert isinstance(self._push_notifier, PushNotifier) assert isinstance( @@ -413,6 +434,13 @@ async def on_resubscribe_to_task( if not task: raise ServerError(error=TaskNotFoundError()) + if task.status.state in TERMINAL_TASK_STATES: + raise ServerError( + error=InvalidParamsError( + message=f'Task {task.id} is in terminal state: {task.status.state}' + ) + ) + task_manager = TaskManager( task_id=task.id, context_id=task.contextId, diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index e30b843c..6f67c0f8 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -28,6 +28,7 @@ ) from a2a.types import ( InternalError, + InvalidParamsError, Message, MessageSendConfiguration, MessageSendParams, @@ -1137,3 +1138,136 @@ async def consume_stream(): texts = [p.root.text for e in events for p in e.status.message.parts] assert texts == ['Event 0', 'Event 1', 'Event 2'] + +TERMINAL_TASK_STATES = { + TaskState.completed, + TaskState.canceled, + TaskState.failed, + TaskState.rejected, +} + +@pytest.mark.asyncio +@pytest.mark.parametrize('terminal_state', TERMINAL_TASK_STATES) +async def test_on_message_send_task_in_terminal_state(terminal_state): + """Test on_message_send when task is already in a terminal state.""" + task_id = f'terminal_task_{terminal_state.value}' + terminal_task = create_sample_task( + task_id=task_id, status_state=terminal_state + ) + + mock_task_store = AsyncMock(spec=TaskStore) + # The get method of TaskManager calls task_store.get. + # We mock TaskManager.get_task which is an async method. + # So we should patch that instead. + + request_handler = DefaultRequestHandler( + agent_executor=DummyAgentExecutor(), task_store=mock_task_store + ) + + params = MessageSendParams( + message=Message( + role=Role.user, + messageId='msg_terminal', + parts=[], + taskId=task_id, + ) + ) + + from a2a.utils.errors import ServerError + + # Patch the TaskManager's get_task method to return our terminal task + with patch( + 'a2a.server.request_handlers.default_request_handler.TaskManager.get_task', + return_value=terminal_task, + ): + with pytest.raises(ServerError) as exc_info: + await request_handler.on_message_send( + params, create_server_call_context() + ) + + assert isinstance(exc_info.value.error, InvalidParamsError) + assert exc_info.value.error.message + assert ( + f'Task {task_id} is in terminal state: {terminal_state.value}' + in exc_info.value.error.message + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize('terminal_state', TERMINAL_TASK_STATES) +async def test_on_message_send_stream_task_in_terminal_state(terminal_state): + """Test on_message_send_stream when task is already in a terminal state.""" + task_id = f'terminal_stream_task_{terminal_state.value}' + terminal_task = create_sample_task( + task_id=task_id, status_state=terminal_state + ) + + mock_task_store = AsyncMock(spec=TaskStore) + + request_handler = DefaultRequestHandler( + agent_executor=DummyAgentExecutor(), task_store=mock_task_store + ) + + params = MessageSendParams( + message=Message( + role=Role.user, + messageId='msg_terminal_stream', + parts=[], + taskId=task_id, + ) + ) + + from a2a.utils.errors import ServerError + + with patch( + 'a2a.server.request_handlers.default_request_handler.TaskManager.get_task', + return_value=terminal_task, + ): + with pytest.raises(ServerError) as exc_info: + async for _ in request_handler.on_message_send_stream( + params, create_server_call_context() + ): + pass # pragma: no cover + + assert isinstance(exc_info.value.error, InvalidParamsError) + assert exc_info.value.error.message + assert ( + f'Task {task_id} is in terminal state: {terminal_state.value}' + in exc_info.value.error.message + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize('terminal_state', TERMINAL_TASK_STATES) +async def test_on_resubscribe_to_task_in_terminal_state(terminal_state): + """Test on_resubscribe_to_task when task is in a terminal state.""" + task_id = f'resub_terminal_task_{terminal_state.value}' + terminal_task = create_sample_task( + task_id=task_id, status_state=terminal_state + ) + + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = terminal_task + + request_handler = DefaultRequestHandler( + agent_executor=DummyAgentExecutor(), + task_store=mock_task_store, + queue_manager=AsyncMock(spec=QueueManager), + ) + params = TaskIdParams(id=task_id) + + from a2a.utils.errors import ServerError + + with pytest.raises(ServerError) as exc_info: + async for _ in request_handler.on_resubscribe_to_task( + params, create_server_call_context() + ): + pass # pragma: no cover + + assert isinstance(exc_info.value.error, InvalidParamsError) + assert exc_info.value.error.message + assert ( + f'Task {task_id} is in terminal state: {terminal_state.value}' + in exc_info.value.error.message + ) + mock_task_store.get.assert_awaited_once_with(task_id)