Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion src/a2a/server/request_handlers/default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from a2a.types import (
GetTaskPushNotificationConfigParams,
InternalError,
InvalidParamsError,
Message,
MessageSendConfiguration,
MessageSendParams,
Expand All @@ -38,6 +39,7 @@
TaskNotFoundError,
TaskPushNotificationConfig,
TaskQueryParams,
TaskState,
UnsupportedOperationError,
)
from a2a.utils.errors import ServerError
Expand All @@ -46,6 +48,12 @@

logger = logging.getLogger(__name__)

TERMINAL_TASK_STATES = {
TaskState.completed,
TaskState.canceled,
TaskState.failed,
TaskState.rejected,
}

@trace_class(kind=SpanKind.SERVER)
class DefaultRequestHandler(RequestHandler):
Expand Down Expand Up @@ -170,28 +178,35 @@
Starts the agent execution for the message and waits for the final
result (Task or Message).
"""
task_manager = TaskManager(
task_id=params.message.taskId,
context_id=params.message.contextId,
task_store=self.task_store,
initial_message=params.message,
)
task: Task | None = await task_manager.get_task()
if task:
if task.status.state in TERMINAL_TASK_STATES:
raise ServerError(
error=InvalidParamsError(
message=f'Task {task.id} is in terminal state: {task.status.state}'
)
)

task = task_manager.update_with_message(params.message, task)
if self.should_add_push_info(params):
assert isinstance(self._push_notifier, PushNotifier)
assert isinstance(
params.configuration, MessageSendConfiguration
)
assert isinstance(
params.configuration.pushNotificationConfig,
PushNotificationConfig,
)
await self._push_notifier.set_info(
task.id, params.configuration.pushNotificationConfig
)
request_context = await self._request_context_builder.build(

Check notice on line 209 in src/a2a/server/request_handlers/default_request_handler.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/server/request_handlers/default_request_handler.py (273-302)
params=params,
task_id=task.id if task else None,
context_id=params.message.contextId,
Expand Down Expand Up @@ -255,30 +270,36 @@
Starts the agent execution and yields events as they are produced
by the agent.
"""
task_manager = TaskManager(
task_id=params.message.taskId,
context_id=params.message.contextId,
task_store=self.task_store,
initial_message=params.message,
)
task: Task | None = await task_manager.get_task()

if task:
task = task_manager.update_with_message(params.message, task)
if task.status.state in TERMINAL_TASK_STATES:
raise ServerError(
error=InvalidParamsError(
message=f'Task {task.id} is in terminal state: {task.status.state}'
)
)

task = task_manager.update_with_message(params.message, task)
if self.should_add_push_info(params):
assert isinstance(self._push_notifier, PushNotifier)
assert isinstance(
params.configuration, MessageSendConfiguration
)
assert isinstance(
params.configuration.pushNotificationConfig,
PushNotificationConfig,
)
await self._push_notifier.set_info(
task.id, params.configuration.pushNotificationConfig
)
else:

