Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion src/a2a/_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

from typing import Any, ClassVar

from pydantic import BaseModel, ConfigDict
Expand Down Expand Up @@ -27,7 +29,8 @@ class A2ABaseModel(BaseModel):
serves as the foundation for future extensions or shared utilities.

This implementation provides backward compatibility for camelCase aliases
by lazy-loading an alias map upon first use.
by lazy-loading an alias map upon first use. Accessing or setting
attributes via their camelCase alias will raise a DeprecationWarning.
"""

model_config = ConfigDict(
Expand Down Expand Up @@ -60,14 +63,37 @@ def __setattr__(self, name: str, value: Any) -> None:
"""Allow setting attributes via their camelCase alias."""
# Get the map and find the corresponding snake_case field name.
field_name = type(self)._get_alias_map().get(name) # noqa: SLF001

if field_name:
# An alias was used, issue a warning.
warnings.warn(
(
f"Setting field '{name}' via its camelCase alias is deprecated and will be removed in version 0.3.0 "
f"Use the snake_case name '{field_name}' instead."
),
DeprecationWarning,
stacklevel=2,
)

# If an alias was used, field_name will be set; otherwise, use the original name.
super().__setattr__(field_name or name, value)

def __getattr__(self, name: str) -> Any:
"""Allow getting attributes via their camelCase alias."""
# Get the map and find the corresponding snake_case field name.
field_name = type(self)._get_alias_map().get(name) # noqa: SLF001

if field_name:
# An alias was used, issue a warning.
warnings.warn(
(
f"Accessing field '{name}' via its camelCase alias is deprecated and will be removed in version 0.3.0 "
f"Use the snake_case name '{field_name}' instead."
),
DeprecationWarning,
stacklevel=2,
)

# If an alias was used, retrieve the actual snake_case attribute.
return getattr(self, field_name)

Expand Down
4 changes: 2 additions & 2 deletions src/a2a/server/request_handlers/default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,10 @@ async def _setup_message_execution(
)

task = task_manager.update_with_message(params.message, task)
elif params.message.taskId:
elif params.message.task_id:
raise ServerError(
error=TaskNotFoundError(
message=f'Task {params.message.taskId} was specified but does not exist'
message=f'Task {params.message.task_id} was specified but does not exist'
)
)

Expand Down
16 changes: 8 additions & 8 deletions tests/server/request_handlers/test_default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1787,7 +1787,7 @@ async def test_on_resubscribe_to_task_in_terminal_state(terminal_state):

@pytest.mark.asyncio
async def test_on_message_send_task_id_provided_but_task_not_found():
"""Test on_message_send when taskId is provided but task doesn't exist."""
"""Test on_message_send when task_id is provided but task doesn't exist."""
task_id = 'nonexistent_task'
mock_task_store = AsyncMock(spec=TaskStore)

Expand All @@ -1798,10 +1798,10 @@ async def test_on_message_send_task_id_provided_but_task_not_found():
params = MessageSendParams(
message=Message(
role=Role.user,
messageId='msg_nonexistent',
message_id='msg_nonexistent',
parts=[Part(root=TextPart(text='Hello'))],
taskId=task_id,
contextId='ctx1',
task_id=task_id,
context_id='ctx1',
)
)

Expand All @@ -1827,7 +1827,7 @@ async def test_on_message_send_task_id_provided_but_task_not_found():

@pytest.mark.asyncio
async def test_on_message_send_stream_task_id_provided_but_task_not_found():
"""Test on_message_send_stream when taskId is provided but task doesn't exist."""
"""Test on_message_send_stream when task_id is provided but task doesn't exist."""
task_id = 'nonexistent_stream_task'
mock_task_store = AsyncMock(spec=TaskStore)

Expand All @@ -1838,10 +1838,10 @@ async def test_on_message_send_stream_task_id_provided_but_task_not_found():
params = MessageSendParams(
message=Message(
role=Role.user,
messageId='msg_nonexistent_stream',
message_id='msg_nonexistent_stream',
parts=[Part(root=TextPart(text='Hello'))],
taskId=task_id,
contextId='ctx1',
task_id=task_id,
context_id='ctx1',
)
)

Expand Down
22 changes: 18 additions & 4 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1551,8 +1551,22 @@ def test_camelCase() -> None:
supportsAuthenticatedExtendedCard=True,
)

agent_card.supportsAuthenticatedExtendedCard = False

default_input_modes = agent_card.defaultInputModes
assert agent_card
# Test setting an attribute via camelCase alias
# We expect a DeprecationWarning with a specific message
with pytest.warns(
DeprecationWarning,
match="Setting field 'supportsAuthenticatedExtendedCard'",
):
agent_card.supportsAuthenticatedExtendedCard = False

# Test getting an attribute via camelCase alias
# We expect another DeprecationWarning with a specific message
with pytest.warns(
DeprecationWarning, match="Accessing field 'defaultInputModes'"
):
default_input_modes = agent_card.defaultInputModes

# Assert the functionality still works as expected
assert agent_card.supports_authenticated_extended_card is False
assert default_input_modes == ['text']
assert agent_card.default_input_modes == ['text']
Loading