diff --git a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/assets.py b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/assets.py index dd08cc16ec01d..00f58257b830b 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/assets.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/assets.py @@ -19,7 +19,7 @@ from datetime import datetime -from pydantic import AliasPath, ConfigDict, Field, NonNegativeInt, field_validator +from pydantic import AliasPath, ConfigDict, Field, JsonValue, NonNegativeInt, field_validator from airflow._shared.secrets_masker import redact from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel @@ -65,7 +65,7 @@ class AssetResponse(BaseModel): name: str uri: str group: str - extra: dict | None = None + extra: dict[str, JsonValue] | None = None created_at: datetime updated_at: datetime scheduled_dags: list[DagScheduleAssetReference] @@ -123,7 +123,7 @@ class AssetEventResponse(BaseModel): uri: str | None = Field(alias="uri", default=None) name: str | None = Field(alias="name", default=None) group: str | None = Field(alias="group", default=None) - extra: dict | None = None + extra: dict[str, JsonValue] | None = None source_task_id: str | None = None source_dag_id: str | None = None source_run_id: str | None = None diff --git a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml index 8a631c1909678..3c1729d1c973b 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml +++ b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml @@ -8666,7 +8666,8 @@ components: title: Group extra: anyOf: - - additionalProperties: true + - additionalProperties: + $ref: '#/components/schemas/JsonValue' type: object - type: 'null' title: Extra @@ -8722,7 +8723,8 @@ components: title: Group extra: anyOf: - - additionalProperties: true + - additionalProperties: + $ref: '#/components/schemas/JsonValue' type: object - type: 'null' title: Extra diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/asset.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/asset.py index 4cd23ddbb822b..c11bccdd59a60 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/asset.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/asset.py @@ -17,6 +17,8 @@ from __future__ import annotations +from pydantic.types import JsonValue + from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel @@ -26,7 +28,7 @@ class AssetResponse(BaseModel): name: str uri: str group: str - extra: dict | None = None + extra: dict[str, JsonValue] | None = None class AssetAliasResponse(BaseModel): diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/asset_event.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/asset_event.py index a9a66f242fc80..b050133212ebb 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/asset_event.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/asset_event.py @@ -19,6 +19,8 @@ from datetime import datetime +from pydantic.types import JsonValue + from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel from airflow.api_fastapi.execution_api.datamodels.asset import AssetResponse @@ -41,7 +43,7 @@ class AssetEventResponse(BaseModel): id: int timestamp: datetime - extra: dict | None = None + extra: dict[str, JsonValue] | None = None asset: AssetResponse created_dagruns: list[DagRunAssetReference] diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index 9f639b1e3911a..f2e7aa36b0d9b 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -25,6 +25,7 @@ AwareDatetime, Discriminator, Field, + JsonValue, Tag, TypeAdapter, WithJsonSchema, @@ -258,7 +259,7 @@ class AssetReferenceAssetEventDagRun(StrictBaseModel): name: str uri: str - extra: dict + extra: dict[str, JsonValue] class AssetAliasReferenceAssetEventDagRun(StrictBaseModel): @@ -271,7 +272,7 @@ class AssetEventDagRunReference(StrictBaseModel): """Schema for AssetEvent model used in DagRun.""" asset: AssetReferenceAssetEventDagRun - extra: dict + extra: dict[str, JsonValue] source_task_id: str | None source_dag_id: str | None source_run_id: str | None diff --git a/airflow-core/src/airflow/lineage/hook.py b/airflow-core/src/airflow/lineage/hook.py index b69f12484dc1b..ede28a7855f03 100644 --- a/airflow-core/src/airflow/lineage/hook.py +++ b/airflow-core/src/airflow/lineage/hook.py @@ -29,6 +29,8 @@ from airflow.utils.log.logging_mixin import LoggingMixin if TYPE_CHECKING: + from pydantic.types import JsonValue + from airflow.sdk import BaseHook, ObjectStoragePath # Store context what sent lineage. @@ -107,7 +109,7 @@ def create_asset( name: str | None = None, group: str | None = None, asset_kwargs: dict | None = None, - asset_extra: dict | None = None, + asset_extra: dict[str, JsonValue] | None = None, ) -> Asset | None: """ Create an asset instance using the provided parameters. @@ -161,7 +163,7 @@ def add_input_asset( name: str | None = None, group: str | None = None, asset_kwargs: dict | None = None, - asset_extra: dict | None = None, + asset_extra: dict[str, JsonValue] | None = None, ): """Add the input asset and its corresponding hook execution context to the collector.""" if len(self._inputs) >= MAX_COLLECTED_ASSETS: @@ -186,7 +188,7 @@ def add_output_asset( name: str | None = None, group: str | None = None, asset_kwargs: dict | None = None, - asset_extra: dict | None = None, + asset_extra: dict[str, JsonValue] | None = None, ): """Add the output asset and its corresponding hook execution context to the collector.""" if len(self._outputs) >= MAX_COLLECTED_ASSETS: diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index e9806696e243b..7b29184110302 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -20,6 +20,7 @@ import contextlib import hashlib import itertools +import json import logging import math import uuid @@ -1380,7 +1381,7 @@ def register_asset_changes_in_db( session=session, ) - def _asset_event_extras_from_aliases() -> dict[tuple[AssetUniqueKey, frozenset], set[str]]: + def _asset_event_extras_from_aliases() -> dict[tuple[AssetUniqueKey, str], set[str]]: d = defaultdict(set) for event in outlet_events: try: @@ -1390,19 +1391,20 @@ def _asset_event_extras_from_aliases() -> dict[tuple[AssetUniqueKey, frozenset], if alias_name not in outlet_alias_names: continue asset_key = AssetUniqueKey(**event["dest_asset_key"]) - extra_key = frozenset(event["extra"].items()) - d[asset_key, extra_key].add(alias_name) + extra_json = json.dumps(event["extra"], sort_keys=True) + d[asset_key, extra_json].add(alias_name) return d outlet_alias_names = {o.name for o in task_outlets if o.type == AssetAlias.__name__ and o.name} if outlet_alias_names and (event_extras_from_aliases := _asset_event_extras_from_aliases()): - for (asset_key, extra_key), event_aliase_names in event_extras_from_aliases.items(): + for (asset_key, extra_json), event_aliase_names in event_extras_from_aliases.items(): + extra = json.loads(extra_json) ti.log.debug("register event for asset %s with aliases %s", asset_key, event_aliase_names) event = asset_manager.register_asset_change( task_instance=ti, asset=asset_key, source_alias_names=event_aliase_names, - extra=dict(extra_key), + extra=extra, session=session, ) if event is None: @@ -1413,7 +1415,7 @@ def _asset_event_extras_from_aliases() -> dict[tuple[AssetUniqueKey, frozenset], task_instance=ti, asset=asset_key, source_alias_names=event_aliase_names, - extra=dict(extra_key), + extra=extra, session=session, ) diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts index f9c93f976019f..4f503ef9492f7 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -209,7 +209,9 @@ export const $AssetEventResponse = { extra: { anyOf: [ { - additionalProperties: true, + additionalProperties: { + '$ref': '#/components/schemas/JsonValue' + }, type: 'object' }, { @@ -295,7 +297,9 @@ export const $AssetResponse = { extra: { anyOf: [ { - additionalProperties: true, + additionalProperties: { + '$ref': '#/components/schemas/JsonValue' + }, type: 'object' }, { diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts index 1396090691c41..f2e2c84f812d3 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts @@ -64,7 +64,7 @@ export type AssetEventResponse = { name?: string | null; group?: string | null; extra?: { - [key: string]: unknown; + [key: string]: JsonValue; } | null; source_task_id?: string | null; source_dag_id?: string | null; @@ -83,7 +83,7 @@ export type AssetResponse = { uri: string; group: string; extra?: { - [key: string]: unknown; + [key: string]: JsonValue; } | null; created_at: string; updated_at: string; diff --git a/airflow-core/tests/unit/models/test_taskinstance.py b/airflow-core/tests/unit/models/test_taskinstance.py index 74260a05a11ce..4c7a0d9fd5c86 100644 --- a/airflow-core/tests/unit/models/test_taskinstance.py +++ b/airflow-core/tests/unit/models/test_taskinstance.py @@ -22,7 +22,7 @@ import operator import os import pathlib -from typing import cast +from typing import TYPE_CHECKING, cast from unittest import mock from unittest.mock import patch @@ -88,6 +88,11 @@ from tests_common.test_utils.mock_operators import MockOperator from unit.models import DEFAULT_DATE +if TYPE_CHECKING: + from sqlalchemy.orm import Session + + from tests_common.pytest_plugin import DagMaker + pytestmark = [pytest.mark.db_test, pytest.mark.need_serialized_dag, pytest.mark.want_activate_assets] @@ -1602,14 +1607,19 @@ def test_set_duration_empty_dates(self): ti.set_duration() assert ti.duration is None - def test_outlet_asset_extra(self, dag_maker, session): + def test_outlet_asset_extra(self, dag_maker: DagMaker, session: Session): from airflow.sdk.definitions.asset import Asset with dag_maker(schedule=None, serialized=True, session=session): @task(outlets=Asset("test_outlet_asset_extra_1")) - def write1(*, outlet_events): - outlet_events[Asset("test_outlet_asset_extra_1")].extra = {"foo": "bar"} + def write1(*, outlet_events=None): + if TYPE_CHECKING: + assert isinstance(outlet_events, dict) + outlet_events[Asset("test_outlet_asset_extra_1")].extra = { + "foo": "bar", + "this": {"is": "nested", "value": 1}, + } write1() @@ -1634,7 +1644,7 @@ def _write2_post_execute(context, _): assert events["write1"].source_run_id == dr.run_id assert events["write1"].source_task_id == "write1" assert events["write1"].asset.uri == "test_outlet_asset_extra_1" - assert events["write1"].extra == {"foo": "bar"} + assert events["write1"].extra == {"foo": "bar", "this": {"is": "nested", "value": 1}} assert events["write2"].source_dag_id == dr.dag_id assert events["write2"].source_run_id == dr.run_id diff --git a/airflow-core/tests/unit/serialization/test_serialized_objects.py b/airflow-core/tests/unit/serialization/test_serialized_objects.py index 6dc503ff099ab..14ca6a5a2cc9e 100644 --- a/airflow-core/tests/unit/serialization/test_serialized_objects.py +++ b/airflow-core/tests/unit/serialization/test_serialized_objects.py @@ -22,6 +22,7 @@ import sys from collections.abc import Iterator from datetime import datetime, timedelta +from typing import TYPE_CHECKING import pendulum import pytest @@ -85,6 +86,9 @@ from unit.models import DEFAULT_DATE +if TYPE_CHECKING: + from pydantic.types import JsonValue + DAG_ID = "dag_id_1" TEST_CALLBACK_PATH = f"{__name__}.empty_callback_for_deadline" @@ -227,7 +231,7 @@ def validate(self, obj): def create_outlet_event_accessors( - key: Asset | AssetAlias, extra: dict, asset_alias_events: list[AssetAliasEvent] + key: Asset | AssetAlias, extra: dict[str, JsonValue], asset_alias_events: list[AssetAliasEvent] ) -> OutletEventAccessors: o = OutletEventAccessors() o[key].extra = extra diff --git a/airflow-ctl/src/airflowctl/api/datamodels/generated.py b/airflow-ctl/src/airflowctl/api/datamodels/generated.py index 9e8f9f7c759f4..48b96aeaa0cd9 100644 --- a/airflow-ctl/src/airflowctl/api/datamodels/generated.py +++ b/airflow-ctl/src/airflowctl/api/datamodels/generated.py @@ -1028,7 +1028,7 @@ class AssetEventResponse(BaseModel): uri: Annotated[str | None, Field(title="Uri")] = None name: Annotated[str | None, Field(title="Name")] = None group: Annotated[str | None, Field(title="Group")] = None - extra: Annotated[dict[str, Any] | None, Field(title="Extra")] = None + extra: Annotated[dict[str, JsonValue] | None, Field(title="Extra")] = None source_task_id: Annotated[str | None, Field(title="Source Task Id")] = None source_dag_id: Annotated[str | None, Field(title="Source Dag Id")] = None source_run_id: Annotated[str | None, Field(title="Source Run Id")] = None @@ -1046,7 +1046,7 @@ class AssetResponse(BaseModel): name: Annotated[str, Field(title="Name")] uri: Annotated[str, Field(title="Uri")] group: Annotated[str, Field(title="Group")] - extra: Annotated[dict[str, Any] | None, Field(title="Extra")] = None + extra: Annotated[dict[str, JsonValue] | None, Field(title="Extra")] = None created_at: Annotated[datetime, Field(title="Created At")] updated_at: Annotated[datetime, Field(title="Updated At")] scheduled_dags: Annotated[list[DagScheduleAssetReference], Field(title="Scheduled Dags")] diff --git a/airflow-ctl/tests/airflow_ctl/api/test_operations.py b/airflow-ctl/tests/airflow_ctl/api/test_operations.py index 01df0de657706..1e7391b99d2cc 100644 --- a/airflow-ctl/tests/airflow_ctl/api/test_operations.py +++ b/airflow-ctl/tests/airflow_ctl/api/test_operations.py @@ -200,7 +200,7 @@ class TestAssetsOperations: id=asset_id, name="asset", uri="asset_uri", - extra={"extra": "extra"}, + extra={"extra": "extra"}, # type: ignore[dict-item] created_at=datetime.datetime(2024, 12, 31, 23, 59, 59), updated_at=datetime.datetime(2025, 1, 1, 0, 0, 0), scheduled_dags=[], diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index d288fecc1f200..b216c4372e3e7 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -63,30 +63,6 @@ class AssetProfile(BaseModel): type: Annotated[str, Field(title="Type")] -class AssetReferenceAssetEventDagRun(BaseModel): - """ - Schema for AssetModel used in AssetEventDagRunReference. - """ - - model_config = ConfigDict( - extra="forbid", - ) - name: Annotated[str, Field(title="Name")] - uri: Annotated[str, Field(title="Uri")] - extra: Annotated[dict[str, Any], Field(title="Extra")] - - -class AssetResponse(BaseModel): - """ - Asset schema for responses with fields that are needed for Runtime. - """ - - name: Annotated[str, Field(title="Name")] - uri: Annotated[str, Field(title="Uri")] - group: Annotated[str, Field(title="Group")] - extra: Annotated[dict[str, Any] | None, Field(title="Extra")] = None - - class ConnectionResponse(BaseModel): """ Connection schema for responses with fields that are needed for Runtime. @@ -477,6 +453,74 @@ class TriggerRule(str, Enum): ALL_SKIPPED = "all_skipped" +class AssetReferenceAssetEventDagRun(BaseModel): + """ + Schema for AssetModel used in AssetEventDagRunReference. + """ + + model_config = ConfigDict( + extra="forbid", + ) + name: Annotated[str, Field(title="Name")] + uri: Annotated[str, Field(title="Uri")] + extra: Annotated[dict[str, JsonValue], Field(title="Extra")] + + +class AssetResponse(BaseModel): + """ + Asset schema for responses with fields that are needed for Runtime. + """ + + name: Annotated[str, Field(title="Name")] + uri: Annotated[str, Field(title="Uri")] + group: Annotated[str, Field(title="Group")] + extra: Annotated[dict[str, JsonValue] | None, Field(title="Extra")] = None + + +class HITLDetailRequest(BaseModel): + """ + Schema for the request part of a Human-in-the-loop detail for a specific task instance. + """ + + ti_id: Annotated[UUID, Field(title="Ti Id")] + options: Annotated[list[str], Field(min_length=1, title="Options")] + subject: Annotated[str, Field(title="Subject")] + body: Annotated[str | None, Field(title="Body")] = None + defaults: Annotated[list[str] | None, Field(title="Defaults")] = None + multiple: Annotated[bool | None, Field(title="Multiple")] = False + params: Annotated[dict[str, Any] | None, Field(title="Params")] = None + assigned_users: Annotated[list[HITLUser] | None, Field(title="Assigned Users")] = None + + +class HITLDetailResponse(BaseModel): + """ + Schema for the response part of a Human-in-the-loop detail for a specific task instance. + """ + + response_received: Annotated[bool, Field(title="Response Received")] + responded_by_user: HITLUser | None = None + responded_at: Annotated[AwareDatetime | None, Field(title="Responded At")] = None + chosen_options: Annotated[list[str] | None, Field(title="Chosen Options")] = None + params_input: Annotated[dict[str, Any] | None, Field(title="Params Input")] = None + + +class HTTPValidationError(BaseModel): + detail: Annotated[list[ValidationError] | None, Field(title="Detail")] = None + + +class TITerminalStatePayload(BaseModel): + """ + Schema for updating TaskInstance to a terminal state except SUCCESS state. + """ + + model_config = ConfigDict( + extra="forbid", + ) + state: TerminalStateNonSuccess + end_date: Annotated[AwareDatetime, Field(title="End Date")] + rendered_map_index: Annotated[str | None, Field(title="Rendered Map Index")] = None + + class AssetEventDagRunReference(BaseModel): """ Schema for AssetEvent model used in DagRun. @@ -486,7 +530,7 @@ class AssetEventDagRunReference(BaseModel): extra="forbid", ) asset: AssetReferenceAssetEventDagRun - extra: Annotated[dict[str, Any], Field(title="Extra")] + extra: Annotated[dict[str, JsonValue], Field(title="Extra")] source_task_id: Annotated[str | None, Field(title="Source Task Id")] = None source_dag_id: Annotated[str | None, Field(title="Source Dag Id")] = None source_run_id: Annotated[str | None, Field(title="Source Run Id")] = None @@ -502,7 +546,7 @@ class AssetEventResponse(BaseModel): id: Annotated[int, Field(title="Id")] timestamp: Annotated[AwareDatetime, Field(title="Timestamp")] - extra: Annotated[dict[str, Any] | None, Field(title="Extra")] = None + extra: Annotated[dict[str, JsonValue] | None, Field(title="Extra")] = None asset: AssetResponse created_dagruns: Annotated[list[DagRunAssetReference], Field(title="Created Dagruns")] source_task_id: Annotated[str | None, Field(title="Source Task Id")] = None @@ -543,37 +587,6 @@ class DagRun(BaseModel): consumed_asset_events: Annotated[list[AssetEventDagRunReference], Field(title="Consumed Asset Events")] -class HITLDetailRequest(BaseModel): - """ - Schema for the request part of a Human-in-the-loop detail for a specific task instance. - """ - - ti_id: Annotated[UUID, Field(title="Ti Id")] - options: Annotated[list[str], Field(min_length=1, title="Options")] - subject: Annotated[str, Field(title="Subject")] - body: Annotated[str | None, Field(title="Body")] = None - defaults: Annotated[list[str] | None, Field(title="Defaults")] = None - multiple: Annotated[bool | None, Field(title="Multiple")] = False - params: Annotated[dict[str, Any] | None, Field(title="Params")] = None - assigned_users: Annotated[list[HITLUser] | None, Field(title="Assigned Users")] = None - - -class HITLDetailResponse(BaseModel): - """ - Schema for the response part of a Human-in-the-loop detail for a specific task instance. - """ - - response_received: Annotated[bool, Field(title="Response Received")] - responded_by_user: HITLUser | None = None - responded_at: Annotated[AwareDatetime | None, Field(title="Responded At")] = None - chosen_options: Annotated[list[str] | None, Field(title="Chosen Options")] = None - params_input: Annotated[dict[str, Any] | None, Field(title="Params Input")] = None - - -class HTTPValidationError(BaseModel): - detail: Annotated[list[ValidationError] | None, Field(title="Detail")] = None - - class TIRunContext(BaseModel): """ Response schema for TaskInstance run context. @@ -591,16 +604,3 @@ class TIRunContext(BaseModel): next_kwargs: Annotated[dict[str, Any] | str | None, Field(title="Next Kwargs")] = None xcom_keys_to_clear: Annotated[list[str] | None, Field(title="Xcom Keys To Clear")] = None should_retry: Annotated[bool | None, Field(title="Should Retry")] = False - - -class TITerminalStatePayload(BaseModel): - """ - Schema for updating TaskInstance to a terminal state except SUCCESS state. - """ - - model_config = ConfigDict( - extra="forbid", - ) - state: TerminalStateNonSuccess - end_date: Annotated[AwareDatetime, Field(title="End Date")] - rendered_map_index: Annotated[str | None, Field(title="Rendered Map Index")] = None diff --git a/task-sdk/src/airflow/sdk/definitions/asset/__init__.py b/task-sdk/src/airflow/sdk/definitions/asset/__init__.py index 4f986efa6bbc3..98a9e270c62a5 100644 --- a/task-sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task-sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -35,6 +35,8 @@ from collections.abc import Iterable, Iterator from urllib.parse import SplitResult + from pydantic.types import JsonValue + from airflow.models.asset import AssetModel from airflow.sdk.io.path import ObjectStoragePath from airflow.serialization.serialized_objects import SerializedAssetWatcher @@ -216,7 +218,7 @@ def _validate_asset_name(instance, attribute, value): return value -def _set_extra_default(extra: dict | None) -> dict: +def _set_extra_default(extra: dict[str, JsonValue] | None) -> dict: """ Automatically convert None to an empty dict. @@ -319,7 +321,7 @@ class Asset(os.PathLike, BaseAsset): default=attrs.Factory(operator.attrgetter("asset_type"), takes_self=True), validator=[_validate_identifier], ) - extra: dict[str, Any] = attrs.field( + extra: dict[str, JsonValue] = attrs.field( factory=dict, converter=_set_extra_default, ) @@ -337,7 +339,7 @@ def __init__( uri: str | ObjectStoragePath, *, group: str = ..., - extra: dict | None = None, + extra: dict[str, JsonValue] | None = None, watchers: list[AssetWatcher | SerializedAssetWatcher] = ..., ) -> None: """Canonical; both name and uri are provided.""" @@ -348,7 +350,7 @@ def __init__( name: str, *, group: str = ..., - extra: dict | None = None, + extra: dict[str, JsonValue] | None = None, watchers: list[AssetWatcher | SerializedAssetWatcher] = ..., ) -> None: """It's possible to only provide the name, either by keyword or as the only positional argument.""" @@ -359,7 +361,7 @@ def __init__( *, uri: str | ObjectStoragePath, group: str = ..., - extra: dict | None = None, + extra: dict[str, JsonValue] | None = None, watchers: list[AssetWatcher | SerializedAssetWatcher] = ..., ) -> None: """It's possible to only provide the URI as a keyword argument.""" @@ -370,7 +372,7 @@ def __init__( uri: str | ObjectStoragePath | None = None, *, group: str | None = None, - extra: dict | None = None, + extra: dict[str, JsonValue] | None = None, watchers: list[AssetWatcher | SerializedAssetWatcher] | None = None, ) -> None: if name is None and uri is None: @@ -686,4 +688,4 @@ class AssetAliasEvent(attrs.AttrsInstance): source_alias_name: str dest_asset_key: AssetUniqueKey - extra: dict[str, Any] + extra: dict[str, JsonValue] diff --git a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py index 44d23efdbbad8..eab493f5eaa1a 100644 --- a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py +++ b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py @@ -29,6 +29,8 @@ if TYPE_CHECKING: from collections.abc import Callable, Collection, Iterator, Mapping + from pydantic.types import JsonValue + from airflow.sdk import DAG, AssetAlias, ObjectStoragePath from airflow.sdk.bases.decorator import _TaskDecorator from airflow.sdk.definitions.asset import AssetUniqueKey @@ -218,7 +220,7 @@ class asset(_DAGFactory): name: str | None = None uri: str | ObjectStoragePath | None = None group: str = Asset.asset_type - extra: dict[str, Any] = attrs.field(factory=dict) + extra: dict[str, JsonValue] = attrs.field(factory=dict) watchers: list[BaseTrigger] = attrs.field(factory=list) @attrs.define(kw_only=True) diff --git a/task-sdk/src/airflow/sdk/definitions/asset/metadata.py b/task-sdk/src/airflow/sdk/definitions/asset/metadata.py index 88e886454290c..ee8c42b3ad4ae 100644 --- a/task-sdk/src/airflow/sdk/definitions/asset/metadata.py +++ b/task-sdk/src/airflow/sdk/definitions/asset/metadata.py @@ -17,11 +17,13 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import attrs if TYPE_CHECKING: + from pydantic.types import JsonValue + from airflow.sdk.definitions.asset import Asset, AssetAlias __all__ = ["Metadata"] @@ -32,5 +34,5 @@ class Metadata: """Metadata to attach to an AssetEvent.""" asset: Asset - extra: dict[str, Any] = attrs.field(factory=dict) + extra: dict[str, JsonValue] = attrs.field(factory=dict) alias: AssetAlias | None = None diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index e4a647bab2f1a..02b9e90d04a11 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -44,6 +44,8 @@ if TYPE_CHECKING: from uuid import UUID + from pydantic.types import JsonValue + from airflow.sdk import Variable from airflow.sdk.bases.operator import BaseOperator from airflow.sdk.definitions.connection import Connection @@ -477,10 +479,10 @@ class OutletEventAccessor(_AssetRefResolutionMixin): """Wrapper to access an outlet asset event in template.""" key: BaseAssetUniqueKey - extra: dict[str, Any] = attrs.Factory(dict) + extra: dict[str, JsonValue] = attrs.Factory(dict) asset_alias_events: list[AssetAliasEvent] = attrs.field(factory=list) - def add(self, asset: Asset | AssetRef, extra: dict[str, Any] | None = None) -> None: + def add(self, asset: Asset | AssetRef, extra: dict[str, JsonValue] | None = None) -> None: """Add an AssetEvent to an existing Asset.""" if not isinstance(self.key, AssetAliasUniqueKey): return diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 7138598027126..8352aa7ef5dff 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -789,7 +789,7 @@ def _build_asset_profiles(lineage_objects: list) -> Iterator[AssetProfile]: yield AssetProfile(name=obj.name, type=AssetAlias.__name__) -def _serialize_outlet_events(events: OutletEventAccessorsProtocol) -> Iterator[dict[str, Any]]: +def _serialize_outlet_events(events: OutletEventAccessorsProtocol) -> Iterator[dict[str, JsonValue]]: if TYPE_CHECKING: assert isinstance(events, OutletEventAccessors) # We just collect everything the user recorded in the accessors. diff --git a/task-sdk/src/airflow/sdk/types.py b/task-sdk/src/airflow/sdk/types.py index 124fb560cc592..c87a7ebd6d148 100644 --- a/task-sdk/src/airflow/sdk/types.py +++ b/task-sdk/src/airflow/sdk/types.py @@ -27,7 +27,7 @@ if TYPE_CHECKING: from collections.abc import Iterator - from pydantic import AwareDatetime + from pydantic import AwareDatetime, JsonValue from airflow.sdk._shared.logging.types import Logger as Logger from airflow.sdk.api.datamodels._generated import TaskInstanceState @@ -129,17 +129,17 @@ class OutletEventAccessorProtocol(Protocol): """Protocol for managing access to a specific outlet event accessor.""" key: BaseAssetUniqueKey - extra: dict[str, Any] + extra: dict[str, JsonValue] asset_alias_events: list[AssetAliasEvent] def __init__( self, *, key: BaseAssetUniqueKey, - extra: dict[str, Any], + extra: dict[str, JsonValue], asset_alias_events: list[AssetAliasEvent], ) -> None: ... - def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None: ... + def add(self, asset: Asset, extra: dict[str, JsonValue] | None = None) -> None: ... class OutletEventAccessorsProtocol(Protocol):