Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
5441801
Add snake_case generation
holtskinner Jul 16, 2025
c8eb963
Edit some snake_case
holtskinner Jul 16, 2025
0028bda
Update types.py
holtskinner Jul 17, 2025
7649d0c
Update usage in library
holtskinner Jul 17, 2025
d3a35e3
Update Pydantic
holtskinner Jul 17, 2025
572fd31
Undo pydantic upgrade
holtskinner Jul 17, 2025
71152f2
Switch to alias generator
holtskinner Jul 17, 2025
4504f0f
Switch to to_camel
holtskinner Jul 17, 2025
fd7dd8c
Fix missing changes
holtskinner Jul 17, 2025
f24739b
Merge branch 'main' into snake-case-field
holtskinner Jul 18, 2025
38c1dc9
Add --field-constraints to generate_types
holtskinner Jul 18, 2025
8597e89
Fix duplicate fields
holtskinner Jul 18, 2025
3a1e924
Fix push_notification_config
holtskinner Jul 18, 2025
167f003
Fix links
holtskinner Jul 21, 2025
5b1f8d0
re-add Alias
holtskinner Jul 21, 2025
e3d5057
Add pydantic mypy plugin
holtskinner Jul 21, 2025
fe8beff
Use custom camelCase alias generator
holtskinner Jul 21, 2025
4944c4f
Merge branch 'main' into snake-case-field
holtskinner Jul 21, 2025
02db1a5
Fixed camelCase in JSON Payload
holtskinner Jul 21, 2025
9baefef
Merge branch 'snake-case-field' of https://github.com/google-a2a/a2a-…
holtskinner Jul 21, 2025
14b8308
Update pyproject.toml
holtskinner Jul 21, 2025
66011b6
Add mypy plugin
holtskinner Jul 21, 2025
59d5ffb
Add backwards compatibility for camelCase
holtskinner Jul 21, 2025
d5253af
Add Lint ignore
holtskinner Jul 21, 2025
04f105a
spelling
holtskinner Jul 21, 2025
28656a0
Add support for camelCase `__getattr__`
holtskinner Jul 21, 2025
71adfec
Simplify `__setattr__` and `__getattr__` implementation
holtskinner Jul 21, 2025
e36211b
Linting
holtskinner Jul 21, 2025
aebec89
Merge branch 'main' into snake-case-field
holtskinner Jul 21, 2025
373959d
Format spelling
holtskinner Jul 21, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/actions/spelling/allow.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ pyversions
respx
resub
RUF
SLF
socio
sse
tagwords
Expand Down
1 change: 1 addition & 0 deletions .mypy.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[mypy]
exclude = src/a2a/grpc/
disable_error_code = import-not-found,annotation-unchecked,import-untyped
plugins = pydantic.mypy

[mypy-examples.*]
follow_imports = skip
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ url = "https://test.pypi.org/simple/"
publish-url = "https://test.pypi.org/legacy/"
explicit = true

[tool.mypy]
plugins = ['pydantic.mypy']

