diff --git a/airflow/api_connexion/endpoints/asset_endpoint.py b/airflow/api_connexion/endpoints/asset_endpoint.py index 1bda4fdb2a218..64930b1249468 100644 --- a/airflow/api_connexion/endpoints/asset_endpoint.py +++ b/airflow/api_connexion/endpoints/asset_endpoint.py @@ -45,7 +45,6 @@ ) from airflow.assets.manager import asset_manager from airflow.models.asset import AssetDagRunQueue, AssetEvent, AssetModel -from airflow.sdk.definitions.asset import Asset from airflow.utils import timezone from airflow.utils.api_migration import mark_fastapi_migration_done from airflow.utils.db import get_query_count @@ -341,15 +340,16 @@ def create_asset_event(session: Session = NEW_SESSION) -> APIResponse: except ValidationError as err: raise BadRequest(detail=str(err)) + # TODO: handle name uri = json_body["asset_uri"] - asset = session.scalar(select(AssetModel).where(AssetModel.uri == uri).limit(1)) - if not asset: + asset_model = session.scalar(select(AssetModel).where(AssetModel.uri == uri).limit(1)) + if not asset_model: raise NotFound(title="Asset not found", detail=f"Asset with uri: '{uri}' not found") timestamp = timezone.utcnow() extra = json_body.get("extra", {}) extra["from_rest_api"] = True asset_event = asset_manager.register_asset_change( - asset=Asset(uri=uri), + asset=asset_model.to_public(), timestamp=timestamp, extra=extra, session=session, diff --git a/airflow/api_connexion/schemas/asset_schema.py b/airflow/api_connexion/schemas/asset_schema.py index e83c4f1b42797..078ebb3e75866 100644 --- a/airflow/api_connexion/schemas/asset_schema.py +++ b/airflow/api_connexion/schemas/asset_schema.py @@ -70,6 +70,7 @@ class Meta: id = auto_field() name = auto_field() + group = auto_field() class AssetSchema(SQLAlchemySchema): @@ -82,6 +83,8 @@ class Meta: id = auto_field() uri = auto_field() + name = auto_field() + group = auto_field() extra = JsonObjectField() created_at = auto_field() updated_at = auto_field() diff --git a/airflow/api_fastapi/core_api/datamodels/assets.py b/airflow/api_fastapi/core_api/datamodels/assets.py index 638ee1cba6e29..72bba200fabf3 100644 --- a/airflow/api_fastapi/core_api/datamodels/assets.py +++ b/airflow/api_fastapi/core_api/datamodels/assets.py @@ -47,6 +47,7 @@ class AssetAliasSchema(BaseModel): id: int name: str + group: str class AssetResponse(BaseModel): diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index 6c34fc4cf41f3..7e33084929043 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -5742,10 +5742,14 @@ components: name: type: string title: Name + group: + type: string + title: Group type: object required: - id - name + - group title: AssetAliasSchema description: Asset alias serializer for assets. AssetCollectionResponse: diff --git a/airflow/api_fastapi/core_api/routes/public/assets.py b/airflow/api_fastapi/core_api/routes/public/assets.py index 70bf5b047bbac..db3fa61767a9a 100644 --- a/airflow/api_fastapi/core_api/routes/public/assets.py +++ b/airflow/api_fastapi/core_api/routes/public/assets.py @@ -51,7 +51,6 @@ from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc from airflow.assets.manager import asset_manager from airflow.models.asset import AssetDagRunQueue, AssetEvent, AssetModel -from airflow.sdk.definitions.asset import Asset from airflow.utils import timezone assets_router = AirflowRouter(tags=["Asset"]) @@ -171,13 +170,13 @@ def create_asset_event( session: SessionDep, ) -> AssetEventResponse: """Create asset events.""" - asset = session.scalar(select(AssetModel).where(AssetModel.uri == body.uri).limit(1)) - if not asset: + asset_model = session.scalar(select(AssetModel).where(AssetModel.uri == body.uri).limit(1)) + if not asset_model: raise HTTPException(status.HTTP_404_NOT_FOUND, f"Asset with uri: `{body.uri}` was not found") timestamp = timezone.utcnow() assets_event = asset_manager.register_asset_change( - asset=Asset(uri=body.uri), + asset=asset_model.to_public(), timestamp=timestamp, extra=body.extra, session=session, diff --git a/airflow/assets/manager.py b/airflow/assets/manager.py index 40bc97b8134c7..364d01607e5c4 100644 --- a/airflow/assets/manager.py +++ b/airflow/assets/manager.py @@ -86,16 +86,16 @@ def _add_one(asset_alias: AssetAlias) -> AssetAliasModel: def _add_asset_alias_association( cls, alias_names: Collection[str], - asset: AssetModel, + asset_model: AssetModel, *, session: Session, ) -> None: - already_related = {m.name for m in asset.aliases} + already_related = {m.name for m in asset_model.aliases} existing_aliases = { m.name: m for m in session.scalars(select(AssetAliasModel).where(AssetAliasModel.name.in_(alias_names))) } - asset.aliases.extend( + asset_model.aliases.extend( existing_aliases.get(name, AssetAliasModel(name=name)) for name in alias_names if name not in already_related @@ -121,7 +121,7 @@ def register_asset_change( """ asset_model = session.scalar( select(AssetModel) - .where(AssetModel.uri == asset.uri) + .where(AssetModel.name == asset.name, AssetModel.uri == asset.uri) .options( joinedload(AssetModel.aliases), joinedload(AssetModel.consuming_dags).joinedload(DagScheduleAssetReference.dag), @@ -131,7 +131,9 @@ def register_asset_change( cls.logger().warning("AssetModel %s not found", asset) return None - cls._add_asset_alias_association({alias.name for alias in aliases}, asset_model, session=session) + cls._add_asset_alias_association( + alias_names={alias.name for alias in aliases}, asset_model=asset_model, session=session + ) event_kwargs = { "asset_id": asset_model.id, diff --git a/airflow/lineage/hook.py b/airflow/lineage/hook.py index 9e5f8f6648229..62a2c7a54933f 100644 --- a/airflow/lineage/hook.py +++ b/airflow/lineage/hook.py @@ -95,24 +95,40 @@ def _generate_key(self, asset: Asset, context: LineageContext) -> str: return f"{asset.uri}_{extra_hash}_{id(context)}" def create_asset( - self, scheme: str | None, uri: str | None, asset_kwargs: dict | None, asset_extra: dict | None + self, + *, + scheme: str | None = None, + uri: str | None = None, + name: str | None = None, + group: str | None = None, + asset_kwargs: dict | None = None, + asset_extra: dict | None = None, ) -> Asset | None: """ Create an asset instance using the provided parameters. This method attempts to create an asset instance using the given parameters. - It first checks if a URI is provided and falls back to using the default asset factory - with the given URI if no other information is available. + It first checks if a URI or a name is provided and falls back to using the default asset factory + with the given URI or name if no other information is available. - If a scheme is provided but no URI, it attempts to find an asset factory that matches + If a scheme is provided but no URI or name, it attempts to find an asset factory that matches the given scheme. If no such factory is found, it logs an error message and returns None. If asset_kwargs is provided, it is used to pass additional parameters to the asset factory. The asset_extra parameter is also passed to the factory as an ``extra`` parameter. """ - if uri: + if uri or name: # Fallback to default factory using the provided URI - return Asset(uri=uri, extra=asset_extra) + kwargs: dict[str, str | dict] = {} + if uri: + kwargs["uri"] = uri + if name: + kwargs["name"] = name + if group: + kwargs["group"] = group + if asset_extra: + kwargs["extra"] = asset_extra + return Asset(**kwargs) # type: ignore[call-overload] if not scheme: self.log.debug( @@ -137,11 +153,15 @@ def add_input_asset( context: LineageContext, scheme: str | None = None, uri: str | None = None, + name: str | None = None, + group: str | None = None, asset_kwargs: dict | None = None, asset_extra: dict | None = None, ): """Add the input asset and its corresponding hook execution context to the collector.""" - asset = self.create_asset(scheme=scheme, uri=uri, asset_kwargs=asset_kwargs, asset_extra=asset_extra) + asset = self.create_asset( + scheme=scheme, uri=uri, name=name, group=group, asset_kwargs=asset_kwargs, asset_extra=asset_extra + ) if asset: key = self._generate_key(asset, context) if key not in self._inputs: @@ -153,11 +173,15 @@ def add_output_asset( context: LineageContext, scheme: str | None = None, uri: str | None = None, + name: str | None = None, + group: str | None = None, asset_kwargs: dict | None = None, asset_extra: dict | None = None, ): """Add the output asset and its corresponding hook execution context to the collector.""" - asset = self.create_asset(scheme=scheme, uri=uri, asset_kwargs=asset_kwargs, asset_extra=asset_extra) + asset = self.create_asset( + scheme=scheme, uri=uri, name=name, group=group, asset_kwargs=asset_kwargs, asset_extra=asset_extra + ) if asset: key = self._generate_key(asset, context) if key not in self._outputs: diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 1a13430e2fcb5..f78a2b78b8811 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -44,7 +44,11 @@ from airflow.models.connection import Connection from airflow.models.dag import DAG, DagModel from airflow.models.dagrun import DagRun -from airflow.models.expandinput import EXPAND_INPUT_EMPTY, create_expand_input, get_map_type_key +from airflow.models.expandinput import ( + EXPAND_INPUT_EMPTY, + create_expand_input, + get_map_type_key, +) from airflow.models.mappedoperator import MappedOperator from airflow.models.param import Param, ParamsDict from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance @@ -213,7 +217,9 @@ def _get_registered_timetable(importable_string: str) -> type[Timetable] | None: return None -def _get_registered_priority_weight_strategy(importable_string: str) -> type[PriorityWeightStrategy] | None: +def _get_registered_priority_weight_strategy( + importable_string: str, +) -> type[PriorityWeightStrategy] | None: from airflow import plugins_manager if importable_string in airflow_priority_weight_strategies: @@ -256,13 +262,25 @@ def encode_asset_condition(var: BaseAsset) -> dict[str, Any]: :meta private: """ if isinstance(var, Asset): - return {"__type": DAT.ASSET, "name": var.name, "uri": var.uri, "extra": var.extra} + return { + "__type": DAT.ASSET, + "name": var.name, + "uri": var.uri, + "group": var.group, + "extra": var.extra, + } if isinstance(var, AssetAlias): - return {"__type": DAT.ASSET_ALIAS, "name": var.name} + return {"__type": DAT.ASSET_ALIAS, "name": var.name, "group": var.group} if isinstance(var, AssetAll): - return {"__type": DAT.ASSET_ALL, "objects": [encode_asset_condition(x) for x in var.objects]} + return { + "__type": DAT.ASSET_ALL, + "objects": [encode_asset_condition(x) for x in var.objects], + } if isinstance(var, AssetAny): - return {"__type": DAT.ASSET_ANY, "objects": [encode_asset_condition(x) for x in var.objects]} + return { + "__type": DAT.ASSET_ANY, + "objects": [encode_asset_condition(x) for x in var.objects], + } raise ValueError(f"serialization not implemented for {type(var).__name__!r}") @@ -274,13 +292,13 @@ def decode_asset_condition(var: dict[str, Any]) -> BaseAsset: """ dat = var["__type"] if dat == DAT.ASSET: - return Asset(uri=var["uri"], name=var["name"], extra=var["extra"]) + return Asset(name=var["name"], uri=var["uri"], group=var["group"], extra=var["extra"]) if dat == DAT.ASSET_ALL: return AssetAll(*(decode_asset_condition(x) for x in var["objects"])) if dat == DAT.ASSET_ANY: return AssetAny(*(decode_asset_condition(x) for x in var["objects"])) if dat == DAT.ASSET_ALIAS: - return AssetAlias(name=var["name"]) + return AssetAlias(name=var["name"], group=var["group"]) raise ValueError(f"deserialization not implemented for DAT {dat!r}") @@ -586,7 +604,9 @@ def _is_excluded(cls, var: Any, attrname: str, instance: Any) -> bool: @classmethod def serialize_to_json( - cls, object_to_serialize: BaseOperator | MappedOperator | DAG, decorated_fields: set + cls, + object_to_serialize: BaseOperator | MappedOperator | DAG, + decorated_fields: set, ) -> dict[str, Any]: """Serialize an object to JSON.""" serialized_object: dict[str, Any] = {} @@ -653,7 +673,11 @@ def serialize( return cls._encode(json_pod, type_=DAT.POD) elif isinstance(var, OutletEventAccessors): return cls._encode( - cls.serialize(var._dict, strict=strict, use_pydantic_models=use_pydantic_models), # type: ignore[attr-defined] + cls.serialize( + var._dict, # type: ignore[attr-defined] + strict=strict, + use_pydantic_models=use_pydantic_models, + ), type_=DAT.ASSET_EVENT_ACCESSORS, ) elif isinstance(var, OutletEventAccessor): @@ -696,7 +720,11 @@ def serialize( elif isinstance(var, (KeyError, AttributeError)): return cls._encode( cls.serialize( - {"exc_cls_name": var.__class__.__name__, "args": [var.args], "kwargs": {}}, + { + "exc_cls_name": var.__class__.__name__, + "args": [var.args], + "kwargs": {}, + }, use_pydantic_models=use_pydantic_models, strict=strict, ), @@ -704,7 +732,11 @@ def serialize( ) elif isinstance(var, BaseTrigger): return cls._encode( - cls.serialize(var.serialize(), use_pydantic_models=use_pydantic_models, strict=strict), + cls.serialize( + var.serialize(), + use_pydantic_models=use_pydantic_models, + strict=strict, + ), type_=DAT.BASE_TRIGGER, ) elif callable(var): @@ -1065,11 +1097,11 @@ def detect_task_dependencies(task: Operator) -> list[DagDependency]: source=task.dag_id, target="asset", dependency_type="asset", - dependency_id=obj.uri, + dependency_id=obj.name, ) ) elif isinstance(obj, AssetAlias): - cond = AssetAliasCondition(obj.name) + cond = AssetAliasCondition(name=obj.name, group=obj.group) deps.extend(cond.iter_dag_dependencies(source=task.dag_id, target="")) return deps @@ -1298,7 +1330,11 @@ def populate_operator(cls, op: Operator, encoded_op: dict[str, Any]) -> None: # The case for "If OperatorLinks are defined in the operator that is being Serialized" # is handled in the deserialization loop where it matches k == "_operator_extra_links" if op_extra_links_from_plugin and "_operator_extra_links" not in encoded_op: - setattr(op, "operator_extra_links", list(op_extra_links_from_plugin.values())) + setattr( + op, + "operator_extra_links", + list(op_extra_links_from_plugin.values()), + ) for k, v in encoded_op.items(): # python_callable_name only serves to detect function name changes diff --git a/airflow/timetables/base.py b/airflow/timetables/base.py index 60b1c141209fe..b80a6323a8c18 100644 --- a/airflow/timetables/base.py +++ b/airflow/timetables/base.py @@ -19,7 +19,7 @@ from collections.abc import Iterator, Sequence from typing import TYPE_CHECKING, Any, NamedTuple -from airflow.sdk.definitions.asset import BaseAsset +from airflow.sdk.definitions.asset import AssetUniqueKey, BaseAsset from airflow.typing_compat import Protocol, runtime_checkable if TYPE_CHECKING: @@ -55,7 +55,7 @@ def as_expression(self) -> Any: def evaluate(self, statuses: dict[str, bool]) -> bool: return False - def iter_assets(self) -> Iterator[tuple[str, Asset]]: + def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]: return iter(()) def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: diff --git a/airflow/timetables/simple.py b/airflow/timetables/simple.py index f282c7fe67f8a..57eec884b558a 100644 --- a/airflow/timetables/simple.py +++ b/airflow/timetables/simple.py @@ -170,7 +170,7 @@ def __init__(self, assets: BaseAsset) -> None: super().__init__() self.asset_condition = assets if isinstance(self.asset_condition, AssetAlias): - self.asset_condition = AssetAliasCondition(self.asset_condition.name) + self.asset_condition = AssetAliasCondition.from_asset_alias(self.asset_condition) if not next(self.asset_condition.iter_assets(), False): self._summary = AssetTriggeredTimetable.UNRESOLVED_ALIAS_SUMMARY diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow/ui/openapi-gen/requests/schemas.gen.ts index c2cf77baab6ff..6e05c32f75302 100644 --- a/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -99,9 +99,13 @@ export const $AssetAliasSchema = { type: "string", title: "Name", }, + group: { + type: "string", + title: "Group", + }, }, type: "object", - required: ["id", "name"], + required: ["id", "name", "group"], title: "AssetAliasSchema", description: "Asset alias serializer for assets.", } as const; diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow/ui/openapi-gen/requests/types.gen.ts index dcb3ec94f9526..545ef594f471e 100644 --- a/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow/ui/openapi-gen/requests/types.gen.ts @@ -27,6 +27,7 @@ export type AppBuilderViewResponse = { export type AssetAliasSchema = { id: number; name: string; + group: string; }; /** diff --git a/newsfragments/43774.significant.rst b/newsfragments/43774.significant.rst new file mode 100644 index 0000000000000..b716e1fc83f94 --- /dev/null +++ b/newsfragments/43774.significant.rst @@ -0,0 +1,22 @@ +``HookLineageCollector.create_asset`` now accept only keyword arguments + +To provider AIP-74 support, new arguments "name" and "group" are added to ``HookLineageCollector.create_asset``. +For easier change in the future, this function now takes only keyword arguments. + +.. Check the type of change that applies to this change + +* Types of change + + * [ ] DAG changes + * [ ] Config changes + * [ ] API changes + * [ ] CLI changes + * [x] Behaviour changes + * [ ] Plugin changes + * [ ] Dependency change + +.. List the migration rules needed for this change (see https://github.com/apache/airflow/issues/41641) + +* Migrations rules needed + + * Calling ``HookLineageCollector.create_asset`` with positional argument should raise an error diff --git a/providers/tests/openlineage/plugins/test_utils.py b/providers/tests/openlineage/plugins/test_utils.py index e84fac1186573..3d41e87cf0152 100644 --- a/providers/tests/openlineage/plugins/test_utils.py +++ b/providers/tests/openlineage/plugins/test_utils.py @@ -334,9 +334,9 @@ def test_serialize_timetable(): from airflow.timetables.simple import AssetTriggeredTimetable asset = AssetAny( - Asset("2"), - AssetAlias("example-alias"), - Asset("3"), + Asset(name="2", uri="test://2", group="test-group"), + AssetAlias(name="example-alias", group="test-group"), + Asset(name="3", uri="test://3", group="test-group"), AssetAll(AssetAlias("this-should-not-be-seen"), Asset("4")), ) dag = MagicMock() @@ -347,14 +347,32 @@ def test_serialize_timetable(): "asset_condition": { "__type": DagAttributeTypes.ASSET_ANY, "objects": [ - {"__type": DagAttributeTypes.ASSET, "extra": {}, "name": "2", "uri": "2"}, + { + "__type": DagAttributeTypes.ASSET, + "extra": {}, + "uri": "test://2/", + "name": "2", + "group": "test-group", + }, {"__type": DagAttributeTypes.ASSET_ANY, "objects": []}, - {"__type": DagAttributeTypes.ASSET, "extra": {}, "name": "3", "uri": "3"}, + { + "__type": DagAttributeTypes.ASSET, + "extra": {}, + "uri": "test://3/", + "name": "3", + "group": "test-group", + }, { "__type": DagAttributeTypes.ASSET_ALL, "objects": [ {"__type": DagAttributeTypes.ASSET_ANY, "objects": []}, - {"__type": DagAttributeTypes.ASSET, "extra": {}, "name": "4", "uri": "4"}, + { + "__type": DagAttributeTypes.ASSET, + "extra": {}, + "uri": "4", + "name": "4", + "group": "asset", + }, ], }, ], diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py index 812c30261bb97..81af48a6b41b4 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -27,6 +27,7 @@ Any, Callable, ClassVar, + NamedTuple, cast, overload, ) @@ -63,6 +64,15 @@ log = logging.getLogger(__name__) +class AssetUniqueKey(NamedTuple): + name: str + uri: str + + @staticmethod + def from_asset(asset: Asset) -> AssetUniqueKey: + return AssetUniqueKey(name=asset.name, uri=asset.uri) + + def normalize_noop(parts: SplitResult) -> SplitResult: """ Place-hold a :class:`~urllib.parse.SplitResult`` normalizer. @@ -203,7 +213,7 @@ def as_expression(self) -> Any: def evaluate(self, statuses: dict[str, bool]) -> bool: raise NotImplementedError - def iter_assets(self) -> Iterator[tuple[str, Asset]]: + def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]: raise NotImplementedError def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: @@ -350,10 +360,10 @@ def as_expression(self) -> Any: :meta private: """ - return self.uri + return {"asset": {"uri": self.uri, "name": self.name, "group": self.group}} - def iter_assets(self) -> Iterator[tuple[str, Asset]]: - yield self.uri, self + def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]: + yield AssetUniqueKey.from_asset(self), self def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: return iter(()) @@ -371,7 +381,7 @@ def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDepe source=source or "asset", target=target or "asset", dependency_type="asset", - dependency_id=self.uri, + dependency_id=self.name, ) @@ -401,7 +411,7 @@ class AssetAlias(BaseAsset): name: str = attrs.field(validator=_validate_non_empty_identifier) group: str = attrs.field(kw_only=True, default="", validator=_validate_identifier) - def iter_assets(self) -> Iterator[tuple[str, Asset]]: + def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]: return iter(()) def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: @@ -439,13 +449,14 @@ def __init__(self, *objects: BaseAsset) -> None: raise TypeError("expect asset expressions in condition") self.objects = [ - AssetAliasCondition(obj.name) if isinstance(obj, AssetAlias) else obj for obj in objects + AssetAliasCondition.from_asset_alias(obj) if isinstance(obj, AssetAlias) else obj + for obj in objects ] def evaluate(self, statuses: dict[str, bool]) -> bool: return self.agg_func(x.evaluate(statuses=statuses) for x in self.objects) - def iter_assets(self) -> Iterator[tuple[str, Asset]]: + def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]: seen = set() # We want to keep the first instance. for o in self.objects: for k, v in o.iter_assets(): @@ -515,8 +526,9 @@ class AssetAliasCondition(AssetAny): :meta private: """ - def __init__(self, name: str) -> None: + def __init__(self, name: str, group: str) -> None: self.name = name + self.group = group self.objects = expand_alias_to_assets(name) def __repr__(self) -> str: @@ -528,7 +540,7 @@ def as_expression(self) -> Any: :meta private: """ - return {"alias": self.name} + return {"alias": {"name": self.name, "group": self.group}} def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: yield self.name, AssetAlias(self.name) @@ -542,18 +554,18 @@ def iter_dag_dependencies(self, *, source: str = "", target: str = "") -> Iterat if self.objects: for obj in self.objects: asset = cast(Asset, obj) - uri = asset.uri + asset_name = asset.name # asset yield DagDependency( source=f"asset-alias:{self.name}" if source else "asset", target="asset" if source else f"asset-alias:{self.name}", dependency_type="asset", - dependency_id=uri, + dependency_id=asset_name, ) # asset alias yield DagDependency( - source=source or f"asset:{uri}", - target=target or f"asset:{uri}", + source=source or f"asset:{asset_name}", + target=target or f"asset:{asset_name}", dependency_type="asset-alias", dependency_id=self.name, ) @@ -565,6 +577,10 @@ def iter_dag_dependencies(self, *, source: str = "", target: str = "") -> Iterat dependency_id=self.name, ) + @staticmethod + def from_asset_alias(asset_alias: AssetAlias) -> AssetAliasCondition: + return AssetAliasCondition(name=asset_alias.name, group=asset_alias.group) + class AssetAll(_AssetBooleanCondition): """Use to combine assets schedule references in an "or" relationship.""" diff --git a/task_sdk/tests/defintions/test_asset.py b/task_sdk/tests/defintions/test_asset.py index ef602ea5a2267..d9aa6305f579d 100644 --- a/task_sdk/tests/defintions/test_asset.py +++ b/task_sdk/tests/defintions/test_asset.py @@ -170,18 +170,18 @@ def test_asset_logic_operations(): def test_asset_iter_assets(): - assert list(asset1.iter_assets()) == [("s3://bucket1/data1", asset1)] + assert list(asset1.iter_assets()) == [(("asset-1", "s3://bucket1/data1"), asset1)] @pytest.mark.db_test def test_asset_iter_asset_aliases(): base_asset = AssetAll( - AssetAlias("example-alias-1"), + AssetAlias(name="example-alias-1"), Asset("1"), AssetAny( - Asset("2"), + Asset(name="2", uri="test://asset1"), AssetAlias("example-alias-2"), - Asset("3"), + Asset(name="3"), AssetAll(AssetAlias("example-alias-3"), Asset("4"), AssetAlias("example-alias-4")), ), AssetAll(AssetAlias("example-alias-5"), Asset("5")), @@ -225,8 +225,14 @@ def test_assset_boolean_condition_evaluate_iter(): # Testing iter_assets indirectly through the subclasses assets_any = dict(any_condition.iter_assets()) assets_all = dict(all_condition.iter_assets()) - assert assets_any == {"s3://bucket1/data1": asset1, "s3://bucket2/data2": asset2} - assert assets_all == {"s3://bucket1/data1": asset1, "s3://bucket2/data2": asset2} + assert assets_any == { + ("asset-1", "s3://bucket1/data1"): asset1, + ("asset-2", "s3://bucket2/data2"): asset2, + } + assert assets_all == { + ("asset-1", "s3://bucket1/data1"): asset1, + ("asset-2", "s3://bucket2/data2"): asset2, + } @pytest.mark.parametrize( @@ -254,7 +260,7 @@ def test_assset_boolean_condition_evaluate_iter(): ) def test_asset_logical_conditions_evaluation_and_serialization(inputs, scenario, expected): class_ = AssetAny if scenario == "any" else AssetAll - assets = [Asset(uri=f"s3://abc/{i}") for i in range(123, 126)] + assets = [Asset(uri=f"s3://abc/{i}", name=f"asset_{i}") for i in range(123, 126)] condition = class_(*assets) statuses = {asset.uri: status for asset, status in zip(assets, inputs)} @@ -274,31 +280,31 @@ def test_asset_logical_conditions_evaluation_and_serialization(inputs, scenario, ( (False, True, True), False, - ), # AssetAll requires all conditions to be True, but d1 is False + ), # AssetAll requires all conditions to be True, but asset1 is False ((True, True, True), True), # All conditions are True ( (True, False, True), True, - ), # d1 is True, and AssetAny condition (d2 or d3 being True) is met + ), # asset1 is True, and AssetAny condition (asset2 or asset3 being True) is met ( (True, False, False), False, - ), # d1 is True, but neither d2 nor d3 meet the AssetAny condition + ), # asset1 is True, but neither asset2 nor asset3 meet the AssetAny condition ], ) def test_nested_asset_conditions_with_serialization(status_values, expected_evaluation): # Define assets - d1 = Asset(uri="s3://abc/123") - d2 = Asset(uri="s3://abc/124") - d3 = Asset(uri="s3://abc/125") + asset1 = Asset(uri="s3://abc/123") + asset2 = Asset(uri="s3://abc/124") + asset3 = Asset(uri="s3://abc/125") - # Create a nested condition: AssetAll with d1 and AssetAny with d2 and d3 - nested_condition = AssetAll(d1, AssetAny(d2, d3)) + # Create a nested condition: AssetAll with asset1 and AssetAny with asset2 and asset3 + nested_condition = AssetAll(asset1, AssetAny(asset2, asset3)) statuses = { - d1.uri: status_values[0], - d2.uri: status_values[1], - d3.uri: status_values[2], + asset1.uri: status_values[0], + asset2.uri: status_values[1], + asset3.uri: status_values[2], } assert nested_condition.evaluate(statuses) == expected_evaluation, "Initial evaluation mismatch" @@ -314,7 +320,7 @@ def test_nested_asset_conditions_with_serialization(status_values, expected_eval @pytest.fixture def create_test_assets(session): """Fixture to create test assets and corresponding models.""" - assets = [Asset(uri=f"hello{i}") for i in range(1, 3)] + assets = [Asset(uri=f"test://asset{i}", name=f"hello{i}") for i in range(1, 3)] for asset in assets: session.add(AssetModel(uri=asset.uri)) session.commit() @@ -380,17 +386,17 @@ def test_asset_dag_run_queue_processing(session, clear_assets, dag_maker, create @pytest.mark.usefixtures("clear_assets") def test_dag_with_complex_asset_condition(session, dag_maker): # Create Asset instances - d1 = Asset(uri="hello1") - d2 = Asset(uri="hello2") + asset1 = Asset(uri="test://asset1", name="hello1") + asset2 = Asset(uri="test://asset2", name="hello2") # Create and add AssetModel instances to the session - am1 = AssetModel(uri=d1.uri) - am2 = AssetModel(uri=d2.uri) + am1 = AssetModel(uri=asset1.uri, name=asset1.name, group="asset") + am2 = AssetModel(uri=asset2.uri, name=asset2.name, group="asset") session.add_all([am1, am2]) session.commit() # Setup a DAG with complex asset triggers (AssetAny with AssetAll) - with dag_maker(schedule=AssetAny(d1, AssetAll(d2, d1))) as dag: + with dag_maker(schedule=AssetAny(asset1, AssetAll(asset2, asset1))) as dag: EmptyOperator(task_id="hello") assert isinstance( @@ -442,11 +448,11 @@ def assets_equal(a1: BaseAsset, a2: BaseAsset) -> bool: return False -asset1 = Asset(uri="s3://bucket1/data1") -asset2 = Asset(uri="s3://bucket2/data2") -asset3 = Asset(uri="s3://bucket3/data3") -asset4 = Asset(uri="s3://bucket4/data4") -asset5 = Asset(uri="s3://bucket5/data5") +asset1 = Asset(uri="s3://bucket1/data1", name="asset-1") +asset2 = Asset(uri="s3://bucket2/data2", name="asset-2") +asset3 = Asset(uri="s3://bucket3/data3", name="asset-3") +asset4 = Asset(uri="s3://bucket4/data4", name="asset-4") +asset5 = Asset(uri="s3://bucket5/data5", name="asset-5") test_cases = [ (lambda: asset1, asset1), @@ -579,21 +585,27 @@ def test_normalize_uri_valid_uri(): @pytest.mark.usefixtures("clear_assets") class TestAssetAliasCondition: @pytest.fixture - def asset_1(self, session): + def asset_model(self, session): """Example asset links to asset alias resolved_asset_alias_2.""" - asset_uri = "test_uri" - asset_1 = AssetModel(id=1, uri=asset_uri) - - session.add(asset_1) + asset_model = AssetModel( + id=1, + uri="test://asset1/", + name="test_name", + group="asset", + ) + + session.add(asset_model) session.commit() - return asset_1 + return asset_model @pytest.fixture def asset_alias_1(self, session): """Example asset alias links to no assets.""" - alias_name = "test_name" - asset_alias_model = AssetAliasModel(name=alias_name) + asset_alias_model = AssetAliasModel( + name="test_name", + group="test", + ) session.add(asset_alias_model) session.commit() @@ -601,35 +613,34 @@ def asset_alias_1(self, session): return asset_alias_model @pytest.fixture - def resolved_asset_alias_2(self, session, asset_1): + def resolved_asset_alias_2(self, session, asset_model): """Example asset alias links to asset asset_alias_1.""" - asset_name = "test_name_2" - asset_alias_2 = AssetAliasModel(name=asset_name) - asset_alias_2.assets.append(asset_1) + asset_alias_2 = AssetAliasModel(name="test_name_2") + asset_alias_2.assets.append(asset_model) session.add(asset_alias_2) session.commit() return asset_alias_2 - def test_init(self, asset_alias_1, asset_1, resolved_asset_alias_2): - cond = AssetAliasCondition(name=asset_alias_1.name) + def test_init(self, asset_alias_1, asset_model, resolved_asset_alias_2): + cond = AssetAliasCondition.from_asset_alias(asset_alias_1) assert cond.objects == [] - cond = AssetAliasCondition(name=resolved_asset_alias_2.name) - assert cond.objects == [Asset(uri=asset_1.uri)] + cond = AssetAliasCondition.from_asset_alias(resolved_asset_alias_2) + assert cond.objects == [Asset(uri=asset_model.uri, name=asset_model.name)] def test_as_expression(self, asset_alias_1, resolved_asset_alias_2): - for assset_alias in (asset_alias_1, resolved_asset_alias_2): - cond = AssetAliasCondition(assset_alias.name) - assert cond.as_expression() == {"alias": assset_alias.name} + for asset_alias in (asset_alias_1, resolved_asset_alias_2): + cond = AssetAliasCondition.from_asset_alias(asset_alias) + assert cond.as_expression() == {"alias": {"name": asset_alias.name, "group": asset_alias.group}} - def test_evalute(self, asset_alias_1, resolved_asset_alias_2, asset_1): - cond = AssetAliasCondition(asset_alias_1.name) - assert cond.evaluate({asset_1.uri: True}) is False + def test_evalute(self, asset_alias_1, resolved_asset_alias_2, asset_model): + cond = AssetAliasCondition.from_asset_alias(asset_alias_1) + assert cond.evaluate({asset_model.uri: True}) is False - cond = AssetAliasCondition(resolved_asset_alias_2.name) - assert cond.evaluate({asset_1.uri: True}) is True + cond = AssetAliasCondition.from_asset_alias(resolved_asset_alias_2) + assert cond.evaluate({asset_model.uri: True}) is True class TestAssetSubclasses: diff --git a/tests/api_connexion/endpoints/test_asset_endpoint.py b/tests/api_connexion/endpoints/test_asset_endpoint.py index db064ac5b443e..57bea9c6643e2 100644 --- a/tests/api_connexion/endpoints/test_asset_endpoint.py +++ b/tests/api_connexion/endpoints/test_asset_endpoint.py @@ -80,6 +80,8 @@ def _create_asset(self, session): asset_model = AssetModel( id=1, uri="s3://bucket/key", + name="asset-name", + group="asset", extra={"foo": "bar"}, created_at=timezone.parse(self.default_time), updated_at=timezone.parse(self.default_time), @@ -103,6 +105,8 @@ def test_should_respond_200(self, session): assert response.json == { "id": 1, "uri": "s3://bucket/key", + "name": "asset-name", + "group": "asset", "extra": {"foo": "bar"}, "created_at": self.default_time, "updated_at": self.default_time, @@ -136,6 +140,8 @@ def test_should_respond_200(self, session): AssetModel( id=i, uri=f"s3://bucket/key/{i}", + name=f"asset_{i}", + group="asset", extra={"foo": "bar"}, created_at=timezone.parse(self.default_time), updated_at=timezone.parse(self.default_time), @@ -156,6 +162,8 @@ def test_should_respond_200(self, session): { "id": 1, "uri": "s3://bucket/key/1", + "name": "asset_1", + "group": "asset", "extra": {"foo": "bar"}, "created_at": self.default_time, "updated_at": self.default_time, @@ -166,6 +174,8 @@ def test_should_respond_200(self, session): { "id": 2, "uri": "s3://bucket/key/2", + "name": "asset_2", + "group": "asset", "extra": {"foo": "bar"}, "created_at": self.default_time, "updated_at": self.default_time, diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py b/tests/api_connexion/endpoints/test_dag_run_endpoint.py index 5b4133c683958..45e6bf53376f8 100644 --- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py @@ -1738,7 +1738,7 @@ def test_should_respond_404(self): @pytest.mark.need_serialized_dag class TestGetDagRunAssetTriggerEvents(TestDagRunEndpoint): def test_should_respond_200(self, dag_maker, session): - asset1 = Asset(uri="ds1") + asset1 = Asset(uri="test://asset1", name="asset1") with dag_maker(dag_id="source_dag", start_date=timezone.utcnow(), session=session): EmptyOperator(task_id="task", outlets=[asset1]) diff --git a/tests/api_connexion/schemas/test_asset_schema.py b/tests/api_connexion/schemas/test_asset_schema.py index af5e8c08b86a6..ff5a81961e949 100644 --- a/tests/api_connexion/schemas/test_asset_schema.py +++ b/tests/api_connexion/schemas/test_asset_schema.py @@ -54,6 +54,8 @@ class TestAssetSchema(TestAssetSchemaBase): def test_serialize(self, dag_maker, session): asset = Asset( uri="s3://bucket/key", + name="test_asset", + group="test-group", extra={"foo": "bar"}, ) with dag_maker(dag_id="test_asset_upstream_schema", serialized=True, session=session): @@ -70,6 +72,8 @@ def test_serialize(self, dag_maker, session): assert serialized_data == { "id": 1, "uri": "s3://bucket/key", + "name": "test_asset", + "group": "test-group", "extra": {"foo": "bar"}, "created_at": self.timestamp, "updated_at": self.timestamp, @@ -96,12 +100,14 @@ class TestAssetCollectionSchema(TestAssetSchemaBase): def test_serialize(self, session): assets = [ AssetModel( - uri=f"s3://bucket/key/{i+1}", + uri=f"s3://bucket/key/{i}", + name=f"asset_{i}", + group="test-group", extra={"foo": "bar"}, ) - for i in range(2) + for i in range(1, 3) ] - asset_aliases = [AssetAliasModel(name=f"alias_{i}") for i in range(2)] + asset_aliases = [AssetAliasModel(name=f"alias_{i}", group="test-alias-group") for i in range(2)] for asset_alias in asset_aliases: asset_alias.assets.append(assets[0]) session.add_all(assets) @@ -117,19 +123,23 @@ def test_serialize(self, session): { "id": 1, "uri": "s3://bucket/key/1", + "name": "asset_1", + "group": "test-group", "extra": {"foo": "bar"}, "created_at": self.timestamp, "updated_at": self.timestamp, "consuming_dags": [], "producing_tasks": [], "aliases": [ - {"id": 1, "name": "alias_0"}, - {"id": 2, "name": "alias_1"}, + {"id": 1, "name": "alias_0", "group": "test-alias-group"}, + {"id": 2, "name": "alias_1", "group": "test-alias-group"}, ], }, { "id": 2, "uri": "s3://bucket/key/2", + "name": "asset_2", + "group": "test-group", "extra": {"foo": "bar"}, "created_at": self.timestamp, "updated_at": self.timestamp, diff --git a/tests/api_connexion/schemas/test_dag_schema.py b/tests/api_connexion/schemas/test_dag_schema.py index 4f1b07fb6e70f..d6438045249aa 100644 --- a/tests/api_connexion/schemas/test_dag_schema.py +++ b/tests/api_connexion/schemas/test_dag_schema.py @@ -198,8 +198,8 @@ def test_serialize_test_dag_detail_schema(url_safe_serializer): @pytest.mark.db_test def test_serialize_test_dag_with_asset_schedule_detail_schema(url_safe_serializer): - asset1 = Asset(uri="s3://bucket/obj1") - asset2 = Asset(uri="s3://bucket/obj2") + asset1 = Asset(uri="s3://bucket/obj1", name="asset1") + asset2 = Asset(uri="s3://bucket/obj2", name="asset2") dag = DAG( dag_id="test_dag", start_date=datetime(2020, 6, 19), diff --git a/tests/api_fastapi/core_api/routes/public/test_assets.py b/tests/api_fastapi/core_api/routes/public/test_assets.py index a20353d32f86d..9218cbbf820cd 100644 --- a/tests/api_fastapi/core_api/routes/public/test_assets.py +++ b/tests/api_fastapi/core_api/routes/public/test_assets.py @@ -722,7 +722,7 @@ def test_should_respond_200(self, test_client, session): } def test_invalid_attr_not_allowed(self, test_client, session): - self.create_assets() + self.create_assets(session) event_invalid_payload = {"asset_uri": "s3://bucket/key/1", "extra": {"foo": "bar"}, "fake": {}} response = test_client.post("/public/assets/events", json=event_invalid_payload) @@ -731,7 +731,7 @@ def test_invalid_attr_not_allowed(self, test_client, session): @pytest.mark.usefixtures("time_freezer") @pytest.mark.enable_redact def test_should_mask_sensitive_extra(self, test_client, session): - self.create_assets() + self.create_assets(session) event_payload = {"uri": "s3://bucket/key/1", "extra": {"password": "bar"}} response = test_client.post("/public/assets/events", json=event_payload) assert response.status_code == 200 diff --git a/tests/api_fastapi/core_api/routes/ui/test_assets.py b/tests/api_fastapi/core_api/routes/ui/test_assets.py index 8eafb0f8bdd4b..7b532918496ed 100644 --- a/tests/api_fastapi/core_api/routes/ui/test_assets.py +++ b/tests/api_fastapi/core_api/routes/ui/test_assets.py @@ -36,7 +36,11 @@ def cleanup(): def test_next_run_assets(test_client, dag_maker): - with dag_maker(dag_id="upstream", schedule=[Asset(uri="s3://bucket/key/1")], serialized=True): + with dag_maker( + dag_id="upstream", + schedule=[Asset(uri="s3://bucket/next-run-asset/1", name="asset1")], + serialized=True, + ): EmptyOperator(task_id="task1") dag_maker.create_dagrun() @@ -46,6 +50,16 @@ def test_next_run_assets(test_client, dag_maker): assert response.status_code == 200 assert response.json() == { - "asset_expression": {"all": ["s3://bucket/key/1"]}, - "events": [{"id": 20, "uri": "s3://bucket/key/1", "lastUpdate": None}], + "asset_expression": { + "all": [ + { + "asset": { + "uri": "s3://bucket/next-run-asset/1", + "name": "asset1", + "group": "asset", + } + } + ] + }, + "events": [{"id": 20, "uri": "s3://bucket/next-run-asset/1", "lastUpdate": None}], } diff --git a/tests/assets/test_manager.py b/tests/assets/test_manager.py index aa8fbb036242e..b716056e81466 100644 --- a/tests/assets/test_manager.py +++ b/tests/assets/test_manager.py @@ -112,7 +112,7 @@ def create_mock_dag(): class TestAssetManager: def test_register_asset_change_asset_doesnt_exist(self, mock_task_instance): - asset = Asset(uri="asset_doesnt_exist") + asset = Asset(uri="asset_doesnt_exist", name="not exist") mock_session = mock.Mock() # Gotta mock up the query results @@ -131,12 +131,12 @@ def test_register_asset_change_asset_doesnt_exist(self, mock_task_instance): def test_register_asset_change(self, session, dag_maker, mock_task_instance): asset_manager = AssetManager() - asset = Asset(uri="test_asset_uri") + asset = Asset(uri="test://asset1", name="test_asset_uri", group="asset") dag1 = DagModel(dag_id="dag1", is_active=True) dag2 = DagModel(dag_id="dag2", is_active=True) session.add_all([dag1, dag2]) - asm = AssetModel(uri="test_asset_uri") + asm = AssetModel(uri="test://asset1/", name="test_asset_uri", group="asset") session.add(asm) asm.consuming_dags = [DagScheduleAssetReference(dag_id=dag.dag_id) for dag in (dag1, dag2)] session.execute(delete(AssetDagRunQueue)) @@ -155,10 +155,10 @@ def test_register_asset_change_with_alias(self, session, dag_maker, mock_task_in consumer_dag_2 = DagModel(dag_id="conumser_2", is_active=True, fileloc="dag2.py") session.add_all([consumer_dag_1, consumer_dag_2]) - asm = AssetModel(uri="test_asset_uri") + asm = AssetModel(uri="test://asset1/", name="test_asset_uri", group="asset") session.add(asm) - asam = AssetAliasModel(name="test_alias_name") + asam = AssetAliasModel(name="test_alias_name", group="test") session.add(asam) asam.consuming_dags = [ DagScheduleAssetAliasReference(alias_id=asam.id, dag_id=dag.dag_id) @@ -167,8 +167,8 @@ def test_register_asset_change_with_alias(self, session, dag_maker, mock_task_in session.execute(delete(AssetDagRunQueue)) session.flush() - asset = Asset(uri="test_asset_uri") - asset_alias = AssetAlias(name="test_alias_name") + asset = Asset(uri="test://asset1", name="test_asset_uri") + asset_alias = AssetAlias(name="test_alias_name", group="test") asset_manager = AssetManager() asset_manager.register_asset_change( task_instance=mock_task_instance, @@ -187,8 +187,8 @@ def test_register_asset_change_with_alias(self, session, dag_maker, mock_task_in def test_register_asset_change_no_downstreams(self, session, mock_task_instance): asset_manager = AssetManager() - asset = Asset(uri="never_consumed") - asm = AssetModel(uri="never_consumed") + asset = Asset(uri="test://asset1", name="never_consumed") + asm = AssetModel(uri="test://asset1/", name="never_consumed", group="asset") session.add(asm) session.execute(delete(AssetDagRunQueue)) session.flush() @@ -205,11 +205,11 @@ def test_register_asset_change_notifies_asset_listener(self, session, mock_task_ asset_listener.clear() get_listener_manager().add_listener(asset_listener) - asset = Asset(uri="test_asset_uri_2") + asset = Asset(uri="test://asset1", name="test_asset_1") dag1 = DagModel(dag_id="dag3") session.add(dag1) - asm = AssetModel(uri="test_asset_uri_2") + asm = AssetModel(uri="test://asset1/", name="test_asset_1", group="asset") session.add(asm) asm.consuming_dags = [DagScheduleAssetReference(dag_id=dag1.dag_id)] session.flush() @@ -226,7 +226,7 @@ def test_create_assets_notifies_asset_listener(self, session): asset_listener.clear() get_listener_manager().add_listener(asset_listener) - asset = Asset(uri="test_asset_uri_3") + asset = Asset(uri="test://asset1", name="test_asset_1") asms = asset_manager.create_assets([asset], session=session) diff --git a/tests/dags/test_assets.py b/tests/dags/test_assets.py index 1fbc67a18d329..6a0b08f9ba6a1 100644 --- a/tests/dags/test_assets.py +++ b/tests/dags/test_assets.py @@ -25,8 +25,8 @@ from airflow.providers.standard.operators.python import PythonOperator from airflow.sdk.definitions.asset import Asset -skip_task_dag_asset = Asset("s3://dag_with_skip_task/output_1.txt", extra={"hi": "bye"}) -fail_task_dag_asset = Asset("s3://dag_with_fail_task/output_1.txt", extra={"hi": "bye"}) +skip_task_dag_asset = Asset(uri="s3://dag_with_skip_task/output_1.txt", name="skip", extra={"hi": "bye"}) +fail_task_dag_asset = Asset(uri="s3://dag_with_fail_task/output_1.txt", name="fail", extra={"hi": "bye"}) def raise_skip_exc(): diff --git a/tests/dags/test_only_empty_tasks.py b/tests/dags/test_only_empty_tasks.py index e5152f1f9ad34..92c5464982453 100644 --- a/tests/dags/test_only_empty_tasks.py +++ b/tests/dags/test_only_empty_tasks.py @@ -56,4 +56,6 @@ def __init__(self, body, *args, **kwargs): EmptyOperator(task_id="test_task_on_success", on_success_callback=lambda *args, **kwargs: None) - EmptyOperator(task_id="test_task_outlets", outlets=[Asset("hello")]) + EmptyOperator( + task_id="test_task_outlets", outlets=[Asset(name="hello", uri="test://asset1", group="test-group")] + ) diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py index a90bccafa41fd..b53a379e8611d 100644 --- a/tests/decorators/test_python.py +++ b/tests/decorators/test_python.py @@ -975,12 +975,13 @@ def test_task_decorator_asset(dag_maker, session): result = None uri = "s3://bucket/name" + asset_name = "test_asset" with dag_maker(session=session) as dag: @dag.task() def up1() -> Asset: - return Asset(uri) + return Asset(uri=uri, name=asset_name) @dag.task() def up2(src: Asset) -> str: diff --git a/tests/io/test_path.py b/tests/io/test_path.py index fd9844bc4bc58..29e67ca84649d 100644 --- a/tests/io/test_path.py +++ b/tests/io/test_path.py @@ -405,7 +405,7 @@ def test_asset(self): p = "s3" f = "bucket/object" - i = Asset(uri=f"{p}://{f}", extra={"foo": "bar"}) + i = Asset(uri=f"{p}://{f}", name="test-asset", extra={"foo": "bar"}) o = ObjectStoragePath(i) assert o.protocol == p assert o.path == f diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 4ba0b6febf8df..6b413135bfc3d 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -3979,8 +3979,8 @@ def test_create_dag_runs_assets(self, session, dag_maker): - That dag_model has next_dagrun """ - asset1 = Asset(uri="ds1") - asset2 = Asset(uri="ds2") + asset1 = Asset(uri="test://asset1", name="test_asset", group="test_group") + asset2 = Asset(uri="test://asset2", name="test_asset_2", group="test_group") with dag_maker(dag_id="assets-1", start_date=timezone.utcnow(), session=session): BashOperator(task_id="task", bash_command="echo 1", outlets=[asset1]) @@ -4075,15 +4075,14 @@ def dict_from_obj(obj): ], ) def test_no_create_dag_runs_when_dag_disabled(self, session, dag_maker, disable, enable): - ds = Asset("ds") - with dag_maker(dag_id="consumer", schedule=[ds], session=session): + asset = Asset(uri="test://asset_1", name="test_asset_1", group="test_group") + with dag_maker(dag_id="consumer", schedule=[asset], session=session): pass with dag_maker(dag_id="producer", schedule="@daily", session=session): - BashOperator(task_id="task", bash_command="echo 1", outlets=ds) + BashOperator(task_id="task", bash_command="echo 1", outlets=asset) asset_manger = AssetManager() - asset_id = session.scalars(select(AssetModel.id).filter_by(uri=ds.uri)).one() - + asset_id = session.scalars(select(AssetModel.id).filter_by(uri=asset.uri, name=asset.name)).one() ase_q = select(AssetEvent).where(AssetEvent.asset_id == asset_id).order_by(AssetEvent.timestamp) adrq_q = select(AssetDagRunQueue).where( AssetDagRunQueue.asset_id == asset_id, AssetDagRunQueue.target_dag_id == "consumer" @@ -4096,7 +4095,7 @@ def test_no_create_dag_runs_when_dag_disabled(self, session, dag_maker, disable, dr1: DagRun = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) asset_manger.register_asset_change( task_instance=dr1.get_task_instance("task", session=session), - asset=ds, + asset=asset, session=session, ) session.flush() @@ -4110,7 +4109,7 @@ def test_no_create_dag_runs_when_dag_disabled(self, session, dag_maker, disable, dr2: DagRun = dag_maker.create_dagrun_after(dr1, run_type=DagRunType.SCHEDULED) asset_manger.register_asset_change( task_instance=dr2.get_task_instance("task", session=session), - asset=ds, + asset=asset, session=session, ) session.flush() @@ -6187,11 +6186,11 @@ def _find_assets_activation(session) -> tuple[list[AssetModel], list[AssetModel] def test_asset_orphaning(self, dag_maker, session): self.job_runner = SchedulerJobRunner(job=Job(), subdir=os.devnull) - asset1 = Asset(uri="ds1") - asset2 = Asset(uri="ds2") - asset3 = Asset(uri="ds3") - asset4 = Asset(uri="ds4") - asset5 = Asset(uri="ds5") + asset1 = Asset(uri="test://asset_1", name="test_asset_1", group="test_group") + asset2 = Asset(uri="test://asset_2", name="test_asset_2", group="test_group") + asset3 = Asset(uri="test://asset_3", name="test_asset_3", group="test_group") + asset4 = Asset(uri="test://asset_4", name="test_asset_4", group="test_group") + asset5 = Asset(uri="test://asset_5", name="test_asset_5", group="test_group") with dag_maker(dag_id="assets-1", schedule=[asset1, asset2], session=session): BashOperator(task_id="task", bash_command="echo 1", outlets=[asset3, asset4]) @@ -6230,7 +6229,7 @@ def test_asset_orphaning(self, dag_maker, session): def test_asset_orphaning_ignore_orphaned_assets(self, dag_maker, session): self.job_runner = SchedulerJobRunner(job=Job(), subdir=os.devnull) - asset1 = Asset(uri="ds1") + asset1 = Asset(uri="test://asset_1", name="test_asset_1", group="test_group") with dag_maker(dag_id="assets-1", schedule=[asset1], session=session): BashOperator(task_id="task", bash_command="echo 1") @@ -6303,11 +6302,13 @@ def test_activate_referenced_assets_with_no_existing_warning(self, session): asset1 = Asset(name=asset1_name, uri="s3://bucket/key/1", extra=asset_extra) asset1_1 = Asset(name=asset1_name, uri="it's duplicate", extra=asset_extra) - dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule=[asset1, asset1_1]) + asset1_2 = Asset(name="it's also a duplicate", uri="s3://bucket/key/1", extra=asset_extra) + dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule=[asset1, asset1_1, asset1_2]) DAG.bulk_write_to_db([dag1], session=session) asset_models = session.scalars(select(AssetModel)).all() + assert len(asset_models) == 3 SchedulerJobRunner._activate_referenced_assets(asset_models, session=session) session.flush() @@ -6318,8 +6319,10 @@ def test_activate_referenced_assets_with_no_existing_warning(self, session): ) ) assert dag_warning.message == ( - "Cannot activate asset AssetModel(name='asset1', uri=\"it's duplicate\", extra={'foo': 'bar'}); " - "name is already associated to 's3://bucket/key/1'" + 'Cannot activate asset AssetModel(name="it\'s also a duplicate",' + " uri='s3://bucket/key/1', extra={'foo': 'bar'}); uri is already associated to 'asset1'\n" + "Cannot activate asset AssetModel(name='asset1', uri" + "=\"it's duplicate\", extra={'foo': 'bar'}); name is already associated to 's3://bucket/key/1'" ) def test_activate_referenced_assets_with_existing_warnings(self, session): diff --git a/tests/lineage/test_hook.py b/tests/lineage/test_hook.py index ec6390c77a555..f66f6c2bf9f83 100644 --- a/tests/lineage/test_hook.py +++ b/tests/lineage/test_hook.py @@ -46,13 +46,26 @@ def test_are_assets_collected(self): assert self.collector.collected_assets == HookLineage() input_hook = BaseHook() output_hook = BaseHook() - self.collector.add_input_asset(input_hook, uri="s3://in_bucket/file") - self.collector.add_output_asset(output_hook, uri="postgres://example.com:5432/database/default/table") + self.collector.add_input_asset(input_hook, uri="s3://in_bucket/file", name="asset-1", group="test") + self.collector.add_output_asset( + output_hook, + uri="postgres://example.com:5432/database/default/table", + ) assert self.collector.collected_assets == HookLineage( - [AssetLineageInfo(asset=Asset("s3://in_bucket/file"), count=1, context=input_hook)], [ AssetLineageInfo( - asset=Asset("postgres://example.com:5432/database/default/table"), + asset=Asset(uri="s3://in_bucket/file", name="asset-1", group="test"), + count=1, + context=input_hook, + ) + ], + [ + AssetLineageInfo( + asset=Asset( + uri="postgres://example.com:5432/database/default/table", + name="postgres://example.com:5432/database/default/table", + group="asset", + ), count=1, context=output_hook, ) @@ -68,7 +81,7 @@ def test_add_input_asset(self, mock_asset): self.collector.add_input_asset(hook, uri="test_uri") assert next(iter(self.collector._inputs.values())) == (asset, hook) - mock_asset.assert_called_once_with(uri="test_uri", extra=None) + mock_asset.assert_called_once_with(uri="test_uri") def test_grouping_assets(self): hook_1 = MagicMock() @@ -95,18 +108,29 @@ def test_grouping_assets(self): @patch("airflow.lineage.hook.ProvidersManager") def test_create_asset(self, mock_providers_manager): def create_asset(arg1, arg2="default", extra=None): - return Asset(uri=f"myscheme://{arg1}/{arg2}", extra=extra or {}) + return Asset( + uri=f"myscheme://{arg1}/{arg2}", name=f"asset-{arg1}", group="test", extra=extra or {} + ) mock_providers_manager.return_value.asset_factories = {"myscheme": create_asset} assert self.collector.create_asset( - scheme="myscheme", uri=None, asset_kwargs={"arg1": "value_1"}, asset_extra=None - ) == Asset("myscheme://value_1/default") + scheme="myscheme", + uri=None, + name=None, + group=None, + asset_kwargs={"arg1": "value_1"}, + asset_extra=None, + ) == Asset(uri="myscheme://value_1/default", name="asset-value_1", group="test") assert self.collector.create_asset( scheme="myscheme", uri=None, + name=None, + group=None, asset_kwargs={"arg1": "value_1", "arg2": "value_2"}, asset_extra={"key": "value"}, - ) == Asset("myscheme://value_1/value_2", extra={"key": "value"}) + ) == Asset( + uri="myscheme://value_1/value_2", name="asset-value_1", group="test", extra={"key": "value"} + ) @patch("airflow.lineage.hook.ProvidersManager") def test_create_asset_no_factory(self, mock_providers_manager): @@ -117,7 +141,12 @@ def test_create_asset_no_factory(self, mock_providers_manager): assert ( self.collector.create_asset( - scheme=test_scheme, uri=None, asset_kwargs=test_kwargs, asset_extra=None + scheme=test_scheme, + uri=None, + name=None, + group=None, + asset_kwargs=test_kwargs, + asset_extra=None, ) is None ) diff --git a/tests/listeners/test_asset_listener.py b/tests/listeners/test_asset_listener.py index 7acf122829d86..ace800358f249 100644 --- a/tests/listeners/test_asset_listener.py +++ b/tests/listeners/test_asset_listener.py @@ -41,9 +41,11 @@ def clean_listener_manager(): @pytest.mark.db_test @provide_session def test_asset_listener_on_asset_changed_gets_calls(create_task_instance_of_operator, session): - asset_uri = "test_asset_uri" - asset = Asset(uri=asset_uri) - asset_model = AssetModel(uri=asset_uri) + asset_uri = "test://asset/" + asset_name = "test_asset_uri" + asset_group = "test-group" + asset = Asset(uri=asset_uri, name=asset_name, group=asset_group) + asset_model = AssetModel(uri=asset_uri, name=asset_name, group=asset_group) session.add(asset_model) session.flush() @@ -59,3 +61,5 @@ def test_asset_listener_on_asset_changed_gets_calls(create_task_instance_of_oper assert len(asset_listener.changed) == 1 assert asset_listener.changed[0].uri == asset_uri + assert asset_listener.changed[0].name == asset_name + assert asset_listener.changed[0].group == asset_group diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index a651c7114d603..384d76c7548b4 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -857,15 +857,24 @@ def test_bulk_write_to_db_assets(self): """ dag_id1 = "test_asset_dag1" dag_id2 = "test_asset_dag2" + task_id = "test_asset_task" + uri1 = "s3://asset/1" - a1 = Asset(uri1, extra={"not": "used"}) - a2 = Asset("s3://asset/2") - a3 = Asset("s3://asset/3") + a1 = Asset(uri=uri1, name="test_asset_1", extra={"not": "used"}, group="test-group") + a2 = Asset(uri="s3://asset/2", name="test_asset_2", group="test-group") + a3 = Asset(uri="s3://asset/3", name="test_asset-3", group="test-group") + dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule=[a1]) EmptyOperator(task_id=task_id, dag=dag1, outlets=[a2, a3]) + dag2 = DAG(dag_id=dag_id2, start_date=DEFAULT_DATE, schedule=None) - EmptyOperator(task_id=task_id, dag=dag2, outlets=[Asset(uri1, extra={"should": "be used"})]) + EmptyOperator( + task_id=task_id, + dag=dag2, + outlets=[Asset(uri=uri1, name="test_asset_1", extra={"should": "be used"}, group="test-group")], + ) + session = settings.Session() dag1.clear() DAG.bulk_write_to_db([dag1, dag2], session=session) @@ -934,10 +943,10 @@ def test_bulk_write_to_db_does_not_activate(self, dag_maker, session): """ # Create four assets - two that have references and two that are unreferenced and marked as # orphans - asset1 = Asset(uri="ds1") - asset2 = Asset(uri="ds2") - asset3 = Asset(uri="ds3") - asset4 = Asset(uri="ds4") + asset1 = Asset(uri="test://asset1", name="asset1", group="test-group") + asset2 = Asset(uri="test://asset2", name="asset2", group="test-group") + asset3 = Asset(uri="test://asset3", name="asset3", group="test-group") + asset4 = Asset(uri="test://asset4", name="asset4", group="test-group") dag1 = DAG(dag_id="assets-1", start_date=DEFAULT_DATE, schedule=[asset1]) BashOperator(dag=dag1, task_id="task", bash_command="echo 1", outlets=[asset3]) @@ -1407,8 +1416,11 @@ def test_timetable_and_description_from_schedule_arg( assert dag.timetable.description == interval_description def test_timetable_and_description_from_asset(self): - dag = DAG("test_schedule_interval_arg", schedule=[Asset(uri="hello")], start_date=TEST_DATE) - assert dag.timetable == AssetTriggeredTimetable(Asset(uri="hello")) + uri = "test://asset" + dag = DAG( + "test_schedule_interval_arg", schedule=[Asset(uri=uri, group="test-group")], start_date=TEST_DATE + ) + assert dag.timetable == AssetTriggeredTimetable(Asset(uri=uri, group="test-group")) assert dag.timetable.description == "Triggered by assets" @pytest.mark.parametrize( @@ -2173,7 +2185,7 @@ def test_dags_needing_dagruns_not_too_early(self): session.close() def test_dags_needing_dagruns_assets(self, dag_maker, session): - asset = Asset(uri="hello") + asset = Asset(uri="test://asset", group="test-group") with dag_maker( session=session, dag_id="my_dag", @@ -2405,8 +2417,8 @@ def test__processor_dags_folder(self, session): @pytest.mark.need_serialized_dag def test_dags_needing_dagruns_asset_triggered_dag_info_queued_times(self, session, dag_maker): - asset1 = Asset(uri="ds1") - asset2 = Asset(uri="ds2") + asset1 = Asset(uri="test://asset1", group="test-group") + asset2 = Asset(uri="test://asset2", name="test_asset_2", group="test-group") for dag_id, asset in [("assets-1", asset1), ("assets-2", asset2)]: with dag_maker(dag_id=dag_id, start_date=timezone.utcnow(), session=session): @@ -2455,12 +2467,17 @@ def test_asset_expression(self, session: Session) -> None: dag = DAG( dag_id="test_dag_asset_expression", schedule=AssetAny( - Asset("s3://dag1/output_1.txt", extra={"hi": "bye"}), + Asset(uri="s3://dag1/output_1.txt", extra={"hi": "bye"}, group="test-group"), AssetAll( - Asset("s3://dag2/output_1.txt", extra={"hi": "bye"}), - Asset("s3://dag3/output_3.txt", extra={"hi": "bye"}), + Asset( + uri="s3://dag2/output_1.txt", + name="test_asset_2", + extra={"hi": "bye"}, + group="test-group", + ), + Asset("s3://dag3/output_3.txt", extra={"hi": "bye"}, group="test-group"), ), - AssetAlias(name="test_name"), + AssetAlias(name="test_name", group="test-group"), ), start_date=datetime.datetime.min, ) @@ -2469,9 +2486,32 @@ def test_asset_expression(self, session: Session) -> None: expression = session.scalars(select(DagModel.asset_expression).filter_by(dag_id=dag.dag_id)).one() assert expression == { "any": [ - "s3://dag1/output_1.txt", - {"all": ["s3://dag2/output_1.txt", "s3://dag3/output_3.txt"]}, - {"alias": "test_name"}, + { + "asset": { + "uri": "s3://dag1/output_1.txt", + "name": "s3://dag1/output_1.txt", + "group": "test-group", + } + }, + { + "all": [ + { + "asset": { + "uri": "s3://dag2/output_1.txt", + "name": "test_asset_2", + "group": "test-group", + } + }, + { + "asset": { + "uri": "s3://dag3/output_3.txt", + "name": "s3://dag3/output_3.txt", + "group": "test-group", + } + }, + ] + }, + {"alias": {"name": "test_name", "group": "test-group"}}, ] } @@ -3026,9 +3066,9 @@ def test__time_restriction(dag_maker, dag_date, tasks_date, restrict): @pytest.mark.need_serialized_dag def test_get_asset_triggered_next_run_info(dag_maker, clear_assets): - asset1 = Asset(uri="ds1") - asset2 = Asset(uri="ds2") - asset3 = Asset(uri="ds3") + asset1 = Asset(uri="test://asset1", name="test_asset1", group="test-group") + asset2 = Asset(uri="test://asset2", group="test-group") + asset3 = Asset(uri="test://asset3", group="test-group") with dag_maker(dag_id="assets-1", schedule=[asset2]): pass dag1 = dag_maker.dag diff --git a/tests/models/test_serialized_dag.py b/tests/models/test_serialized_dag.py index 011e785626e2b..41632fe0458a3 100644 --- a/tests/models/test_serialized_dag.py +++ b/tests/models/test_serialized_dag.py @@ -243,16 +243,16 @@ def test_order_of_deps_is_consistent(self): dag_id="example", start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), schedule=[ - Asset("1"), - Asset("2"), - Asset("3"), - Asset("4"), - Asset("5"), + Asset(uri="test://asset1", name="1"), + Asset(uri="test://asset2", name="2"), + Asset(uri="test://asset3", name="3"), + Asset(uri="test://asset4", name="4"), + Asset(uri="test://asset5", name="5"), ], ) as dag6: BashOperator( task_id="any", - outlets=[Asset("0*"), Asset("6*")], + outlets=[Asset(uri="test://asset0", name="0*"), Asset(uri="test://asset6", name="6*")], bash_command="sleep 5", ) deps_order = [x["dependency_id"] for x in SerializedDAG.serialize_dag(dag6)["dag_dependencies"]] diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index d7dbf54c1869e..3955d17477bab 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -88,7 +88,12 @@ from airflow.utils.xcom import XCOM_RETURN_KEY from tests_common.test_utils.compat import BaseOperatorLink -from tests_common.test_utils.mock_operators import AirflowLink2, CustomOperator, GoogleLink, MockOperator +from tests_common.test_utils.mock_operators import ( + AirflowLink2, + CustomOperator, + GoogleLink, + MockOperator, +) from tests_common.test_utils.timetables import ( CustomSerializationTimetable, cron_timetable, @@ -105,7 +110,10 @@ metadata=k8s.V1ObjectMeta(name="my-name"), spec=k8s.V1PodSpec( containers=[ - k8s.V1Container(name="base", volume_mounts=[k8s.V1VolumeMount(name="my-vol", mount_path="/vol/")]) + k8s.V1Container( + name="base", + volume_mounts=[k8s.V1VolumeMount(name="my-vol", mount_path="/vol/")], + ) ] ), ) @@ -133,7 +141,10 @@ "task_group": { "_group_id": None, "prefix_group_id": True, - "children": {"bash_task": ("operator", "bash_task"), "custom_task": ("operator", "custom_task")}, + "children": { + "bash_task": ("operator", "bash_task"), + "custom_task": ("operator", "custom_task"), + }, "tooltip": "", "ui_color": "CornflowerBlue", "ui_fgcolor": "#000", @@ -161,7 +172,10 @@ "ui_fgcolor": "#000", "template_ext": [".sh", ".bash"], "template_fields": ["bash_command", "env", "cwd"], - "template_fields_renderers": {"bash_command": "bash", "env": "json"}, + "template_fields_renderers": { + "bash_command": "bash", + "env": "json", + }, "bash_command": "echo {{ task.task_id }}", "task_type": "BashOperator", "_task_module": "airflow.providers.standard.operators.bash", @@ -223,7 +237,10 @@ "__var": { "DAGs": { "__type": "set", - "__var": [permissions.ACTION_CAN_READ, permissions.ACTION_CAN_EDIT], + "__var": [ + permissions.ACTION_CAN_READ, + permissions.ACTION_CAN_EDIT, + ], } }, } @@ -462,7 +479,10 @@ def test_dag_serialization_preserves_empty_access_roles(self): serialized_dag = SerializedDAG.to_dict(dag) SerializedDAG.validate_schema(serialized_dag) - assert serialized_dag["dag"]["access_control"] == {"__type": "dict", "__var": {}} + assert serialized_dag["dag"]["access_control"] == { + "__type": "dict", + "__var": {}, + } @pytest.mark.db_test def test_dag_serialization_unregistered_custom_timetable(self): @@ -690,14 +710,21 @@ def validate_deserialized_task( default_partial_kwargs = ( BaseOperator.partial(task_id="_")._expand(EXPAND_INPUT_EMPTY, strict=False).partial_kwargs ) - serialized_partial_kwargs = {**default_partial_kwargs, **serialized_task.partial_kwargs} + serialized_partial_kwargs = { + **default_partial_kwargs, + **serialized_task.partial_kwargs, + } original_partial_kwargs = {**default_partial_kwargs, **task.partial_kwargs} assert serialized_partial_kwargs == original_partial_kwargs @pytest.mark.parametrize( "dag_start_date, task_start_date, expected_task_start_date", [ - (datetime(2019, 8, 1, tzinfo=timezone.utc), None, datetime(2019, 8, 1, tzinfo=timezone.utc)), + ( + datetime(2019, 8, 1, tzinfo=timezone.utc), + None, + datetime(2019, 8, 1, tzinfo=timezone.utc), + ), ( datetime(2019, 8, 1, tzinfo=timezone.utc), datetime(2019, 8, 2, tzinfo=timezone.utc), @@ -749,7 +776,11 @@ def test_deserialization_with_dag_context(self): @pytest.mark.parametrize( "dag_end_date, task_end_date, expected_task_end_date", [ - (datetime(2019, 8, 1, tzinfo=timezone.utc), None, datetime(2019, 8, 1, tzinfo=timezone.utc)), + ( + datetime(2019, 8, 1, tzinfo=timezone.utc), + None, + datetime(2019, 8, 1, tzinfo=timezone.utc), + ), ( datetime(2019, 8, 1, tzinfo=timezone.utc), datetime(2019, 8, 2, tzinfo=timezone.utc), @@ -763,7 +794,12 @@ def test_deserialization_with_dag_context(self): ], ) def test_deserialization_end_date(self, dag_end_date, task_end_date, expected_task_end_date): - dag = DAG(dag_id="simple_dag", schedule=None, start_date=datetime(2019, 8, 1), end_date=dag_end_date) + dag = DAG( + dag_id="simple_dag", + schedule=None, + start_date=datetime(2019, 8, 1), + end_date=dag_end_date, + ) BaseOperator(task_id="simple_task", dag=dag, end_date=task_end_date) serialized_dag = SerializedDAG.to_dict(dag) @@ -781,7 +817,10 @@ def test_deserialization_end_date(self, dag_end_date, task_end_date, expected_ta @pytest.mark.parametrize( "serialized_timetable, expected_timetable", [ - ({"__type": "airflow.timetables.simple.NullTimetable", "__var": {}}, NullTimetable()), + ( + {"__type": "airflow.timetables.simple.NullTimetable", "__var": {}}, + NullTimetable(), + ), ( { "__type": "airflow.timetables.interval.CronDataIntervalTimetable", @@ -789,7 +828,10 @@ def test_deserialization_end_date(self, dag_end_date, task_end_date, expected_ta }, cron_timetable("0 0 * * 0"), ), - ({"__type": "airflow.timetables.simple.OnceTimetable", "__var": {}}, OnceTimetable()), + ( + {"__type": "airflow.timetables.simple.OnceTimetable", "__var": {}}, + OnceTimetable(), + ), ( { "__type": "airflow.timetables.interval.DeltaDataIntervalTimetable", @@ -848,12 +890,24 @@ def test_deserialization_timetable_unregistered(self): @pytest.mark.parametrize( "val, expected", [ - (relativedelta(days=-1), {"__type": "relativedelta", "__var": {"days": -1}}), - (relativedelta(month=1, days=-1), {"__type": "relativedelta", "__var": {"month": 1, "days": -1}}), + ( + relativedelta(days=-1), + {"__type": "relativedelta", "__var": {"days": -1}}, + ), + ( + relativedelta(month=1, days=-1), + {"__type": "relativedelta", "__var": {"month": 1, "days": -1}}, + ), # Every friday - (relativedelta(weekday=FR), {"__type": "relativedelta", "__var": {"weekday": [4]}}), + ( + relativedelta(weekday=FR), + {"__type": "relativedelta", "__var": {"weekday": [4]}}, + ), # Every second friday - (relativedelta(weekday=FR(2)), {"__type": "relativedelta", "__var": {"weekday": [4, 2]}}), + ( + relativedelta(weekday=FR(2)), + {"__type": "relativedelta", "__var": {"weekday": [4, 2]}}, + ), ], ) def test_roundtrip_relativedelta(self, val, expected): @@ -913,7 +967,11 @@ def __init__(self, path: str): schema = {"type": "string", "pattern": r"s3:\/\/(.+?)\/(.+)"} super().__init__(default=path, schema=schema) - dag = DAG(dag_id="simple_dag", schedule=None, params={"path": S3Param("s3://my_bucket/my_path")}) + dag = DAG( + dag_id="simple_dag", + schedule=None, + params={"path": S3Param("s3://my_bucket/my_path")}, + ) with pytest.raises(SerializationError): SerializedDAG.to_dict(dag) @@ -968,11 +1026,21 @@ def test_task_params_roundtrip(self, val, expected_val): dag = DAG(dag_id="simple_dag", schedule=None) if expected_val == ParamValidationError: with pytest.raises(ParamValidationError): - BaseOperator(task_id="simple_task", dag=dag, params=val, start_date=datetime(2019, 8, 1)) + BaseOperator( + task_id="simple_task", + dag=dag, + params=val, + start_date=datetime(2019, 8, 1), + ) # further tests not relevant return else: - BaseOperator(task_id="simple_task", dag=dag, params=val, start_date=datetime(2019, 8, 1)) + BaseOperator( + task_id="simple_task", + dag=dag, + params=val, + start_date=datetime(2019, 8, 1), + ) serialized_dag = SerializedDAG.to_dict(dag) deserialized_dag = SerializedDAG.from_dict(serialized_dag) @@ -1130,10 +1198,19 @@ def __ne__(self, other): ("{{ task.task_id }}", "{{ task.task_id }}"), (["{{ task.task_id }}", "{{ task.task_id }}"]), ({"foo": "{{ task.task_id }}"}, {"foo": "{{ task.task_id }}"}), - ({"foo": {"bar": "{{ task.task_id }}"}}, {"foo": {"bar": "{{ task.task_id }}"}}), ( - [{"foo1": {"bar": "{{ task.task_id }}"}}, {"foo2": {"bar": "{{ task.task_id }}"}}], - [{"foo1": {"bar": "{{ task.task_id }}"}}, {"foo2": {"bar": "{{ task.task_id }}"}}], + {"foo": {"bar": "{{ task.task_id }}"}}, + {"foo": {"bar": "{{ task.task_id }}"}}, + ), + ( + [ + {"foo1": {"bar": "{{ task.task_id }}"}}, + {"foo2": {"bar": "{{ task.task_id }}"}}, + ], + [ + {"foo1": {"bar": "{{ task.task_id }}"}}, + {"foo2": {"bar": "{{ task.task_id }}"}}, + ], ), ( {"foo": {"bar": {"{{ task.task_id }}": ["sar"]}}}, @@ -1141,7 +1218,9 @@ def __ne__(self, other): ), ( ClassWithCustomAttributes( - att1="{{ task.task_id }}", att2="{{ task.task_id }}", template_fields=["att1"] + att1="{{ task.task_id }}", + att2="{{ task.task_id }}", + template_fields=["att1"], ), "ClassWithCustomAttributes(" "{'att1': '{{ task.task_id }}', 'att2': '{{ task.task_id }}', 'template_fields': ['att1']})", @@ -1149,10 +1228,14 @@ def __ne__(self, other): ( ClassWithCustomAttributes( nested1=ClassWithCustomAttributes( - att1="{{ task.task_id }}", att2="{{ task.task_id }}", template_fields=["att1"] + att1="{{ task.task_id }}", + att2="{{ task.task_id }}", + template_fields=["att1"], ), nested2=ClassWithCustomAttributes( - att3="{{ task.task_id }}", att4="{{ task.task_id }}", template_fields=["att3"] + att3="{{ task.task_id }}", + att4="{{ task.task_id }}", + template_fields=["att3"], ), template_fields=["nested1"], ), @@ -1172,7 +1255,11 @@ def test_templated_fields_exist_in_serialized_dag(self, templated_field, expecte we want check that non-"basic" objects are turned in to strings after deserializing. """ - dag = DAG("test_serialized_template_fields", schedule=None, start_date=datetime(2019, 8, 1)) + dag = DAG( + "test_serialized_template_fields", + schedule=None, + start_date=datetime(2019, 8, 1), + ) with dag: BashOperator(task_id="test", bash_command=templated_field) @@ -1410,7 +1497,11 @@ def test_setup_teardown_tasks(self): """ logical_date = datetime(2020, 1, 1) - with DAG("test_task_group_setup_teardown_tasks", schedule=None, start_date=logical_date) as dag: + with DAG( + "test_task_group_setup_teardown_tasks", + schedule=None, + start_date=logical_date, + ) as dag: EmptyOperator(task_id="setup").as_setup() EmptyOperator(task_id="teardown").as_teardown() @@ -1580,7 +1671,11 @@ class DummyTask(BaseOperator): deps = frozenset([*BaseOperator.deps, CustomTestTriggerRule()]) logical_date = datetime(2020, 1, 1) - with DAG(dag_id="test_serialize_custom_ti_deps", schedule=None, start_date=logical_date) as dag: + with DAG( + dag_id="test_serialize_custom_ti_deps", + schedule=None, + start_date=logical_date, + ) as dag: DummyTask(task_id="task1") serialize_op = SerializedBaseOperator.serialize_operator(dag.task_dict["task1"]) @@ -1668,20 +1763,26 @@ def test_dag_deps_assets_with_duplicate_asset(self): """ from airflow.providers.standard.sensors.external_task import ExternalTaskSensor - d1 = Asset("d1") - d2 = Asset("d2") - d3 = Asset("d3") - d4 = Asset("d4") + asset1 = Asset(name="asset1", uri="test://asset1") + asset2 = Asset(name="asset2", uri="test://asset2") + asset3 = Asset(name="asset3", uri="test://asset3") + asset4 = Asset(name="asset4", uri="test://asset4") logical_date = datetime(2020, 1, 1) - with DAG(dag_id="test", start_date=logical_date, schedule=[d1, d1, d1, d1, d1]) as dag: + with DAG( + dag_id="test", start_date=logical_date, schedule=[asset1, asset1, asset1, asset1, asset1] + ) as dag: ExternalTaskSensor( task_id="task1", external_dag_id="external_dag_id", mode="reschedule", ) - BashOperator(task_id="asset_writer", bash_command="echo hello", outlets=[d2, d2, d2, d3]) + BashOperator( + task_id="asset_writer", + bash_command="echo hello", + outlets=[asset2, asset2, asset2, asset3], + ) - @dag.task(outlets=[d4]) + @dag.task(outlets=[asset4]) def other_asset_writer(x): pass @@ -1695,7 +1796,7 @@ def other_asset_writer(x): "source": "test", "target": "asset", "dependency_type": "asset", - "dependency_id": "d4", + "dependency_id": "asset4", }, { "source": "external_dag_id", @@ -1707,40 +1808,40 @@ def other_asset_writer(x): "source": "test", "target": "asset", "dependency_type": "asset", - "dependency_id": "d3", + "dependency_id": "asset3", }, { "source": "test", "target": "asset", "dependency_type": "asset", - "dependency_id": "d2", + "dependency_id": "asset2", }, { "source": "asset", "target": "test", "dependency_type": "asset", - "dependency_id": "d1", + "dependency_id": "asset1", }, { - "dependency_id": "d1", + "dependency_id": "asset1", "dependency_type": "asset", "source": "asset", "target": "test", }, { - "dependency_id": "d1", + "dependency_id": "asset1", "dependency_type": "asset", "source": "asset", "target": "test", }, { - "dependency_id": "d1", + "dependency_id": "asset1", "dependency_type": "asset", "source": "asset", "target": "test", }, { - "dependency_id": "d1", + "dependency_id": "asset1", "dependency_type": "asset", "source": "asset", "target": "test", @@ -1757,20 +1858,20 @@ def test_dag_deps_assets(self): """ from airflow.providers.standard.sensors.external_task import ExternalTaskSensor - d1 = Asset("d1") - d2 = Asset("d2") - d3 = Asset("d3") - d4 = Asset("d4") + asset1 = Asset(name="asset1", uri="test://asset1") + asset2 = Asset(name="asset2", uri="test://asset2") + asset3 = Asset(name="asset3", uri="test://asset3") + asset4 = Asset(name="asset4", uri="test://asset4") logical_date = datetime(2020, 1, 1) - with DAG(dag_id="test", start_date=logical_date, schedule=[d1]) as dag: + with DAG(dag_id="test", start_date=logical_date, schedule=[asset1]) as dag: ExternalTaskSensor( task_id="task1", external_dag_id="external_dag_id", mode="reschedule", ) - BashOperator(task_id="asset_writer", bash_command="echo hello", outlets=[d2, d3]) + BashOperator(task_id="asset_writer", bash_command="echo hello", outlets=[asset2, asset3]) - @dag.task(outlets=[d4]) + @dag.task(outlets=[asset4]) def other_asset_writer(x): pass @@ -1784,7 +1885,7 @@ def other_asset_writer(x): "source": "test", "target": "asset", "dependency_type": "asset", - "dependency_id": "d4", + "dependency_id": "asset4", }, { "source": "external_dag_id", @@ -1796,19 +1897,19 @@ def other_asset_writer(x): "source": "test", "target": "asset", "dependency_type": "asset", - "dependency_id": "d3", + "dependency_id": "asset3", }, { "source": "test", "target": "asset", "dependency_type": "asset", - "dependency_id": "d2", + "dependency_id": "asset2", }, { "source": "asset", "target": "test", "dependency_type": "asset", - "dependency_id": "d1", + "dependency_id": "asset1", }, ], key=lambda x: tuple(x.values()), @@ -1821,14 +1922,20 @@ def test_derived_dag_deps_operator(self, mapped): Tests DAG dependency detection for operators, including derived classes """ from airflow.operators.empty import EmptyOperator - from airflow.providers.standard.operators.trigger_dagrun import TriggerDagRunOperator + from airflow.providers.standard.operators.trigger_dagrun import ( + TriggerDagRunOperator, + ) class DerivedOperator(TriggerDagRunOperator): pass logical_date = datetime(2020, 1, 1) for class_ in [TriggerDagRunOperator, DerivedOperator]: - with DAG(dag_id="test_derived_dag_deps_trigger", schedule=None, start_date=logical_date) as dag: + with DAG( + dag_id="test_derived_dag_deps_trigger", + schedule=None, + start_date=logical_date, + ) as dag: task1 = EmptyOperator(task_id="task1") if mapped: task2 = class_.partial( @@ -1912,7 +2019,10 @@ def test_task_group_sorted(self): assert upstream_group_ids == ["task_group_up1", "task_group_up2"] upstream_task_ids = task_group_middle_dict["upstream_task_ids"] - assert upstream_task_ids == ["task_group_up1.task_up1", "task_group_up2.task_up2"] + assert upstream_task_ids == [ + "task_group_up1.task_up1", + "task_group_up2.task_up2", + ] downstream_group_ids = task_group_middle_dict["downstream_group_ids"] assert downstream_group_ids == ["task_group_down1", "task_group_down2"] @@ -1930,7 +2040,11 @@ def test_edge_info_serialization(self): from airflow.operators.empty import EmptyOperator from airflow.utils.edgemodifier import Label - with DAG("test_edge_info_serialization", schedule=None, start_date=datetime(2020, 1, 1)) as dag: + with DAG( + "test_edge_info_serialization", + schedule=None, + start_date=datetime(2020, 1, 1), + ) as dag: task1 = EmptyOperator(task_id="task1") task2 = EmptyOperator(task_id="task2") task1 >> Label("test label") >> task2 @@ -2024,7 +2138,11 @@ def test_dag_on_failure_callback_roundtrip(self, passed_failure_callback, expect When the callback is not set, has_on_failure_callback should not be stored in Serialized blob and so default to False on de-serialization """ - dag = DAG(dag_id="test_dag_on_failure_callback_roundtrip", schedule=None, **passed_failure_callback) + dag = DAG( + dag_id="test_dag_on_failure_callback_roundtrip", + schedule=None, + **passed_failure_callback, + ) BaseOperator(task_id="simple_task", dag=dag, start_date=datetime(2019, 8, 1)) serialized_dag = SerializedDAG.to_dict(dag) @@ -2116,7 +2234,12 @@ def test_params_serialization_from_dict_upgrade(self): "fileloc": "/path/to/file.py", "tasks": [], "timezone": "UTC", - "params": {"my_param": {"__class": "airflow.models.param.Param", "default": "str"}}, + "params": { + "my_param": { + "__class": "airflow.models.param.Param", + "default": "str", + } + }, }, } dag = SerializedDAG.from_dict(serialized) @@ -2265,7 +2388,10 @@ def execute_complete(self): "__type": "START_TRIGGER_ARGS", "trigger_cls": "airflow.providers.standard.triggers.temporal.TimeDeltaTrigger", # "trigger_kwargs": {"__type": "dict", "__var": {"delta": {"__type": "timedelta", "__var": 2.0}}}, - "trigger_kwargs": {"__type": "dict", "__var": {"delta": {"__type": "timedelta", "__var": 2.0}}}, + "trigger_kwargs": { + "__type": "dict", + "__var": {"delta": {"__type": "timedelta", "__var": 2.0}}, + }, "next_method": "execute_complete", "next_kwargs": None, "timeout": None, @@ -2400,7 +2526,12 @@ def test_operator_expand_xcomarg_serde(): "type": "dict-of-lists", "value": { "__type": "dict", - "__var": {"arg2": {"__type": "xcomref", "__var": {"task_id": "op1", "key": "return_value"}}}, + "__var": { + "arg2": { + "__type": "xcomref", + "__var": {"task_id": "op1", "key": "return_value"}, + } + }, }, }, "partial_kwargs": {}, @@ -2457,7 +2588,12 @@ def test_operator_expand_kwargs_literal_serde(strict): {"__type": "dict", "__var": {"a": "x"}}, { "__type": "dict", - "__var": {"a": {"__type": "xcomref", "__var": {"task_id": "op1", "key": "return_value"}}}, + "__var": { + "a": { + "__type": "xcomref", + "__var": {"task_id": "op1", "key": "return_value"}, + } + }, }, ], }, @@ -2481,12 +2617,18 @@ def test_operator_expand_kwargs_literal_serde(strict): # The XComArg can't be deserialized before the DAG is. expand_value = op.expand_input.value - assert expand_value == [{"a": "x"}, {"a": _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY})}] + assert expand_value == [ + {"a": "x"}, + {"a": _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY})}, + ] serialized_dag: DAG = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) resolved_expand_value = serialized_dag.task_dict["task_2"].expand_input.value - resolved_expand_value == [{"a": "x"}, {"a": PlainXComArg(serialized_dag.task_dict["op1"])}] + resolved_expand_value == [ + {"a": "x"}, + {"a": PlainXComArg(serialized_dag.task_dict["op1"])}, + ] @pytest.mark.parametrize("strict", [True, False]) @@ -2508,7 +2650,10 @@ def test_operator_expand_kwargs_xcomarg_serde(strict): "downstream_task_ids": [], "expand_input": { "type": "list-of-dicts", - "value": {"__type": "xcomref", "__var": {"task_id": "op1", "key": "return_value"}}, + "value": { + "__type": "xcomref", + "__var": {"task_id": "op1", "key": "return_value"}, + }, }, "partial_kwargs": {}, "task_id": "task_2", @@ -2640,7 +2785,10 @@ def x(arg1, arg2, arg3): "__type": "dict", "__var": { "arg2": {"__type": "dict", "__var": {"a": 1, "b": 2}}, - "arg3": {"__type": "xcomref", "__var": {"task_id": "op1", "key": "return_value"}}, + "arg3": { + "__type": "xcomref", + "__var": {"task_id": "op1", "key": "return_value"}, + }, }, }, }, @@ -2650,7 +2798,11 @@ def x(arg1, arg2, arg3): "task_id": "x", "template_ext": [], "template_fields": ["templates_dict", "op_args", "op_kwargs"], - "template_fields_renderers": {"templates_dict": "json", "op_args": "py", "op_kwargs": "py"}, + "template_fields_renderers": { + "templates_dict": "json", + "op_args": "py", + "op_kwargs": "py", + }, "_disallow_kwargs_override": False, "_expand_input_attr": "op_kwargs_expand_input", "python_callable_name": qualname(x), @@ -2666,7 +2818,10 @@ def x(arg1, arg2, arg3): assert deserialized.op_kwargs_expand_input == _ExpandInputRef( key="dict-of-lists", - value={"arg2": {"a": 1, "b": 2}, "arg3": _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY})}, + value={ + "arg2": {"a": 1, "b": 2}, + "arg3": _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY}), + }, ) assert deserialized.partial_kwargs == { "is_setup": False, @@ -2688,7 +2843,10 @@ def x(arg1, arg2, arg3): pickled = pickle.loads(pickle.dumps(deserialized)) assert pickled.op_kwargs_expand_input == _ExpandInputRef( key="dict-of-lists", - value={"arg2": {"a": 1, "b": 2}, "arg3": _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY})}, + value={ + "arg2": {"a": 1, "b": 2}, + "arg3": _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY}), + }, ) assert pickled.partial_kwargs == { "is_setup": False, @@ -2753,7 +2911,11 @@ def x(arg1, arg2, arg3): "task_id": "x", "template_ext": [], "template_fields": ["templates_dict", "op_args", "op_kwargs"], - "template_fields_renderers": {"templates_dict": "json", "op_args": "py", "op_kwargs": "py"}, + "template_fields_renderers": { + "templates_dict": "json", + "op_args": "py", + "op_kwargs": "py", + }, "_disallow_kwargs_override": strict, "_expand_input_attr": "op_kwargs_expand_input", } diff --git a/tests/serialization/test_serde.py b/tests/serialization/test_serde.py index a3a946124ff9d..2fc8ad8d17b3d 100644 --- a/tests/serialization/test_serde.py +++ b/tests/serialization/test_serde.py @@ -365,7 +365,7 @@ def test_backwards_compat_wrapped(self): assert e["extra"] == {"hi": "bye"} def test_encode_asset(self): - asset = Asset("mytest://asset") + asset = Asset(uri="mytest://asset", name="test") obj = deserialize(serialize(asset)) assert asset.uri == obj.uri diff --git a/tests/serialization/test_serialized_objects.py b/tests/serialization/test_serialized_objects.py index 75ff736be8733..3e8e844528822 100644 --- a/tests/serialization/test_serialized_objects.py +++ b/tests/serialization/test_serialized_objects.py @@ -187,7 +187,11 @@ def __len__(self) -> int: (timezone.utcnow(), DAT.DATETIME, equal_time), (timedelta(minutes=2), DAT.TIMEDELTA, equals), (Timezone("UTC"), DAT.TIMEZONE, lambda a, b: a.name == b.name), - (relativedelta.relativedelta(hours=+1), DAT.RELATIVEDELTA, lambda a, b: a.hours == b.hours), + ( + relativedelta.relativedelta(hours=+1), + DAT.RELATIVEDELTA, + lambda a, b: a.hours == b.hours, + ), ({"test": "dict", "test-1": 1}, None, equals), (["array_item", 2], None, equals), (("tuple_item", 3), DAT.TUPLE, equals), @@ -195,7 +199,9 @@ def __len__(self) -> int: ( k8s.V1Pod( metadata=k8s.V1ObjectMeta( - name="test", annotations={"test": "annotation"}, creation_timestamp=timezone.utcnow() + name="test", + annotations={"test": "annotation"}, + creation_timestamp=timezone.utcnow(), ) ), DAT.POD, @@ -214,7 +220,14 @@ def __len__(self) -> int: ), (Resources(cpus=0.1, ram=2048), None, None), (EmptyOperator(task_id="test-task"), None, None), - (TaskGroup(group_id="test-group", dag=DAG(dag_id="test_dag", start_date=datetime.now())), None, None), + ( + TaskGroup( + group_id="test-group", + dag=DAG(dag_id="test_dag", start_date=datetime.now()), + ), + None, + None, + ), ( Param("test", "desc"), DAT.PARAM, @@ -231,8 +244,12 @@ def __len__(self) -> int: DAT.XCOM_REF, None, ), - (MockLazySelectSequence(), None, lambda a, b: len(a) == len(b) and isinstance(b, list)), - (Asset(uri="test"), DAT.ASSET, equals), + ( + MockLazySelectSequence(), + None, + lambda a, b: len(a) == len(b) and isinstance(b, list), + ), + (Asset(uri="test://asset1", name="test"), DAT.ASSET, equals), (SimpleTaskInstance.from_ti(ti=TI), DAT.SIMPLE_TASK_INSTANCE, equals), ( Connection(conn_id="TEST_ID", uri="mysql://"), @@ -240,16 +257,24 @@ def __len__(self) -> int: lambda a, b: a.get_uri() == b.get_uri(), ), ( - OutletEventAccessor(raw_key=Asset(uri="test"), extra={"key": "value"}, asset_alias_events=[]), + OutletEventAccessor( + raw_key=Asset(uri="test://asset1", name="test", group="test-group"), + extra={"key": "value"}, + asset_alias_events=[], + ), DAT.ASSET_EVENT_ACCESSOR, equal_outlet_event_accessor, ), ( OutletEventAccessor( - raw_key=AssetAlias(name="test_alias"), + raw_key=AssetAlias(name="test_alias", group="test-alias-group"), extra={"key": "value"}, asset_alias_events=[ - AssetAliasEvent(source_alias_name="test_alias", dest_asset_uri="test_uri", extra={}) + AssetAliasEvent( + source_alias_name="test_alias", + dest_asset_uri="test_uri", + extra={}, + ) ], ), DAT.ASSET_EVENT_ACCESSOR, @@ -295,7 +320,10 @@ def test_serialize_deserialize(input, encoded_type, cmp_func): "conn_uri", [ pytest.param("aws://", id="only-conn-type"), - pytest.param("postgres://username:password@ec2.compute.com:5432/the_database", id="all-non-extra"), + pytest.param( + "postgres://username:password@ec2.compute.com:5432/the_database", + id="all-non-extra", + ), pytest.param( "///?__extra__=%7B%22foo%22%3A+%22bar%22%2C+%22answer%22%3A+42%2C+%22" "nullable%22%3A+null%2C+%22empty%22%3A+%22%22%2C+%22zero%22%3A+0%7D", @@ -307,7 +335,10 @@ def test_backcompat_deserialize_connection(conn_uri): """Test deserialize connection which serialised by previous serializer implementation.""" from airflow.serialization.serialized_objects import BaseSerialization - conn_obj = {Encoding.TYPE: DAT.CONNECTION, Encoding.VAR: {"conn_id": "TEST_ID", "uri": conn_uri}} + conn_obj = { + Encoding.TYPE: DAT.CONNECTION, + Encoding.VAR: {"conn_id": "TEST_ID", "uri": conn_uri}, + } deserialized = BaseSerialization.deserialize(conn_obj) assert deserialized.get_uri() == conn_uri @@ -323,10 +354,13 @@ def test_backcompat_deserialize_connection(conn_uri): is_paused=True, ), LogTemplatePydantic: LogTemplate( - id=1, filename="test_file", elasticsearch_id="test_id", created_at=datetime.now() + id=1, + filename="test_file", + elasticsearch_id="test_id", + created_at=datetime.now(), ), DagTagPydantic: DagTag(), - AssetPydantic: Asset("uri", extra={}), + AssetPydantic: Asset(name="test", uri="test://asset1", extra={}), AssetEventPydantic: AssetEvent(), } diff --git a/tests/timetables/test_assets_timetable.py b/tests/timetables/test_assets_timetable.py index 9d572295773a1..c8c889f603ca4 100644 --- a/tests/timetables/test_assets_timetable.py +++ b/tests/timetables/test_assets_timetable.py @@ -105,7 +105,7 @@ def test_timetable() -> MockTimetable: @pytest.fixture def test_assets() -> list[Asset]: """Pytest fixture for creating a list of Asset objects.""" - return [Asset("test_asset")] + return [Asset(name="test_asset", uri="test://asset")] @pytest.fixture @@ -134,7 +134,15 @@ def test_serialization(asset_timetable: AssetOrTimeSchedule, monkeypatch: Any) - "timetable": "mock_serialized_timetable", "asset_condition": { "__type": "asset_all", - "objects": [{"__type": "asset", "uri": "test_asset", "name": "test_asset", "extra": {}}], + "objects": [ + { + "__type": "asset", + "name": "test_asset", + "uri": "test://asset/", + "group": "asset", + "extra": {}, + } + ], }, } @@ -152,7 +160,15 @@ def test_deserialization(monkeypatch: Any) -> None: "timetable": "mock_serialized_timetable", "asset_condition": { "__type": "asset_all", - "objects": [{"__type": "asset", "name": "test_asset", "uri": "test_asset", "extra": None}], + "objects": [ + { + "__type": "asset", + "name": "test_asset", + "uri": "test://asset/", + "group": "asset", + "extra": None, + } + ], }, } deserialized = AssetOrTimeSchedule.deserialize(mock_serialized_data) diff --git a/tests/utils/test_json.py b/tests/utils/test_json.py index b99681c22318a..d5d0cdb32e8e2 100644 --- a/tests/utils/test_json.py +++ b/tests/utils/test_json.py @@ -86,7 +86,7 @@ def test_encode_raises(self): ) def test_encode_xcom_asset(self): - asset = Asset("mytest://asset") + asset = Asset(uri="mytest://asset", name="mytest") s = json.dumps(asset, cls=utils_json.XComEncoder) obj = json.loads(s, cls=utils_json.XComDecoder) assert asset.uri == obj.uri diff --git a/tests/www/views/test_views_asset.py b/tests/www/views/test_views_asset.py index e4fda0aeac662..2e6668f134ad3 100644 --- a/tests/www/views/test_views_asset.py +++ b/tests/www/views/test_views_asset.py @@ -42,7 +42,10 @@ def _cleanup(self): @pytest.fixture def create_assets(self, session): def create(indexes): - assets = [AssetModel(id=i, uri=f"s3://bucket/key/{i}") for i in indexes] + assets = [ + AssetModel(id=i, uri=f"s3://bucket/key/{i}", name=f"asset-{i}", group="asset") + for i in indexes + ] session.add_all(assets) session.flush() session.add_all(AssetActive.for_asset(a) for a in assets) @@ -220,7 +223,7 @@ def test_search_uri_pattern(self, admin_client, create_assets, session): @pytest.mark.need_serialized_dag def test_correct_counts_update(self, admin_client, session, dag_maker, app, monkeypatch): with monkeypatch.context() as m: - assets = [Asset(uri=f"s3://bucket/key/{i}") for i in [1, 2, 3, 4, 5]] + assets = [Asset(uri=f"s3://bucket/key/{i}", name=f"asset-{i}") for i in range(1, 6)] # DAG that produces asset #1 with dag_maker(dag_id="upstream", schedule=None, serialized=True, session=session): @@ -399,7 +402,9 @@ def test_should_return_max_if_req_above(self, admin_client, create_assets, sessi class TestGetAssetNextRunSummary(TestAssetEndpoint): def test_next_run_asset_summary(self, dag_maker, admin_client): - with dag_maker(dag_id="upstream", schedule=[Asset(uri="s3://bucket/key/1")], serialized=True): + with dag_maker( + dag_id="upstream", schedule=[Asset(uri="s3://bucket/key/1", name="asset-1")], serialized=True + ): EmptyOperator(task_id="task1") response = admin_client.post("/next_run_assets_summary", data={"dag_ids": ["upstream"]}) diff --git a/tests/www/views/test_views_grid.py b/tests/www/views/test_views_grid.py index 067ca9325f6c0..e2181aa702d25 100644 --- a/tests/www/views/test_views_grid.py +++ b/tests/www/views/test_views_grid.py @@ -431,8 +431,8 @@ def test_has_outlet_asset_flag(admin_client, dag_maker, session, app, monkeypatc lineagefile = File("/tmp/does_not_exist") EmptyOperator(task_id="task1") EmptyOperator(task_id="task2", outlets=[lineagefile]) - EmptyOperator(task_id="task3", outlets=[Asset("foo"), lineagefile]) - EmptyOperator(task_id="task4", outlets=[Asset("foo")]) + EmptyOperator(task_id="task3", outlets=[Asset(name="foo", uri="s3://bucket/key"), lineagefile]) + EmptyOperator(task_id="task4", outlets=[Asset(name="foo", uri="s3://bucket/key")]) m.setattr(app, "dag_bag", dag_maker.dagbag) resp = admin_client.get(f"/object/grid_data?dag_id={DAG_ID}", follow_redirects=True) @@ -471,7 +471,7 @@ def _expected_task_details(task_id, has_outlet_assets): @pytest.mark.need_serialized_dag def test_next_run_assets(admin_client, dag_maker, session, app, monkeypatch): with monkeypatch.context() as m: - assets = [Asset(uri=f"s3://bucket/key/{i}") for i in [1, 2]] + assets = [Asset(uri=f"s3://bucket/key/{i}", name=f"name_{i}", group="test-group") for i in [1, 2]] with dag_maker(dag_id=DAG_ID, schedule=assets, serialized=True, session=session): EmptyOperator(task_id="task1") @@ -508,7 +508,12 @@ def test_next_run_assets(admin_client, dag_maker, session, app, monkeypatch): assert resp.status_code == 200, resp.json assert resp.json == { - "asset_expression": {"all": ["s3://bucket/key/1", "s3://bucket/key/2"]}, + "asset_expression": { + "all": [ + {"asset": {"uri": "s3://bucket/key/1", "name": "name_1", "group": "test-group"}}, + {"asset": {"uri": "s3://bucket/key/2", "name": "name_2", "group": "test-group"}}, + ] + }, "events": [ {"id": asset1_id, "uri": "s3://bucket/key/1", "lastUpdate": "2022-08-02T02:00:00+00:00"}, {"id": asset2_id, "uri": "s3://bucket/key/2", "lastUpdate": None},