Check notice on line 302 in src/a2a/server/request_handlers/default_request_handler.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/server/request_handlers/default_request_handler.py (181-209)
queue = EventQueue()
result_aggregator = ResultAggregator(task_manager)
request_context = await self._request_context_builder.build(
Expand Down Expand Up @@ -413,18 +434,25 @@
if not task:
raise ServerError(error=TaskNotFoundError())

if task.status.state in TERMINAL_TASK_STATES:
raise ServerError(
error=InvalidParamsError(
message=f'Task {task.id} is in terminal state: {task.status.state}'
)
)

task_manager = TaskManager(
task_id=task.id,
context_id=task.contextId,
task_store=self.task_store,
initial_message=None,
)

result_aggregator = ResultAggregator(task_manager)

queue = await self._queue_manager.tap(task.id)
if not queue:
raise ServerError(error=TaskNotFoundError())

Check notice on line 455 in src/a2a/server/request_handlers/default_request_handler.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/server/request_handlers/default_request_handler.py (121-133)

consumer = EventConsumer(queue)
async for event in result_aggregator.consume_and_emit(consumer):
Expand Down
134 changes: 134 additions & 0 deletions tests/server/request_handlers/test_default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from a2a.types import (
InternalError,
InvalidParamsError,
Message,
MessageSendConfiguration,
MessageSendParams,
Expand Down Expand Up @@ -1137,3 +1138,136 @@ async def consume_stream():

texts = [p.root.text for e in events for p in e.status.message.parts]
assert texts == ['Event 0', 'Event 1', 'Event 2']

TERMINAL_TASK_STATES = {
TaskState.completed,
TaskState.canceled,
TaskState.failed,
TaskState.rejected,
}

@pytest.mark.asyncio
@pytest.mark.parametrize('terminal_state', TERMINAL_TASK_STATES)
async def test_on_message_send_task_in_terminal_state(terminal_state):
"""Test on_message_send when task is already in a terminal state."""
task_id = f'terminal_task_{terminal_state.value}'
terminal_task = create_sample_task(
task_id=task_id, status_state=terminal_state
)

mock_task_store = AsyncMock(spec=TaskStore)
# The get method of TaskManager calls task_store.get.
# We mock TaskManager.get_task which is an async method.
# So we should patch that instead.

request_handler = DefaultRequestHandler(
agent_executor=DummyAgentExecutor(), task_store=mock_task_store
)

params = MessageSendParams(
message=Message(
role=Role.user,
messageId='msg_terminal',
parts=[],
taskId=task_id,
)
)

from a2a.utils.errors import ServerError

# Patch the TaskManager's get_task method to return our terminal task
with patch(
'a2a.server.request_handlers.default_request_handler.TaskManager.get_task',
return_value=terminal_task,
):
with pytest.raises(ServerError) as exc_info:
await request_handler.on_message_send(
params, create_server_call_context()
)

assert isinstance(exc_info.value.error, InvalidParamsError)
assert exc_info.value.error.message
assert (
f'Task {task_id} is in terminal state: {terminal_state.value}'
in exc_info.value.error.message
)


@pytest.mark.asyncio
@pytest.mark.parametrize('terminal_state', TERMINAL_TASK_STATES)
async def test_on_message_send_stream_task_in_terminal_state(terminal_state):
"""Test on_message_send_stream when task is already in a terminal state."""
task_id = f'terminal_stream_task_{terminal_state.value}'
terminal_task = create_sample_task(
task_id=task_id, status_state=terminal_state
)

mock_task_store = AsyncMock(spec=TaskStore)

request_handler = DefaultRequestHandler(
agent_executor=DummyAgentExecutor(), task_store=mock_task_store
)

params = MessageSendParams(
message=Message(
role=Role.user,
messageId='msg_terminal_stream',
parts=[],
taskId=task_id,
)
)

from a2a.utils.errors import ServerError

with patch(
'a2a.server.request_handlers.default_request_handler.TaskManager.get_task',
return_value=terminal_task,
):
with pytest.raises(ServerError) as exc_info:
async for _ in request_handler.on_message_send_stream(
params, create_server_call_context()
):
pass # pragma: no cover

assert isinstance(exc_info.value.error, InvalidParamsError)
assert exc_info.value.error.message
assert (
f'Task {task_id} is in terminal state: {terminal_state.value}'
in exc_info.value.error.message
)


@pytest.mark.asyncio
@pytest.mark.parametrize('terminal_state', TERMINAL_TASK_STATES)
async def test_on_resubscribe_to_task_in_terminal_state(terminal_state):
"""Test on_resubscribe_to_task when task is in a terminal state."""
task_id = f'resub_terminal_task_{terminal_state.value}'
terminal_task = create_sample_task(
task_id=task_id, status_state=terminal_state
)

mock_task_store = AsyncMock(spec=TaskStore)
mock_task_store.get.return_value = terminal_task

request_handler = DefaultRequestHandler(
agent_executor=DummyAgentExecutor(),
task_store=mock_task_store,
queue_manager=AsyncMock(spec=QueueManager),
)
params = TaskIdParams(id=task_id)

from a2a.utils.errors import ServerError

with pytest.raises(ServerError) as exc_info:
async for _ in request_handler.on_resubscribe_to_task(
params, create_server_call_context()
):
pass # pragma: no cover

assert isinstance(exc_info.value.error, InvalidParamsError)
assert exc_info.value.error.message
assert (
f'Task {task_id} is in terminal state: {terminal_state.value}'
in exc_info.value.error.message
)
mock_task_store.get.assert_awaited_once_with(task_id)
Loading