[tool.pyright]
include = ["src"]
exclude = [
Expand Down
5 changes: 4 additions & 1 deletion scripts/generate_types.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ uv run datamodel-codegen \
--class-name A2A \
--use-standard-collections \
--use-subclass-enum \
--base-class a2a._base.A2ABaseModel
--base-class a2a._base.A2ABaseModel \
--field-constraints \
--snake-case-field \
--no-alias

echo "Formatting generated file with ruff..."
uv run ruff format "$GENERATED_FILE"
Expand Down
62 changes: 62 additions & 0 deletions src/a2a/_base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,77 @@
from typing import Any, ClassVar

from pydantic import BaseModel, ConfigDict
from pydantic.alias_generators import to_camel


def to_camel_custom(snake: str) -> str:
"""Convert a snake_case string to camelCase.

Args:
snake: The string to convert.

Returns:
The converted camelCase string.
"""
# First, remove any trailing underscores. This is common for names that
# conflict with Python keywords, like 'in_' or 'from_'.
if snake.endswith('_'):
snake = snake.rstrip('_')
return to_camel(snake)


class A2ABaseModel(BaseModel):
"""Base class for shared behavior across A2A data models.

Provides a common configuration (e.g., alias-based population) and
serves as the foundation for future extensions or shared utilities.

This implementation provides backward compatibility for camelCase aliases
by lazy-loading an alias map upon first use.
"""

model_config = ConfigDict(
# SEE: https://docs.pydantic.dev/latest/api/config/#pydantic.config.ConfigDict.populate_by_name
validate_by_name=True,
validate_by_alias=True,
serialize_by_alias=True,
alias_generator=to_camel_custom,
)

# Cache for the alias -> field_name mapping.
# It starts as None and is populated on first access.
_alias_to_field_name_map: ClassVar[dict[str, str] | None] = None

@classmethod
def _get_alias_map(cls) -> dict[str, str]:
"""Lazily builds and returns the alias-to-field-name mapping for the class.

The map is cached on the class object to avoid re-computation.
"""
if cls._alias_to_field_name_map is None:
cls._alias_to_field_name_map = {
field.alias: field_name
for field_name, field in cls.model_fields.items()
if field.alias is not None
}
return cls._alias_to_field_name_map

def __setattr__(self, name: str, value: Any) -> None:
"""Allow setting attributes via their camelCase alias."""
# Get the map and find the corresponding snake_case field name.
field_name = type(self)._get_alias_map().get(name) # noqa: SLF001
# If an alias was used, field_name will be set; otherwise, use the original name.
super().__setattr__(field_name or name, value)

def __getattr__(self, name: str) -> Any:
"""Allow getting attributes via their camelCase alias."""
# Get the map and find the corresponding snake_case field name.
field_name = type(self)._get_alias_map().get(name) # noqa: SLF001
if field_name:
# If an alias was used, retrieve the actual snake_case attribute.
return getattr(self, field_name)

# If it's not a known alias, it's a genuine missing attribute.
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'"
)
6 changes: 3 additions & 3 deletions src/a2a/client/auth/interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def intercept(
if (
agent_card is None
or agent_card.security is None
or agent_card.securitySchemes is None
or agent_card.security_schemes is None
):
return request_payload, http_kwargs

Expand All @@ -45,8 +45,8 @@ async def intercept(
credential = await self._credential_service.get_credentials(
scheme_name, context
)
if credential and scheme_name in agent_card.securitySchemes:
scheme_def_union = agent_card.securitySchemes.get(
if credential and scheme_name in agent_card.security_schemes:
scheme_def_union = agent_card.security_schemes.get(
scheme_name
)
if not scheme_def_union:
Expand Down
4 changes: 2 additions & 2 deletions src/a2a/client/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def create_text_message_object(
content: The text content of the message. Defaults to an empty string.

Returns:
A `Message` object with a new UUID messageId.
A `Message` object with a new UUID message_id.
"""
return Message(
role=role, parts=[Part(TextPart(text=content))], messageId=str(uuid4())
role=role, parts=[Part(TextPart(text=content))], message_id=str(uuid4())
)
22 changes: 11 additions & 11 deletions src/a2a/server/agent_execution/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,14 @@ def __init__( # noqa: PLR0913
# match the request. Otherwise, create them
if self._params:
if task_id:
self._params.message.taskId = task_id
self._params.message.task_id = task_id
if task and task.id != task_id:
raise ServerError(InvalidParamsError(message='bad task id'))
else:
self._check_or_generate_task_id()
if context_id:
self._params.message.contextId = context_id
if task and task.contextId != context_id:
self._params.message.context_id = context_id
if task and task.context_id != context_id:
raise ServerError(
InvalidParamsError(message='bad context id')
)
Expand Down Expand Up @@ -148,17 +148,17 @@ def _check_or_generate_task_id(self) -> None:
if not self._params:
return

if not self._task_id and not self._params.message.taskId:
self._params.message.taskId = str(uuid.uuid4())
if self._params.message.taskId:
self._task_id = self._params.message.taskId
if not self._task_id and not self._params.message.task_id:
self._params.message.task_id = str(uuid.uuid4())
if self._params.message.task_id:
self._task_id = self._params.message.task_id

def _check_or_generate_context_id(self) -> None:
"""Ensures a context ID is present, generating one if necessary."""
if not self._params:
return

if not self._context_id and not self._params.message.contextId:
self._params.message.contextId = str(uuid.uuid4())
if self._params.message.contextId:
self._context_id = self._params.message.contextId
if not self._context_id and not self._params.message.context_id:
self._params.message.context_id = str(uuid.uuid4())
if self._params.message.context_id:
self._context_id = self._params.message.context_id
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

Args:
should_populate_referred_tasks: If True, the builder will fetch tasks
referenced in `params.message.referenceTaskIds` and populate the
referenced in `params.message.reference_task_ids` and populate the
`related_tasks` field in the RequestContext. Defaults to False.
task_store: The TaskStore instance to use for fetching referred tasks.
Required if `should_populate_referred_tasks` is True.
Expand All @@ -26,43 +26,43 @@
self._task_store = task_store
self._should_populate_referred_tasks = should_populate_referred_tasks

async def build(
self,
params: MessageSendParams | None = None,
task_id: str | None = None,
context_id: str | None = None,
task: Task | None = None,
context: ServerCallContext | None = None,
) -> RequestContext:
"""Builds the request context for an agent execution.

This method assembles the RequestContext object. If the builder was
initialized with `should_populate_referred_tasks=True`, it fetches all tasks
referenced in `params.message.referenceTaskIds` from the `task_store`.
referenced in `params.message.reference_task_ids` from the `task_store`.

Args:
params: The parameters of the incoming message send request.
task_id: The ID of the task being executed.
context_id: The ID of the current execution context.
task: The primary task object associated with the request.
context: The server call context, containing metadata about the call.

Returns:
An instance of RequestContext populated with the provided information
and potentially a list of related tasks.
"""

Check notice on line 53 in src/a2a/server/agent_execution/simple_request_context_builder.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/server/agent_execution/request_context_builder.py (12-20)
related_tasks: list[Task] | None = None

if (
self._task_store
and self._should_populate_referred_tasks
and params
and params.message.referenceTaskIds
and params.message.reference_task_ids
):
tasks = await asyncio.gather(
*[
self._task_store.get(task_id)
for task_id in params.message.referenceTaskIds
for task_id in params.message.reference_task_ids
]
)
related_tasks = [x for x in tasks if x is not None]
Expand Down
2 changes: 1 addition & 1 deletion src/a2a/server/apps/jsonrpc/fastapi_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def add_routes_to_app(
)(self._handle_requests)
app.get(agent_card_url)(self._handle_get_agent_card)

if self.agent_card.supportsAuthenticatedExtendedCard:
if self.agent_card.supports_authenticated_extended_card:
app.get(extended_agent_card_url)(
self._handle_get_authenticated_extended_agent_card
)
Expand Down
8 changes: 4 additions & 4 deletions src/a2a/server/apps/jsonrpc/jsonrpc_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,11 @@ def __init__(
agent_card=agent_card, request_handler=http_handler
)
if (
self.agent_card.supportsAuthenticatedExtendedCard
self.agent_card.supports_authenticated_extended_card
and self.extended_agent_card is None
):
logger.error(
'AgentCard.supportsAuthenticatedExtendedCard is True, but no extended_agent_card was provided. The /agent/authenticatedExtendedCard endpoint will return 404.'
'AgentCard.supports_authenticated_extended_card is True, but no extended_agent_card was provided. The /agent/authenticatedExtendedCard endpoint will return 404.'
)
self._context_builder = context_builder or DefaultCallContextBuilder()

Expand Down Expand Up @@ -421,7 +421,7 @@ async def _handle_get_authenticated_extended_agent_card(
self, request: Request
) -> JSONResponse:
"""Handles GET requests for the authenticated extended agent card."""
if not self.agent_card.supportsAuthenticatedExtendedCard:
if not self.agent_card.supports_authenticated_extended_card:
return JSONResponse(
{'error': 'Extended agent card not supported or not enabled.'},
status_code=404,
Expand All @@ -435,7 +435,7 @@ async def _handle_get_authenticated_extended_agent_card(
by_alias=True,
)
)
# If supportsAuthenticatedExtendedCard is true, but no specific
# If supports_authenticated_extended_card is true, but no specific
# extended_agent_card was provided during server initialization,
# return a 404
return JSONResponse(
Expand Down
2 changes: 1 addition & 1 deletion src/a2a/server/apps/jsonrpc/starlette_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def routes(
),
]

if self.agent_card.supportsAuthenticatedExtendedCard:
if self.agent_card.supports_authenticated_extended_card:
app_routes.append(
Route(
extended_agent_card_url,
Expand Down
10 changes: 5 additions & 5 deletions src/a2a/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class TaskMixin:
"""Mixin providing standard task columns with proper type handling."""

id: Mapped[str] = mapped_column(String(36), primary_key=True, index=True)
contextId: Mapped[str] = mapped_column(String(36), nullable=False) # noqa: N815
context_id: Mapped[str] = mapped_column(String(36), nullable=False)
kind: Mapped[str] = mapped_column(
String(16), nullable=False, default='task'
)
Expand All @@ -148,12 +148,12 @@ def task_metadata(cls) -> Mapped[dict[str, Any] | None]:
def __repr__(self) -> str:
"""Return a string representation of the task."""
repr_template = (
'<{CLS}(id="{ID}", contextId="{CTX_ID}", status="{STATUS}")>'
'<{CLS}(id="{ID}", context_id="{CTX_ID}", status="{STATUS}")>'
)
return repr_template.format(
CLS=self.__class__.__name__,
ID=self.id,
CTX_ID=self.contextId,
CTX_ID=self.context_id,
STATUS=self.status,
)

Expand Down Expand Up @@ -188,11 +188,11 @@ class TaskModel(TaskMixin, base):
@override
def __repr__(self) -> str:
"""Return a string representation of the task."""
repr_template = '<TaskModel[{TABLE}](id="{ID}", contextId="{CTX_ID}", status="{STATUS}")>'
repr_template = '<TaskModel[{TABLE}](id="{ID}", context_id="{CTX_ID}", status="{STATUS}")>'
return repr_template.format(
TABLE=table_name,
ID=self.id,
CTX_ID=self.contextId,
CTX_ID=self.context_id,
STATUS=self.status,
)

Expand Down
Loading
Loading