From 4ba3ee4f383d1148c4e087d0bf94d37d80b70ce4 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 7 Nov 2024 15:28:32 +0800 Subject: [PATCH 01/51] test(tests/www/views/test_views_grid): extend Asset test cases to include both uri and name --- tests/www/views/test_views_grid.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/www/views/test_views_grid.py b/tests/www/views/test_views_grid.py index 067ca9325f6c0..3a34ca1b6dd89 100644 --- a/tests/www/views/test_views_grid.py +++ b/tests/www/views/test_views_grid.py @@ -431,8 +431,8 @@ def test_has_outlet_asset_flag(admin_client, dag_maker, session, app, monkeypatc lineagefile = File("/tmp/does_not_exist") EmptyOperator(task_id="task1") EmptyOperator(task_id="task2", outlets=[lineagefile]) - EmptyOperator(task_id="task3", outlets=[Asset("foo"), lineagefile]) - EmptyOperator(task_id="task4", outlets=[Asset("foo")]) + EmptyOperator(task_id="task3", outlets=[Asset(name="foo", uri="s3://bucket/key"), lineagefile]) + EmptyOperator(task_id="task4", outlets=[Asset(name="foo", uri="s3://bucket/key")]) m.setattr(app, "dag_bag", dag_maker.dagbag) resp = admin_client.get(f"/object/grid_data?dag_id={DAG_ID}", follow_redirects=True) @@ -471,7 +471,7 @@ def _expected_task_details(task_id, has_outlet_assets): @pytest.mark.need_serialized_dag def test_next_run_assets(admin_client, dag_maker, session, app, monkeypatch): with monkeypatch.context() as m: - assets = [Asset(uri=f"s3://bucket/key/{i}") for i in [1, 2]] + assets = [Asset(uri=f"s3://bucket/key/{i}", name=f"name_{i}", group="test-group") for i in [1, 2]] with dag_maker(dag_id=DAG_ID, schedule=assets, serialized=True, session=session): EmptyOperator(task_id="task1") From 2970a8a9648369db856ccdf43c75b3039ce93cd0 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 7 Nov 2024 15:29:52 +0800 Subject: [PATCH 02/51] test(utils/test_json): extend Asset test cases to include both uri and name --- tests/utils/test_json.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/test_json.py b/tests/utils/test_json.py index b99681c22318a..d5d0cdb32e8e2 100644 --- a/tests/utils/test_json.py +++ b/tests/utils/test_json.py @@ -86,7 +86,7 @@ def test_encode_raises(self): ) def test_encode_xcom_asset(self): - asset = Asset("mytest://asset") + asset = Asset(uri="mytest://asset", name="mytest") s = json.dumps(asset, cls=utils_json.XComEncoder) obj = json.loads(s, cls=utils_json.XComDecoder) assert asset.uri == obj.uri From b5317feffbf0d8bed4bc919bb2dc8bcee3e7266a Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 7 Nov 2024 15:33:05 +0800 Subject: [PATCH 03/51] test(timetables/test_assets_timetable): extend Asset test cases to include both uri and name --- tests/timetables/test_assets_timetable.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/tests/timetables/test_assets_timetable.py b/tests/timetables/test_assets_timetable.py index 9d572295773a1..c8c889f603ca4 100644 --- a/tests/timetables/test_assets_timetable.py +++ b/tests/timetables/test_assets_timetable.py @@ -105,7 +105,7 @@ def test_timetable() -> MockTimetable: @pytest.fixture def test_assets() -> list[Asset]: """Pytest fixture for creating a list of Asset objects.""" - return [Asset("test_asset")] + return [Asset(name="test_asset", uri="test://asset")] @pytest.fixture @@ -134,7 +134,15 @@ def test_serialization(asset_timetable: AssetOrTimeSchedule, monkeypatch: Any) - "timetable": "mock_serialized_timetable", "asset_condition": { "__type": "asset_all", - "objects": [{"__type": "asset", "uri": "test_asset", "name": "test_asset", "extra": {}}], + "objects": [ + { + "__type": "asset", + "name": "test_asset", + "uri": "test://asset/", + "group": "asset", + "extra": {}, + } + ], }, } @@ -152,7 +160,15 @@ def test_deserialization(monkeypatch: Any) -> None: "timetable": "mock_serialized_timetable", "asset_condition": { "__type": "asset_all", - "objects": [{"__type": "asset", "name": "test_asset", "uri": "test_asset", "extra": None}], + "objects": [ + { + "__type": "asset", + "name": "test_asset", + "uri": "test://asset/", + "group": "asset", + "extra": None, + } + ], }, } deserialized = AssetOrTimeSchedule.deserialize(mock_serialized_data) From ba8f0440c7d73dd112ff1f3577a73cd4e27849ef Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 12 Nov 2024 17:24:11 +0800 Subject: [PATCH 04/51] test(listeners/test_asset_listener): extend Asset test cases to include both uri and name --- tests/listeners/test_asset_listener.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/listeners/test_asset_listener.py b/tests/listeners/test_asset_listener.py index 7acf122829d86..ace800358f249 100644 --- a/tests/listeners/test_asset_listener.py +++ b/tests/listeners/test_asset_listener.py @@ -41,9 +41,11 @@ def clean_listener_manager(): @pytest.mark.db_test @provide_session def test_asset_listener_on_asset_changed_gets_calls(create_task_instance_of_operator, session): - asset_uri = "test_asset_uri" - asset = Asset(uri=asset_uri) - asset_model = AssetModel(uri=asset_uri) + asset_uri = "test://asset/" + asset_name = "test_asset_uri" + asset_group = "test-group" + asset = Asset(uri=asset_uri, name=asset_name, group=asset_group) + asset_model = AssetModel(uri=asset_uri, name=asset_name, group=asset_group) session.add(asset_model) session.flush() @@ -59,3 +61,5 @@ def test_asset_listener_on_asset_changed_gets_calls(create_task_instance_of_oper assert len(asset_listener.changed) == 1 assert asset_listener.changed[0].uri == asset_uri + assert asset_listener.changed[0].name == asset_name + assert asset_listener.changed[0].group == asset_group From a4a801987e53118b753962a9f8b779614509a217 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 12 Nov 2024 18:03:41 +0800 Subject: [PATCH 05/51] test(jobs/test_scheduler_job): extend Asset test cases to include both uri and name --- tests/jobs/test_scheduler_job.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 4ba0b6febf8df..ccbd32cca2741 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -3979,8 +3979,8 @@ def test_create_dag_runs_assets(self, session, dag_maker): - That dag_model has next_dagrun """ - asset1 = Asset(uri="ds1") - asset2 = Asset(uri="ds2") + asset1 = Asset(uri="test://asset1", name="test_asset", group="test_group") + asset2 = Asset(uri="test://asset2", name="test_asset_2", group="test_group") with dag_maker(dag_id="assets-1", start_date=timezone.utcnow(), session=session): BashOperator(task_id="task", bash_command="echo 1", outlets=[asset1]) @@ -4075,15 +4075,14 @@ def dict_from_obj(obj): ], ) def test_no_create_dag_runs_when_dag_disabled(self, session, dag_maker, disable, enable): - ds = Asset("ds") - with dag_maker(dag_id="consumer", schedule=[ds], session=session): + asset = Asset(uri="test://asset_1", name="test_asset_1", group="test_group") + with dag_maker(dag_id="consumer", schedule=[asset], session=session): pass with dag_maker(dag_id="producer", schedule="@daily", session=session): - BashOperator(task_id="task", bash_command="echo 1", outlets=ds) + BashOperator(task_id="task", bash_command="echo 1", outlets=asset) asset_manger = AssetManager() - asset_id = session.scalars(select(AssetModel.id).filter_by(uri=ds.uri)).one() - + asset_id = session.scalars(select(AssetModel.id).filter_by(uri=asset.uri, name=asset.name)).one() ase_q = select(AssetEvent).where(AssetEvent.asset_id == asset_id).order_by(AssetEvent.timestamp) adrq_q = select(AssetDagRunQueue).where( AssetDagRunQueue.asset_id == asset_id, AssetDagRunQueue.target_dag_id == "consumer" @@ -4096,7 +4095,7 @@ def test_no_create_dag_runs_when_dag_disabled(self, session, dag_maker, disable, dr1: DagRun = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) asset_manger.register_asset_change( task_instance=dr1.get_task_instance("task", session=session), - asset=ds, + asset=asset, session=session, ) session.flush() @@ -4110,7 +4109,7 @@ def test_no_create_dag_runs_when_dag_disabled(self, session, dag_maker, disable, dr2: DagRun = dag_maker.create_dagrun_after(dr1, run_type=DagRunType.SCHEDULED) asset_manger.register_asset_change( task_instance=dr2.get_task_instance("task", session=session), - asset=ds, + asset=asset, session=session, ) session.flush() @@ -6187,11 +6186,11 @@ def _find_assets_activation(session) -> tuple[list[AssetModel], list[AssetModel] def test_asset_orphaning(self, dag_maker, session): self.job_runner = SchedulerJobRunner(job=Job(), subdir=os.devnull) - asset1 = Asset(uri="ds1") - asset2 = Asset(uri="ds2") - asset3 = Asset(uri="ds3") - asset4 = Asset(uri="ds4") - asset5 = Asset(uri="ds5") + asset1 = Asset(uri="test://asset_1", name="test_asset_1", group="test_group") + asset2 = Asset(uri="test://asset_2", name="test_asset_2", group="test_group") + asset3 = Asset(uri="test://asset_3", name="test_asset_3", group="test_group") + asset4 = Asset(uri="test://asset_4", name="test_asset_4", group="test_group") + asset5 = Asset(uri="test://asset_5", name="test_asset_5", group="test_group") with dag_maker(dag_id="assets-1", schedule=[asset1, asset2], session=session): BashOperator(task_id="task", bash_command="echo 1", outlets=[asset3, asset4]) @@ -6230,7 +6229,7 @@ def test_asset_orphaning(self, dag_maker, session): def test_asset_orphaning_ignore_orphaned_assets(self, dag_maker, session): self.job_runner = SchedulerJobRunner(job=Job(), subdir=os.devnull) - asset1 = Asset(uri="ds1") + asset1 = Asset(uri="test://asset_1", name="test_asset_1", group="test_group") with dag_maker(dag_id="assets-1", schedule=[asset1], session=session): BashOperator(task_id="task", bash_command="echo 1") From 5ebccb42018234d22dd242c1a1a90369bd4d3063 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 12 Nov 2024 19:23:31 +0800 Subject: [PATCH 06/51] test(providers/openlineage): extend Asset test cases to include both uri and name --- .../tests/openlineage/plugins/test_utils.py | 30 +++++++++++++++---- 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/providers/tests/openlineage/plugins/test_utils.py b/providers/tests/openlineage/plugins/test_utils.py index e84fac1186573..3d41e87cf0152 100644 --- a/providers/tests/openlineage/plugins/test_utils.py +++ b/providers/tests/openlineage/plugins/test_utils.py @@ -334,9 +334,9 @@ def test_serialize_timetable(): from airflow.timetables.simple import AssetTriggeredTimetable asset = AssetAny( - Asset("2"), - AssetAlias("example-alias"), - Asset("3"), + Asset(name="2", uri="test://2", group="test-group"), + AssetAlias(name="example-alias", group="test-group"), + Asset(name="3", uri="test://3", group="test-group"), AssetAll(AssetAlias("this-should-not-be-seen"), Asset("4")), ) dag = MagicMock() @@ -347,14 +347,32 @@ def test_serialize_timetable(): "asset_condition": { "__type": DagAttributeTypes.ASSET_ANY, "objects": [ - {"__type": DagAttributeTypes.ASSET, "extra": {}, "name": "2", "uri": "2"}, + { + "__type": DagAttributeTypes.ASSET, + "extra": {}, + "uri": "test://2/", + "name": "2", + "group": "test-group", + }, {"__type": DagAttributeTypes.ASSET_ANY, "objects": []}, - {"__type": DagAttributeTypes.ASSET, "extra": {}, "name": "3", "uri": "3"}, + { + "__type": DagAttributeTypes.ASSET, + "extra": {}, + "uri": "test://3/", + "name": "3", + "group": "test-group", + }, { "__type": DagAttributeTypes.ASSET_ALL, "objects": [ {"__type": DagAttributeTypes.ASSET_ANY, "objects": []}, - {"__type": DagAttributeTypes.ASSET, "extra": {}, "name": "4", "uri": "4"}, + { + "__type": DagAttributeTypes.ASSET, + "extra": {}, + "uri": "4", + "name": "4", + "group": "asset", + }, ], }, ], From d61b072d201e3be589688091ade1a91b74ecb80a Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 13 Nov 2024 17:35:47 +0800 Subject: [PATCH 07/51] test(decorators/test_python): extend Asset test cases to include both uri and name --- tests/decorators/test_python.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py index a90bccafa41fd..b53a379e8611d 100644 --- a/tests/decorators/test_python.py +++ b/tests/decorators/test_python.py @@ -975,12 +975,13 @@ def test_task_decorator_asset(dag_maker, session): result = None uri = "s3://bucket/name" + asset_name = "test_asset" with dag_maker(session=session) as dag: @dag.task() def up1() -> Asset: - return Asset(uri) + return Asset(uri=uri, name=asset_name) @dag.task() def up2(src: Asset) -> str: From 55973997dab22a1aab2be348481ab89b709870f6 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 13 Nov 2024 18:20:55 +0800 Subject: [PATCH 08/51] test(models/test_dag): extend asset test cases to cover name, uri, group --- tests/models/test_dag.py | 84 +++++++++++++++++++++++++++++----------- 1 file changed, 62 insertions(+), 22 deletions(-) diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index a651c7114d603..97d11aa909fce 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -857,15 +857,24 @@ def test_bulk_write_to_db_assets(self): """ dag_id1 = "test_asset_dag1" dag_id2 = "test_asset_dag2" + task_id = "test_asset_task" + uri1 = "s3://asset/1" - a1 = Asset(uri1, extra={"not": "used"}) - a2 = Asset("s3://asset/2") - a3 = Asset("s3://asset/3") + a1 = Asset(uri=uri1, name="test_asset_1", extra={"not": "used"}, group="test-group") + a2 = Asset(uri="s3://asset/2", name="test_asset_2", group="test-group") + a3 = Asset(uri="s3://asset/3", name="test_asset-3", group="test-group") + dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule=[a1]) EmptyOperator(task_id=task_id, dag=dag1, outlets=[a2, a3]) + dag2 = DAG(dag_id=dag_id2, start_date=DEFAULT_DATE, schedule=None) - EmptyOperator(task_id=task_id, dag=dag2, outlets=[Asset(uri1, extra={"should": "be used"})]) + EmptyOperator( + task_id=task_id, + dag=dag2, + outlets=[Asset(uri=uri1, name="test_asset_1", extra={"should": "be used"}, group="test-group")], + ) + session = settings.Session() dag1.clear() DAG.bulk_write_to_db([dag1, dag2], session=session) @@ -934,10 +943,10 @@ def test_bulk_write_to_db_does_not_activate(self, dag_maker, session): """ # Create four assets - two that have references and two that are unreferenced and marked as # orphans - asset1 = Asset(uri="ds1") - asset2 = Asset(uri="ds2") - asset3 = Asset(uri="ds3") - asset4 = Asset(uri="ds4") + asset1 = Asset(uri="test://asset1", name="asset1", group="test-group") + asset2 = Asset(uri="test://asset2", name="asset2", group="test-group") + asset3 = Asset(uri="test://asset3", name="asset3", group="test-group") + asset4 = Asset(uri="test://asset4", name="asset4", group="test-group") dag1 = DAG(dag_id="assets-1", start_date=DEFAULT_DATE, schedule=[asset1]) BashOperator(dag=dag1, task_id="task", bash_command="echo 1", outlets=[asset3]) @@ -1407,8 +1416,11 @@ def test_timetable_and_description_from_schedule_arg( assert dag.timetable.description == interval_description def test_timetable_and_description_from_asset(self): - dag = DAG("test_schedule_interval_arg", schedule=[Asset(uri="hello")], start_date=TEST_DATE) - assert dag.timetable == AssetTriggeredTimetable(Asset(uri="hello")) + uri = "test://asset" + dag = DAG( + "test_schedule_interval_arg", schedule=[Asset(uri=uri, group="test-group")], start_date=TEST_DATE + ) + assert dag.timetable == AssetTriggeredTimetable(Asset(uri=uri, group="test-group")) assert dag.timetable.description == "Triggered by assets" @pytest.mark.parametrize( @@ -2173,7 +2185,7 @@ def test_dags_needing_dagruns_not_too_early(self): session.close() def test_dags_needing_dagruns_assets(self, dag_maker, session): - asset = Asset(uri="hello") + asset = Asset(uri="test://asset", group="test-group") with dag_maker( session=session, dag_id="my_dag", @@ -2405,8 +2417,8 @@ def test__processor_dags_folder(self, session): @pytest.mark.need_serialized_dag def test_dags_needing_dagruns_asset_triggered_dag_info_queued_times(self, session, dag_maker): - asset1 = Asset(uri="ds1") - asset2 = Asset(uri="ds2") + asset1 = Asset(uri="test://asset1", group="test-group") + asset2 = Asset(uri="test://asset2", name="test_asset_2", group="test-group") for dag_id, asset in [("assets-1", asset1), ("assets-2", asset2)]: with dag_maker(dag_id=dag_id, start_date=timezone.utcnow(), session=session): @@ -2455,10 +2467,15 @@ def test_asset_expression(self, session: Session) -> None: dag = DAG( dag_id="test_dag_asset_expression", schedule=AssetAny( - Asset("s3://dag1/output_1.txt", extra={"hi": "bye"}), + Asset(uri="s3://dag1/output_1.txt", extra={"hi": "bye"}, group="test-group"), AssetAll( - Asset("s3://dag2/output_1.txt", extra={"hi": "bye"}), - Asset("s3://dag3/output_3.txt", extra={"hi": "bye"}), + Asset( + uri="s3://dag2/output_1.txt", + name="test_asset_2", + extra={"hi": "bye"}, + group="test-group", + ), + Asset("s3://dag3/output_3.txt", extra={"hi": "bye"}, group="test-group"), ), AssetAlias(name="test_name"), ), @@ -2469,9 +2486,32 @@ def test_asset_expression(self, session: Session) -> None: expression = session.scalars(select(DagModel.asset_expression).filter_by(dag_id=dag.dag_id)).one() assert expression == { "any": [ - "s3://dag1/output_1.txt", - {"all": ["s3://dag2/output_1.txt", "s3://dag3/output_3.txt"]}, - {"alias": "test_name"}, + { + "asset": { + "uri": "s3://dag1/output_1.txt", + "name": "s3://dag1/output_1.txt", + "group": "test-group", + } + }, + { + "all": [ + { + "asset": { + "uri": "s3://dag2/output_1.txt", + "name": "test_asset_2", + "group": "test-group", + } + }, + { + "asset": { + "uri": "s3://dag3/output_3.txt", + "name": "s3://dag3/output_3.txt", + "group": "test-group", + } + }, + ] + }, + {"alias": {"name": "test_name", "group": ""}}, ] } @@ -3026,9 +3066,9 @@ def test__time_restriction(dag_maker, dag_date, tasks_date, restrict): @pytest.mark.need_serialized_dag def test_get_asset_triggered_next_run_info(dag_maker, clear_assets): - asset1 = Asset(uri="ds1") - asset2 = Asset(uri="ds2") - asset3 = Asset(uri="ds3") + asset1 = Asset(uri="test://asset1", name="test_asset1", group="test-group") + asset2 = Asset(uri="test://asset2", group="test-group") + asset3 = Asset(uri="test://asset3", group="test-group") with dag_maker(dag_id="assets-1", schedule=[asset2]): pass dag1 = dag_maker.dag From c4d93482a5fce0ed87c35d4d1a34c76a71dddd0f Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 27 Nov 2024 22:54:53 +0800 Subject: [PATCH 09/51] test(api_connexsion/schemas/dag_run): extend asset test cases to cover name, uri, group --- task_sdk/src/airflow/sdk/definitions/asset/__init__.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py index 812c30261bb97..da44a2c6b42ed 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -439,7 +439,8 @@ def __init__(self, *objects: BaseAsset) -> None: raise TypeError("expect asset expressions in condition") self.objects = [ - AssetAliasCondition(obj.name) if isinstance(obj, AssetAlias) else obj for obj in objects + AssetAliasCondition(name=obj.name, group=obj.group) if isinstance(obj, AssetAlias) else obj + for obj in objects ] def evaluate(self, statuses: dict[str, bool]) -> bool: @@ -515,8 +516,9 @@ class AssetAliasCondition(AssetAny): :meta private: """ - def __init__(self, name: str) -> None: + def __init__(self, name: str, group: str) -> None: self.name = name + self.group = group self.objects = expand_alias_to_assets(name) def __repr__(self) -> str: @@ -528,7 +530,7 @@ def as_expression(self) -> Any: :meta private: """ - return {"alias": self.name} + return {"alias": {"name": self.name, "group": self.group}} def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: yield self.name, AssetAlias(self.name) From 93bc452e78610f033aaf2286175668b5ee9e3609 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 27 Nov 2024 18:37:21 +0800 Subject: [PATCH 10/51] test(serialization/serialized_objects): extend asset test cases to cover name, uri, group and asset alias test cases to cover name and group --- .../serialization/test_serialized_objects.py | 58 +++++++++++++++---- 1 file changed, 46 insertions(+), 12 deletions(-) diff --git a/tests/serialization/test_serialized_objects.py b/tests/serialization/test_serialized_objects.py index 75ff736be8733..3e8e844528822 100644 --- a/tests/serialization/test_serialized_objects.py +++ b/tests/serialization/test_serialized_objects.py @@ -187,7 +187,11 @@ def __len__(self) -> int: (timezone.utcnow(), DAT.DATETIME, equal_time), (timedelta(minutes=2), DAT.TIMEDELTA, equals), (Timezone("UTC"), DAT.TIMEZONE, lambda a, b: a.name == b.name), - (relativedelta.relativedelta(hours=+1), DAT.RELATIVEDELTA, lambda a, b: a.hours == b.hours), + ( + relativedelta.relativedelta(hours=+1), + DAT.RELATIVEDELTA, + lambda a, b: a.hours == b.hours, + ), ({"test": "dict", "test-1": 1}, None, equals), (["array_item", 2], None, equals), (("tuple_item", 3), DAT.TUPLE, equals), @@ -195,7 +199,9 @@ def __len__(self) -> int: ( k8s.V1Pod( metadata=k8s.V1ObjectMeta( - name="test", annotations={"test": "annotation"}, creation_timestamp=timezone.utcnow() + name="test", + annotations={"test": "annotation"}, + creation_timestamp=timezone.utcnow(), ) ), DAT.POD, @@ -214,7 +220,14 @@ def __len__(self) -> int: ), (Resources(cpus=0.1, ram=2048), None, None), (EmptyOperator(task_id="test-task"), None, None), - (TaskGroup(group_id="test-group", dag=DAG(dag_id="test_dag", start_date=datetime.now())), None, None), + ( + TaskGroup( + group_id="test-group", + dag=DAG(dag_id="test_dag", start_date=datetime.now()), + ), + None, + None, + ), ( Param("test", "desc"), DAT.PARAM, @@ -231,8 +244,12 @@ def __len__(self) -> int: DAT.XCOM_REF, None, ), - (MockLazySelectSequence(), None, lambda a, b: len(a) == len(b) and isinstance(b, list)), - (Asset(uri="test"), DAT.ASSET, equals), + ( + MockLazySelectSequence(), + None, + lambda a, b: len(a) == len(b) and isinstance(b, list), + ), + (Asset(uri="test://asset1", name="test"), DAT.ASSET, equals), (SimpleTaskInstance.from_ti(ti=TI), DAT.SIMPLE_TASK_INSTANCE, equals), ( Connection(conn_id="TEST_ID", uri="mysql://"), @@ -240,16 +257,24 @@ def __len__(self) -> int: lambda a, b: a.get_uri() == b.get_uri(), ), ( - OutletEventAccessor(raw_key=Asset(uri="test"), extra={"key": "value"}, asset_alias_events=[]), + OutletEventAccessor( + raw_key=Asset(uri="test://asset1", name="test", group="test-group"), + extra={"key": "value"}, + asset_alias_events=[], + ), DAT.ASSET_EVENT_ACCESSOR, equal_outlet_event_accessor, ), ( OutletEventAccessor( - raw_key=AssetAlias(name="test_alias"), + raw_key=AssetAlias(name="test_alias", group="test-alias-group"), extra={"key": "value"}, asset_alias_events=[ - AssetAliasEvent(source_alias_name="test_alias", dest_asset_uri="test_uri", extra={}) + AssetAliasEvent( + source_alias_name="test_alias", + dest_asset_uri="test_uri", + extra={}, + ) ], ), DAT.ASSET_EVENT_ACCESSOR, @@ -295,7 +320,10 @@ def test_serialize_deserialize(input, encoded_type, cmp_func): "conn_uri", [ pytest.param("aws://", id="only-conn-type"), - pytest.param("postgres://username:password@ec2.compute.com:5432/the_database", id="all-non-extra"), + pytest.param( + "postgres://username:password@ec2.compute.com:5432/the_database", + id="all-non-extra", + ), pytest.param( "///?__extra__=%7B%22foo%22%3A+%22bar%22%2C+%22answer%22%3A+42%2C+%22" "nullable%22%3A+null%2C+%22empty%22%3A+%22%22%2C+%22zero%22%3A+0%7D", @@ -307,7 +335,10 @@ def test_backcompat_deserialize_connection(conn_uri): """Test deserialize connection which serialised by previous serializer implementation.""" from airflow.serialization.serialized_objects import BaseSerialization - conn_obj = {Encoding.TYPE: DAT.CONNECTION, Encoding.VAR: {"conn_id": "TEST_ID", "uri": conn_uri}} + conn_obj = { + Encoding.TYPE: DAT.CONNECTION, + Encoding.VAR: {"conn_id": "TEST_ID", "uri": conn_uri}, + } deserialized = BaseSerialization.deserialize(conn_obj) assert deserialized.get_uri() == conn_uri @@ -323,10 +354,13 @@ def test_backcompat_deserialize_connection(conn_uri): is_paused=True, ), LogTemplatePydantic: LogTemplate( - id=1, filename="test_file", elasticsearch_id="test_id", created_at=datetime.now() + id=1, + filename="test_file", + elasticsearch_id="test_id", + created_at=datetime.now(), ), DagTagPydantic: DagTag(), - AssetPydantic: Asset("uri", extra={}), + AssetPydantic: Asset(name="test", uri="test://asset1", extra={}), AssetEventPydantic: AssetEvent(), } From 01c8b62072c775297af77e6c4d194dfe5f8dcac7 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 27 Nov 2024 19:35:50 +0800 Subject: [PATCH 11/51] test(serialization/dag_serialization): extend asset test cases to cover name, uri, group --- tests/serialization/test_dag_serialization.py | 304 ++++++++++++++---- 1 file changed, 233 insertions(+), 71 deletions(-) diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index d7dbf54c1869e..3955d17477bab 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -88,7 +88,12 @@ from airflow.utils.xcom import XCOM_RETURN_KEY from tests_common.test_utils.compat import BaseOperatorLink -from tests_common.test_utils.mock_operators import AirflowLink2, CustomOperator, GoogleLink, MockOperator +from tests_common.test_utils.mock_operators import ( + AirflowLink2, + CustomOperator, + GoogleLink, + MockOperator, +) from tests_common.test_utils.timetables import ( CustomSerializationTimetable, cron_timetable, @@ -105,7 +110,10 @@ metadata=k8s.V1ObjectMeta(name="my-name"), spec=k8s.V1PodSpec( containers=[ - k8s.V1Container(name="base", volume_mounts=[k8s.V1VolumeMount(name="my-vol", mount_path="/vol/")]) + k8s.V1Container( + name="base", + volume_mounts=[k8s.V1VolumeMount(name="my-vol", mount_path="/vol/")], + ) ] ), ) @@ -133,7 +141,10 @@ "task_group": { "_group_id": None, "prefix_group_id": True, - "children": {"bash_task": ("operator", "bash_task"), "custom_task": ("operator", "custom_task")}, + "children": { + "bash_task": ("operator", "bash_task"), + "custom_task": ("operator", "custom_task"), + }, "tooltip": "", "ui_color": "CornflowerBlue", "ui_fgcolor": "#000", @@ -161,7 +172,10 @@ "ui_fgcolor": "#000", "template_ext": [".sh", ".bash"], "template_fields": ["bash_command", "env", "cwd"], - "template_fields_renderers": {"bash_command": "bash", "env": "json"}, + "template_fields_renderers": { + "bash_command": "bash", + "env": "json", + }, "bash_command": "echo {{ task.task_id }}", "task_type": "BashOperator", "_task_module": "airflow.providers.standard.operators.bash", @@ -223,7 +237,10 @@ "__var": { "DAGs": { "__type": "set", - "__var": [permissions.ACTION_CAN_READ, permissions.ACTION_CAN_EDIT], + "__var": [ + permissions.ACTION_CAN_READ, + permissions.ACTION_CAN_EDIT, + ], } }, } @@ -462,7 +479,10 @@ def test_dag_serialization_preserves_empty_access_roles(self): serialized_dag = SerializedDAG.to_dict(dag) SerializedDAG.validate_schema(serialized_dag) - assert serialized_dag["dag"]["access_control"] == {"__type": "dict", "__var": {}} + assert serialized_dag["dag"]["access_control"] == { + "__type": "dict", + "__var": {}, + } @pytest.mark.db_test def test_dag_serialization_unregistered_custom_timetable(self): @@ -690,14 +710,21 @@ def validate_deserialized_task( default_partial_kwargs = ( BaseOperator.partial(task_id="_")._expand(EXPAND_INPUT_EMPTY, strict=False).partial_kwargs ) - serialized_partial_kwargs = {**default_partial_kwargs, **serialized_task.partial_kwargs} + serialized_partial_kwargs = { + **default_partial_kwargs, + **serialized_task.partial_kwargs, + } original_partial_kwargs = {**default_partial_kwargs, **task.partial_kwargs} assert serialized_partial_kwargs == original_partial_kwargs @pytest.mark.parametrize( "dag_start_date, task_start_date, expected_task_start_date", [ - (datetime(2019, 8, 1, tzinfo=timezone.utc), None, datetime(2019, 8, 1, tzinfo=timezone.utc)), + ( + datetime(2019, 8, 1, tzinfo=timezone.utc), + None, + datetime(2019, 8, 1, tzinfo=timezone.utc), + ), ( datetime(2019, 8, 1, tzinfo=timezone.utc), datetime(2019, 8, 2, tzinfo=timezone.utc), @@ -749,7 +776,11 @@ def test_deserialization_with_dag_context(self): @pytest.mark.parametrize( "dag_end_date, task_end_date, expected_task_end_date", [ - (datetime(2019, 8, 1, tzinfo=timezone.utc), None, datetime(2019, 8, 1, tzinfo=timezone.utc)), + ( + datetime(2019, 8, 1, tzinfo=timezone.utc), + None, + datetime(2019, 8, 1, tzinfo=timezone.utc), + ), ( datetime(2019, 8, 1, tzinfo=timezone.utc), datetime(2019, 8, 2, tzinfo=timezone.utc), @@ -763,7 +794,12 @@ def test_deserialization_with_dag_context(self): ], ) def test_deserialization_end_date(self, dag_end_date, task_end_date, expected_task_end_date): - dag = DAG(dag_id="simple_dag", schedule=None, start_date=datetime(2019, 8, 1), end_date=dag_end_date) + dag = DAG( + dag_id="simple_dag", + schedule=None, + start_date=datetime(2019, 8, 1), + end_date=dag_end_date, + ) BaseOperator(task_id="simple_task", dag=dag, end_date=task_end_date) serialized_dag = SerializedDAG.to_dict(dag) @@ -781,7 +817,10 @@ def test_deserialization_end_date(self, dag_end_date, task_end_date, expected_ta @pytest.mark.parametrize( "serialized_timetable, expected_timetable", [ - ({"__type": "airflow.timetables.simple.NullTimetable", "__var": {}}, NullTimetable()), + ( + {"__type": "airflow.timetables.simple.NullTimetable", "__var": {}}, + NullTimetable(), + ), ( { "__type": "airflow.timetables.interval.CronDataIntervalTimetable", @@ -789,7 +828,10 @@ def test_deserialization_end_date(self, dag_end_date, task_end_date, expected_ta }, cron_timetable("0 0 * * 0"), ), - ({"__type": "airflow.timetables.simple.OnceTimetable", "__var": {}}, OnceTimetable()), + ( + {"__type": "airflow.timetables.simple.OnceTimetable", "__var": {}}, + OnceTimetable(), + ), ( { "__type": "airflow.timetables.interval.DeltaDataIntervalTimetable", @@ -848,12 +890,24 @@ def test_deserialization_timetable_unregistered(self): @pytest.mark.parametrize( "val, expected", [ - (relativedelta(days=-1), {"__type": "relativedelta", "__var": {"days": -1}}), - (relativedelta(month=1, days=-1), {"__type": "relativedelta", "__var": {"month": 1, "days": -1}}), + ( + relativedelta(days=-1), + {"__type": "relativedelta", "__var": {"days": -1}}, + ), + ( + relativedelta(month=1, days=-1), + {"__type": "relativedelta", "__var": {"month": 1, "days": -1}}, + ), # Every friday - (relativedelta(weekday=FR), {"__type": "relativedelta", "__var": {"weekday": [4]}}), + ( + relativedelta(weekday=FR), + {"__type": "relativedelta", "__var": {"weekday": [4]}}, + ), # Every second friday - (relativedelta(weekday=FR(2)), {"__type": "relativedelta", "__var": {"weekday": [4, 2]}}), + ( + relativedelta(weekday=FR(2)), + {"__type": "relativedelta", "__var": {"weekday": [4, 2]}}, + ), ], ) def test_roundtrip_relativedelta(self, val, expected): @@ -913,7 +967,11 @@ def __init__(self, path: str): schema = {"type": "string", "pattern": r"s3:\/\/(.+?)\/(.+)"} super().__init__(default=path, schema=schema) - dag = DAG(dag_id="simple_dag", schedule=None, params={"path": S3Param("s3://my_bucket/my_path")}) + dag = DAG( + dag_id="simple_dag", + schedule=None, + params={"path": S3Param("s3://my_bucket/my_path")}, + ) with pytest.raises(SerializationError): SerializedDAG.to_dict(dag) @@ -968,11 +1026,21 @@ def test_task_params_roundtrip(self, val, expected_val): dag = DAG(dag_id="simple_dag", schedule=None) if expected_val == ParamValidationError: with pytest.raises(ParamValidationError): - BaseOperator(task_id="simple_task", dag=dag, params=val, start_date=datetime(2019, 8, 1)) + BaseOperator( + task_id="simple_task", + dag=dag, + params=val, + start_date=datetime(2019, 8, 1), + ) # further tests not relevant return else: - BaseOperator(task_id="simple_task", dag=dag, params=val, start_date=datetime(2019, 8, 1)) + BaseOperator( + task_id="simple_task", + dag=dag, + params=val, + start_date=datetime(2019, 8, 1), + ) serialized_dag = SerializedDAG.to_dict(dag) deserialized_dag = SerializedDAG.from_dict(serialized_dag) @@ -1130,10 +1198,19 @@ def __ne__(self, other): ("{{ task.task_id }}", "{{ task.task_id }}"), (["{{ task.task_id }}", "{{ task.task_id }}"]), ({"foo": "{{ task.task_id }}"}, {"foo": "{{ task.task_id }}"}), - ({"foo": {"bar": "{{ task.task_id }}"}}, {"foo": {"bar": "{{ task.task_id }}"}}), ( - [{"foo1": {"bar": "{{ task.task_id }}"}}, {"foo2": {"bar": "{{ task.task_id }}"}}], - [{"foo1": {"bar": "{{ task.task_id }}"}}, {"foo2": {"bar": "{{ task.task_id }}"}}], + {"foo": {"bar": "{{ task.task_id }}"}}, + {"foo": {"bar": "{{ task.task_id }}"}}, + ), + ( + [ + {"foo1": {"bar": "{{ task.task_id }}"}}, + {"foo2": {"bar": "{{ task.task_id }}"}}, + ], + [ + {"foo1": {"bar": "{{ task.task_id }}"}}, + {"foo2": {"bar": "{{ task.task_id }}"}}, + ], ), ( {"foo": {"bar": {"{{ task.task_id }}": ["sar"]}}}, @@ -1141,7 +1218,9 @@ def __ne__(self, other): ), ( ClassWithCustomAttributes( - att1="{{ task.task_id }}", att2="{{ task.task_id }}", template_fields=["att1"] + att1="{{ task.task_id }}", + att2="{{ task.task_id }}", + template_fields=["att1"], ), "ClassWithCustomAttributes(" "{'att1': '{{ task.task_id }}', 'att2': '{{ task.task_id }}', 'template_fields': ['att1']})", @@ -1149,10 +1228,14 @@ def __ne__(self, other): ( ClassWithCustomAttributes( nested1=ClassWithCustomAttributes( - att1="{{ task.task_id }}", att2="{{ task.task_id }}", template_fields=["att1"] + att1="{{ task.task_id }}", + att2="{{ task.task_id }}", + template_fields=["att1"], ), nested2=ClassWithCustomAttributes( - att3="{{ task.task_id }}", att4="{{ task.task_id }}", template_fields=["att3"] + att3="{{ task.task_id }}", + att4="{{ task.task_id }}", + template_fields=["att3"], ), template_fields=["nested1"], ), @@ -1172,7 +1255,11 @@ def test_templated_fields_exist_in_serialized_dag(self, templated_field, expecte we want check that non-"basic" objects are turned in to strings after deserializing. """ - dag = DAG("test_serialized_template_fields", schedule=None, start_date=datetime(2019, 8, 1)) + dag = DAG( + "test_serialized_template_fields", + schedule=None, + start_date=datetime(2019, 8, 1), + ) with dag: BashOperator(task_id="test", bash_command=templated_field) @@ -1410,7 +1497,11 @@ def test_setup_teardown_tasks(self): """ logical_date = datetime(2020, 1, 1) - with DAG("test_task_group_setup_teardown_tasks", schedule=None, start_date=logical_date) as dag: + with DAG( + "test_task_group_setup_teardown_tasks", + schedule=None, + start_date=logical_date, + ) as dag: EmptyOperator(task_id="setup").as_setup() EmptyOperator(task_id="teardown").as_teardown() @@ -1580,7 +1671,11 @@ class DummyTask(BaseOperator): deps = frozenset([*BaseOperator.deps, CustomTestTriggerRule()]) logical_date = datetime(2020, 1, 1) - with DAG(dag_id="test_serialize_custom_ti_deps", schedule=None, start_date=logical_date) as dag: + with DAG( + dag_id="test_serialize_custom_ti_deps", + schedule=None, + start_date=logical_date, + ) as dag: DummyTask(task_id="task1") serialize_op = SerializedBaseOperator.serialize_operator(dag.task_dict["task1"]) @@ -1668,20 +1763,26 @@ def test_dag_deps_assets_with_duplicate_asset(self): """ from airflow.providers.standard.sensors.external_task import ExternalTaskSensor - d1 = Asset("d1") - d2 = Asset("d2") - d3 = Asset("d3") - d4 = Asset("d4") + asset1 = Asset(name="asset1", uri="test://asset1") + asset2 = Asset(name="asset2", uri="test://asset2") + asset3 = Asset(name="asset3", uri="test://asset3") + asset4 = Asset(name="asset4", uri="test://asset4") logical_date = datetime(2020, 1, 1) - with DAG(dag_id="test", start_date=logical_date, schedule=[d1, d1, d1, d1, d1]) as dag: + with DAG( + dag_id="test", start_date=logical_date, schedule=[asset1, asset1, asset1, asset1, asset1] + ) as dag: ExternalTaskSensor( task_id="task1", external_dag_id="external_dag_id", mode="reschedule", ) - BashOperator(task_id="asset_writer", bash_command="echo hello", outlets=[d2, d2, d2, d3]) + BashOperator( + task_id="asset_writer", + bash_command="echo hello", + outlets=[asset2, asset2, asset2, asset3], + ) - @dag.task(outlets=[d4]) + @dag.task(outlets=[asset4]) def other_asset_writer(x): pass @@ -1695,7 +1796,7 @@ def other_asset_writer(x): "source": "test", "target": "asset", "dependency_type": "asset", - "dependency_id": "d4", + "dependency_id": "asset4", }, { "source": "external_dag_id", @@ -1707,40 +1808,40 @@ def other_asset_writer(x): "source": "test", "target": "asset", "dependency_type": "asset", - "dependency_id": "d3", + "dependency_id": "asset3", }, { "source": "test", "target": "asset", "dependency_type": "asset", - "dependency_id": "d2", + "dependency_id": "asset2", }, { "source": "asset", "target": "test", "dependency_type": "asset", - "dependency_id": "d1", + "dependency_id": "asset1", }, { - "dependency_id": "d1", + "dependency_id": "asset1", "dependency_type": "asset", "source": "asset", "target": "test", }, { - "dependency_id": "d1", + "dependency_id": "asset1", "dependency_type": "asset", "source": "asset", "target": "test", }, { - "dependency_id": "d1", + "dependency_id": "asset1", "dependency_type": "asset", "source": "asset", "target": "test", }, { - "dependency_id": "d1", + "dependency_id": "asset1", "dependency_type": "asset", "source": "asset", "target": "test", @@ -1757,20 +1858,20 @@ def test_dag_deps_assets(self): """ from airflow.providers.standard.sensors.external_task import ExternalTaskSensor - d1 = Asset("d1") - d2 = Asset("d2") - d3 = Asset("d3") - d4 = Asset("d4") + asset1 = Asset(name="asset1", uri="test://asset1") + asset2 = Asset(name="asset2", uri="test://asset2") + asset3 = Asset(name="asset3", uri="test://asset3") + asset4 = Asset(name="asset4", uri="test://asset4") logical_date = datetime(2020, 1, 1) - with DAG(dag_id="test", start_date=logical_date, schedule=[d1]) as dag: + with DAG(dag_id="test", start_date=logical_date, schedule=[asset1]) as dag: ExternalTaskSensor( task_id="task1", external_dag_id="external_dag_id", mode="reschedule", ) - BashOperator(task_id="asset_writer", bash_command="echo hello", outlets=[d2, d3]) + BashOperator(task_id="asset_writer", bash_command="echo hello", outlets=[asset2, asset3]) - @dag.task(outlets=[d4]) + @dag.task(outlets=[asset4]) def other_asset_writer(x): pass @@ -1784,7 +1885,7 @@ def other_asset_writer(x): "source": "test", "target": "asset", "dependency_type": "asset", - "dependency_id": "d4", + "dependency_id": "asset4", }, { "source": "external_dag_id", @@ -1796,19 +1897,19 @@ def other_asset_writer(x): "source": "test", "target": "asset", "dependency_type": "asset", - "dependency_id": "d3", + "dependency_id": "asset3", }, { "source": "test", "target": "asset", "dependency_type": "asset", - "dependency_id": "d2", + "dependency_id": "asset2", }, { "source": "asset", "target": "test", "dependency_type": "asset", - "dependency_id": "d1", + "dependency_id": "asset1", }, ], key=lambda x: tuple(x.values()), @@ -1821,14 +1922,20 @@ def test_derived_dag_deps_operator(self, mapped): Tests DAG dependency detection for operators, including derived classes """ from airflow.operators.empty import EmptyOperator - from airflow.providers.standard.operators.trigger_dagrun import TriggerDagRunOperator + from airflow.providers.standard.operators.trigger_dagrun import ( + TriggerDagRunOperator, + ) class DerivedOperator(TriggerDagRunOperator): pass logical_date = datetime(2020, 1, 1) for class_ in [TriggerDagRunOperator, DerivedOperator]: - with DAG(dag_id="test_derived_dag_deps_trigger", schedule=None, start_date=logical_date) as dag: + with DAG( + dag_id="test_derived_dag_deps_trigger", + schedule=None, + start_date=logical_date, + ) as dag: task1 = EmptyOperator(task_id="task1") if mapped: task2 = class_.partial( @@ -1912,7 +2019,10 @@ def test_task_group_sorted(self): assert upstream_group_ids == ["task_group_up1", "task_group_up2"] upstream_task_ids = task_group_middle_dict["upstream_task_ids"] - assert upstream_task_ids == ["task_group_up1.task_up1", "task_group_up2.task_up2"] + assert upstream_task_ids == [ + "task_group_up1.task_up1", + "task_group_up2.task_up2", + ] downstream_group_ids = task_group_middle_dict["downstream_group_ids"] assert downstream_group_ids == ["task_group_down1", "task_group_down2"] @@ -1930,7 +2040,11 @@ def test_edge_info_serialization(self): from airflow.operators.empty import EmptyOperator from airflow.utils.edgemodifier import Label - with DAG("test_edge_info_serialization", schedule=None, start_date=datetime(2020, 1, 1)) as dag: + with DAG( + "test_edge_info_serialization", + schedule=None, + start_date=datetime(2020, 1, 1), + ) as dag: task1 = EmptyOperator(task_id="task1") task2 = EmptyOperator(task_id="task2") task1 >> Label("test label") >> task2 @@ -2024,7 +2138,11 @@ def test_dag_on_failure_callback_roundtrip(self, passed_failure_callback, expect When the callback is not set, has_on_failure_callback should not be stored in Serialized blob and so default to False on de-serialization """ - dag = DAG(dag_id="test_dag_on_failure_callback_roundtrip", schedule=None, **passed_failure_callback) + dag = DAG( + dag_id="test_dag_on_failure_callback_roundtrip", + schedule=None, + **passed_failure_callback, + ) BaseOperator(task_id="simple_task", dag=dag, start_date=datetime(2019, 8, 1)) serialized_dag = SerializedDAG.to_dict(dag) @@ -2116,7 +2234,12 @@ def test_params_serialization_from_dict_upgrade(self): "fileloc": "/path/to/file.py", "tasks": [], "timezone": "UTC", - "params": {"my_param": {"__class": "airflow.models.param.Param", "default": "str"}}, + "params": { + "my_param": { + "__class": "airflow.models.param.Param", + "default": "str", + } + }, }, } dag = SerializedDAG.from_dict(serialized) @@ -2265,7 +2388,10 @@ def execute_complete(self): "__type": "START_TRIGGER_ARGS", "trigger_cls": "airflow.providers.standard.triggers.temporal.TimeDeltaTrigger", # "trigger_kwargs": {"__type": "dict", "__var": {"delta": {"__type": "timedelta", "__var": 2.0}}}, - "trigger_kwargs": {"__type": "dict", "__var": {"delta": {"__type": "timedelta", "__var": 2.0}}}, + "trigger_kwargs": { + "__type": "dict", + "__var": {"delta": {"__type": "timedelta", "__var": 2.0}}, + }, "next_method": "execute_complete", "next_kwargs": None, "timeout": None, @@ -2400,7 +2526,12 @@ def test_operator_expand_xcomarg_serde(): "type": "dict-of-lists", "value": { "__type": "dict", - "__var": {"arg2": {"__type": "xcomref", "__var": {"task_id": "op1", "key": "return_value"}}}, + "__var": { + "arg2": { + "__type": "xcomref", + "__var": {"task_id": "op1", "key": "return_value"}, + } + }, }, }, "partial_kwargs": {}, @@ -2457,7 +2588,12 @@ def test_operator_expand_kwargs_literal_serde(strict): {"__type": "dict", "__var": {"a": "x"}}, { "__type": "dict", - "__var": {"a": {"__type": "xcomref", "__var": {"task_id": "op1", "key": "return_value"}}}, + "__var": { + "a": { + "__type": "xcomref", + "__var": {"task_id": "op1", "key": "return_value"}, + } + }, }, ], }, @@ -2481,12 +2617,18 @@ def test_operator_expand_kwargs_literal_serde(strict): # The XComArg can't be deserialized before the DAG is. expand_value = op.expand_input.value - assert expand_value == [{"a": "x"}, {"a": _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY})}] + assert expand_value == [ + {"a": "x"}, + {"a": _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY})}, + ] serialized_dag: DAG = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) resolved_expand_value = serialized_dag.task_dict["task_2"].expand_input.value - resolved_expand_value == [{"a": "x"}, {"a": PlainXComArg(serialized_dag.task_dict["op1"])}] + resolved_expand_value == [ + {"a": "x"}, + {"a": PlainXComArg(serialized_dag.task_dict["op1"])}, + ] @pytest.mark.parametrize("strict", [True, False]) @@ -2508,7 +2650,10 @@ def test_operator_expand_kwargs_xcomarg_serde(strict): "downstream_task_ids": [], "expand_input": { "type": "list-of-dicts", - "value": {"__type": "xcomref", "__var": {"task_id": "op1", "key": "return_value"}}, + "value": { + "__type": "xcomref", + "__var": {"task_id": "op1", "key": "return_value"}, + }, }, "partial_kwargs": {}, "task_id": "task_2", @@ -2640,7 +2785,10 @@ def x(arg1, arg2, arg3): "__type": "dict", "__var": { "arg2": {"__type": "dict", "__var": {"a": 1, "b": 2}}, - "arg3": {"__type": "xcomref", "__var": {"task_id": "op1", "key": "return_value"}}, + "arg3": { + "__type": "xcomref", + "__var": {"task_id": "op1", "key": "return_value"}, + }, }, }, }, @@ -2650,7 +2798,11 @@ def x(arg1, arg2, arg3): "task_id": "x", "template_ext": [], "template_fields": ["templates_dict", "op_args", "op_kwargs"], - "template_fields_renderers": {"templates_dict": "json", "op_args": "py", "op_kwargs": "py"}, + "template_fields_renderers": { + "templates_dict": "json", + "op_args": "py", + "op_kwargs": "py", + }, "_disallow_kwargs_override": False, "_expand_input_attr": "op_kwargs_expand_input", "python_callable_name": qualname(x), @@ -2666,7 +2818,10 @@ def x(arg1, arg2, arg3): assert deserialized.op_kwargs_expand_input == _ExpandInputRef( key="dict-of-lists", - value={"arg2": {"a": 1, "b": 2}, "arg3": _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY})}, + value={ + "arg2": {"a": 1, "b": 2}, + "arg3": _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY}), + }, ) assert deserialized.partial_kwargs == { "is_setup": False, @@ -2688,7 +2843,10 @@ def x(arg1, arg2, arg3): pickled = pickle.loads(pickle.dumps(deserialized)) assert pickled.op_kwargs_expand_input == _ExpandInputRef( key="dict-of-lists", - value={"arg2": {"a": 1, "b": 2}, "arg3": _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY})}, + value={ + "arg2": {"a": 1, "b": 2}, + "arg3": _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY}), + }, ) assert pickled.partial_kwargs == { "is_setup": False, @@ -2753,7 +2911,11 @@ def x(arg1, arg2, arg3): "task_id": "x", "template_ext": [], "template_fields": ["templates_dict", "op_args", "op_kwargs"], - "template_fields_renderers": {"templates_dict": "json", "op_args": "py", "op_kwargs": "py"}, + "template_fields_renderers": { + "templates_dict": "json", + "op_args": "py", + "op_kwargs": "py", + }, "_disallow_kwargs_override": strict, "_expand_input_attr": "op_kwargs_expand_input", } From b0884827aa1315fcf7a7e8de85d2ea2e63676eed Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 27 Nov 2024 19:39:56 +0800 Subject: [PATCH 12/51] test(models/dag): extend asset test cases to cover name, uri, group --- tests/models/test_dag.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 97d11aa909fce..384d76c7548b4 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -2477,7 +2477,7 @@ def test_asset_expression(self, session: Session) -> None: ), Asset("s3://dag3/output_3.txt", extra={"hi": "bye"}, group="test-group"), ), - AssetAlias(name="test_name"), + AssetAlias(name="test_name", group="test-group"), ), start_date=datetime.datetime.min, ) @@ -2511,7 +2511,7 @@ def test_asset_expression(self, session: Session) -> None: }, ] }, - {"alias": {"name": "test_name", "group": ""}}, + {"alias": {"name": "test_name", "group": "test-group"}}, ] } From d07966ca6aef5c885de9c225e9acd67ae1afbfbb Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 27 Nov 2024 21:25:35 +0800 Subject: [PATCH 13/51] test(serialization/serde): extend asset test cases to cover name, uri, group --- tests/serialization/test_serde.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/serialization/test_serde.py b/tests/serialization/test_serde.py index a3a946124ff9d..2fc8ad8d17b3d 100644 --- a/tests/serialization/test_serde.py +++ b/tests/serialization/test_serde.py @@ -365,7 +365,7 @@ def test_backwards_compat_wrapped(self): assert e["extra"] == {"hi": "bye"} def test_encode_asset(self): - asset = Asset("mytest://asset") + asset = Asset(uri="mytest://asset", name="test") obj = deserialize(serialize(asset)) assert asset.uri == obj.uri From 27f5c517461603a25429473046aacefb068f2a68 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 27 Nov 2024 21:25:57 +0800 Subject: [PATCH 14/51] test(api_connexion/schemas/asset): extend asset test cases to cover name, uri, group --- tests/api_connexion/schemas/test_asset_schema.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/api_connexion/schemas/test_asset_schema.py b/tests/api_connexion/schemas/test_asset_schema.py index af5e8c08b86a6..dc070bf96ddbc 100644 --- a/tests/api_connexion/schemas/test_asset_schema.py +++ b/tests/api_connexion/schemas/test_asset_schema.py @@ -54,6 +54,8 @@ class TestAssetSchema(TestAssetSchemaBase): def test_serialize(self, dag_maker, session): asset = Asset( uri="s3://bucket/key", + name="test_asset", + group="test-group", extra={"foo": "bar"}, ) with dag_maker(dag_id="test_asset_upstream_schema", serialized=True, session=session): @@ -70,6 +72,8 @@ def test_serialize(self, dag_maker, session): assert serialized_data == { "id": 1, "uri": "s3://bucket/key", + "name": "test_asset", + "group": "test-group", "extra": {"foo": "bar"}, "created_at": self.timestamp, "updated_at": self.timestamp, @@ -96,10 +100,12 @@ class TestAssetCollectionSchema(TestAssetSchemaBase): def test_serialize(self, session): assets = [ AssetModel( - uri=f"s3://bucket/key/{i+1}", + uri=f"s3://bucket/key/{i}", + name=f"asset_{i}", + group="test-group", extra={"foo": "bar"}, ) - for i in range(2) + for i in range(1, 3) ] asset_aliases = [AssetAliasModel(name=f"alias_{i}") for i in range(2)] for asset_alias in asset_aliases: @@ -117,6 +123,8 @@ def test_serialize(self, session): { "id": 1, "uri": "s3://bucket/key/1", + "name": "asset_1", + "group": "test-group", "extra": {"foo": "bar"}, "created_at": self.timestamp, "updated_at": self.timestamp, @@ -130,6 +138,8 @@ def test_serialize(self, session): { "id": 2, "uri": "s3://bucket/key/2", + "name": "asset_2", + "group": "test-group", "extra": {"foo": "bar"}, "created_at": self.timestamp, "updated_at": self.timestamp, From b46ed072e3e474ed787f2bc3efc716797c83e022 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 27 Nov 2024 21:28:29 +0800 Subject: [PATCH 15/51] test(api_connexion/schemas/asset): extend asset alias test cases to cover name, group --- tests/api_connexion/schemas/test_asset_schema.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/api_connexion/schemas/test_asset_schema.py b/tests/api_connexion/schemas/test_asset_schema.py index dc070bf96ddbc..ff5a81961e949 100644 --- a/tests/api_connexion/schemas/test_asset_schema.py +++ b/tests/api_connexion/schemas/test_asset_schema.py @@ -107,7 +107,7 @@ def test_serialize(self, session): ) for i in range(1, 3) ] - asset_aliases = [AssetAliasModel(name=f"alias_{i}") for i in range(2)] + asset_aliases = [AssetAliasModel(name=f"alias_{i}", group="test-alias-group") for i in range(2)] for asset_alias in asset_aliases: asset_alias.assets.append(assets[0]) session.add_all(assets) @@ -131,8 +131,8 @@ def test_serialize(self, session): "consuming_dags": [], "producing_tasks": [], "aliases": [ - {"id": 1, "name": "alias_0"}, - {"id": 2, "name": "alias_1"}, + {"id": 1, "name": "alias_0", "group": "test-alias-group"}, + {"id": 2, "name": "alias_1", "group": "test-alias-group"}, ], }, { From 6a0907e5a30fd951925b1538ed1f7f47b3f0f80f Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 27 Nov 2024 21:30:25 +0800 Subject: [PATCH 16/51] test(api_connexsion/schemas/dag): extend asset test cases to cover name, uri, group --- tests/api_connexion/schemas/test_dag_schema.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/api_connexion/schemas/test_dag_schema.py b/tests/api_connexion/schemas/test_dag_schema.py index 4f1b07fb6e70f..d6438045249aa 100644 --- a/tests/api_connexion/schemas/test_dag_schema.py +++ b/tests/api_connexion/schemas/test_dag_schema.py @@ -198,8 +198,8 @@ def test_serialize_test_dag_detail_schema(url_safe_serializer): @pytest.mark.db_test def test_serialize_test_dag_with_asset_schedule_detail_schema(url_safe_serializer): - asset1 = Asset(uri="s3://bucket/obj1") - asset2 = Asset(uri="s3://bucket/obj2") + asset1 = Asset(uri="s3://bucket/obj1", name="asset1") + asset2 = Asset(uri="s3://bucket/obj2", name="asset2") dag = DAG( dag_id="test_dag", start_date=datetime(2020, 6, 19), From f29090ee8a8af8ae9d8cc791de40a066261bf2fd Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 27 Nov 2024 21:32:10 +0800 Subject: [PATCH 17/51] test(api_connexsion/schemas/dag_run): extend asset test cases to cover name, uri, group --- tests/api_connexion/endpoints/test_dag_run_endpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py b/tests/api_connexion/endpoints/test_dag_run_endpoint.py index 5b4133c683958..45e6bf53376f8 100644 --- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py @@ -1738,7 +1738,7 @@ def test_should_respond_404(self): @pytest.mark.need_serialized_dag class TestGetDagRunAssetTriggerEvents(TestDagRunEndpoint): def test_should_respond_200(self, dag_maker, session): - asset1 = Asset(uri="ds1") + asset1 = Asset(uri="test://asset1", name="asset1") with dag_maker(dag_id="source_dag", start_date=timezone.utcnow(), session=session): EmptyOperator(task_id="task", outlets=[asset1]) From b2a98ef42f216771240b03f025646d568939510b Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 27 Nov 2024 21:33:51 +0800 Subject: [PATCH 18/51] test(dags/test_assets): extend asset test cases to cover name, uri, group --- tests/dags/test_assets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/dags/test_assets.py b/tests/dags/test_assets.py index 1fbc67a18d329..6a0b08f9ba6a1 100644 --- a/tests/dags/test_assets.py +++ b/tests/dags/test_assets.py @@ -25,8 +25,8 @@ from airflow.providers.standard.operators.python import PythonOperator from airflow.sdk.definitions.asset import Asset -skip_task_dag_asset = Asset("s3://dag_with_skip_task/output_1.txt", extra={"hi": "bye"}) -fail_task_dag_asset = Asset("s3://dag_with_fail_task/output_1.txt", extra={"hi": "bye"}) +skip_task_dag_asset = Asset(uri="s3://dag_with_skip_task/output_1.txt", name="skip", extra={"hi": "bye"}) +fail_task_dag_asset = Asset(uri="s3://dag_with_fail_task/output_1.txt", name="fail", extra={"hi": "bye"}) def raise_skip_exc(): From 97d8ece94fd819141b2d5aa0d97afd2851c653a1 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 27 Nov 2024 21:34:37 +0800 Subject: [PATCH 19/51] test(dags/test_only_empty_tasks): extend asset test cases to cover name, uri, group --- tests/dags/test_only_empty_tasks.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/dags/test_only_empty_tasks.py b/tests/dags/test_only_empty_tasks.py index e5152f1f9ad34..92c5464982453 100644 --- a/tests/dags/test_only_empty_tasks.py +++ b/tests/dags/test_only_empty_tasks.py @@ -56,4 +56,6 @@ def __init__(self, body, *args, **kwargs): EmptyOperator(task_id="test_task_on_success", on_success_callback=lambda *args, **kwargs: None) - EmptyOperator(task_id="test_task_outlets", outlets=[Asset("hello")]) + EmptyOperator( + task_id="test_task_outlets", outlets=[Asset(name="hello", uri="test://asset1", group="test-group")] + ) From 4e699a18c1efe6b7fa9bbc8c4779ac8935eceaff Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 27 Nov 2024 21:35:33 +0800 Subject: [PATCH 20/51] test(api_fastapi): extend asset test cases to cover name, uri, group --- .../core_api/routes/ui/test_assets.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/tests/api_fastapi/core_api/routes/ui/test_assets.py b/tests/api_fastapi/core_api/routes/ui/test_assets.py index 8eafb0f8bdd4b..7b532918496ed 100644 --- a/tests/api_fastapi/core_api/routes/ui/test_assets.py +++ b/tests/api_fastapi/core_api/routes/ui/test_assets.py @@ -36,7 +36,11 @@ def cleanup(): def test_next_run_assets(test_client, dag_maker): - with dag_maker(dag_id="upstream", schedule=[Asset(uri="s3://bucket/key/1")], serialized=True): + with dag_maker( + dag_id="upstream", + schedule=[Asset(uri="s3://bucket/next-run-asset/1", name="asset1")], + serialized=True, + ): EmptyOperator(task_id="task1") dag_maker.create_dagrun() @@ -46,6 +50,16 @@ def test_next_run_assets(test_client, dag_maker): assert response.status_code == 200 assert response.json() == { - "asset_expression": {"all": ["s3://bucket/key/1"]}, - "events": [{"id": 20, "uri": "s3://bucket/key/1", "lastUpdate": None}], + "asset_expression": { + "all": [ + { + "asset": { + "uri": "s3://bucket/next-run-asset/1", + "name": "asset1", + "group": "asset", + } + } + ] + }, + "events": [{"id": 20, "uri": "s3://bucket/next-run-asset/1", "lastUpdate": None}], } From 24ae8c18e03e178e8de39b5d010d3d91cc0fbfd3 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 27 Nov 2024 22:02:54 +0800 Subject: [PATCH 21/51] test(assets/manager): extend asset test cases to cover name, uri, group --- tests/assets/test_manager.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/assets/test_manager.py b/tests/assets/test_manager.py index aa8fbb036242e..b716056e81466 100644 --- a/tests/assets/test_manager.py +++ b/tests/assets/test_manager.py @@ -112,7 +112,7 @@ def create_mock_dag(): class TestAssetManager: def test_register_asset_change_asset_doesnt_exist(self, mock_task_instance): - asset = Asset(uri="asset_doesnt_exist") + asset = Asset(uri="asset_doesnt_exist", name="not exist") mock_session = mock.Mock() # Gotta mock up the query results @@ -131,12 +131,12 @@ def test_register_asset_change_asset_doesnt_exist(self, mock_task_instance): def test_register_asset_change(self, session, dag_maker, mock_task_instance): asset_manager = AssetManager() - asset = Asset(uri="test_asset_uri") + asset = Asset(uri="test://asset1", name="test_asset_uri", group="asset") dag1 = DagModel(dag_id="dag1", is_active=True) dag2 = DagModel(dag_id="dag2", is_active=True) session.add_all([dag1, dag2]) - asm = AssetModel(uri="test_asset_uri") + asm = AssetModel(uri="test://asset1/", name="test_asset_uri", group="asset") session.add(asm) asm.consuming_dags = [DagScheduleAssetReference(dag_id=dag.dag_id) for dag in (dag1, dag2)] session.execute(delete(AssetDagRunQueue)) @@ -155,10 +155,10 @@ def test_register_asset_change_with_alias(self, session, dag_maker, mock_task_in consumer_dag_2 = DagModel(dag_id="conumser_2", is_active=True, fileloc="dag2.py") session.add_all([consumer_dag_1, consumer_dag_2]) - asm = AssetModel(uri="test_asset_uri") + asm = AssetModel(uri="test://asset1/", name="test_asset_uri", group="asset") session.add(asm) - asam = AssetAliasModel(name="test_alias_name") + asam = AssetAliasModel(name="test_alias_name", group="test") session.add(asam) asam.consuming_dags = [ DagScheduleAssetAliasReference(alias_id=asam.id, dag_id=dag.dag_id) @@ -167,8 +167,8 @@ def test_register_asset_change_with_alias(self, session, dag_maker, mock_task_in session.execute(delete(AssetDagRunQueue)) session.flush() - asset = Asset(uri="test_asset_uri") - asset_alias = AssetAlias(name="test_alias_name") + asset = Asset(uri="test://asset1", name="test_asset_uri") + asset_alias = AssetAlias(name="test_alias_name", group="test") asset_manager = AssetManager() asset_manager.register_asset_change( task_instance=mock_task_instance, @@ -187,8 +187,8 @@ def test_register_asset_change_with_alias(self, session, dag_maker, mock_task_in def test_register_asset_change_no_downstreams(self, session, mock_task_instance): asset_manager = AssetManager() - asset = Asset(uri="never_consumed") - asm = AssetModel(uri="never_consumed") + asset = Asset(uri="test://asset1", name="never_consumed") + asm = AssetModel(uri="test://asset1/", name="never_consumed", group="asset") session.add(asm) session.execute(delete(AssetDagRunQueue)) session.flush() @@ -205,11 +205,11 @@ def test_register_asset_change_notifies_asset_listener(self, session, mock_task_ asset_listener.clear() get_listener_manager().add_listener(asset_listener) - asset = Asset(uri="test_asset_uri_2") + asset = Asset(uri="test://asset1", name="test_asset_1") dag1 = DagModel(dag_id="dag3") session.add(dag1) - asm = AssetModel(uri="test_asset_uri_2") + asm = AssetModel(uri="test://asset1/", name="test_asset_1", group="asset") session.add(asm) asm.consuming_dags = [DagScheduleAssetReference(dag_id=dag1.dag_id)] session.flush() @@ -226,7 +226,7 @@ def test_create_assets_notifies_asset_listener(self, session): asset_listener.clear() get_listener_manager().add_listener(asset_listener) - asset = Asset(uri="test_asset_uri_3") + asset = Asset(uri="test://asset1", name="test_asset_1") asms = asset_manager.create_assets([asset], session=session) From f0bf926ca4407fffa80e08e4aff42f30a7be3396 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 27 Nov 2024 22:54:38 +0800 Subject: [PATCH 22/51] test(task_sdk/assets): extend asset test cases to cover name, uri, group --- task_sdk/tests/defintions/test_asset.py | 105 +++++++++++++----------- 1 file changed, 55 insertions(+), 50 deletions(-) diff --git a/task_sdk/tests/defintions/test_asset.py b/task_sdk/tests/defintions/test_asset.py index ef602ea5a2267..1a580061f65ec 100644 --- a/task_sdk/tests/defintions/test_asset.py +++ b/task_sdk/tests/defintions/test_asset.py @@ -176,12 +176,12 @@ def test_asset_iter_assets(): @pytest.mark.db_test def test_asset_iter_asset_aliases(): base_asset = AssetAll( - AssetAlias("example-alias-1"), + AssetAlias(name="example-alias-1"), Asset("1"), AssetAny( - Asset("2"), + Asset(name="2", uri="test://asset1"), AssetAlias("example-alias-2"), - Asset("3"), + Asset(name="3"), AssetAll(AssetAlias("example-alias-3"), Asset("4"), AssetAlias("example-alias-4")), ), AssetAll(AssetAlias("example-alias-5"), Asset("5")), @@ -254,7 +254,7 @@ def test_assset_boolean_condition_evaluate_iter(): ) def test_asset_logical_conditions_evaluation_and_serialization(inputs, scenario, expected): class_ = AssetAny if scenario == "any" else AssetAll - assets = [Asset(uri=f"s3://abc/{i}") for i in range(123, 126)] + assets = [Asset(uri=f"s3://abc/{i}", name=f"asset_{i}") for i in range(123, 126)] condition = class_(*assets) statuses = {asset.uri: status for asset, status in zip(assets, inputs)} @@ -274,31 +274,31 @@ def test_asset_logical_conditions_evaluation_and_serialization(inputs, scenario, ( (False, True, True), False, - ), # AssetAll requires all conditions to be True, but d1 is False + ), # AssetAll requires all conditions to be True, but asset1 is False ((True, True, True), True), # All conditions are True ( (True, False, True), True, - ), # d1 is True, and AssetAny condition (d2 or d3 being True) is met + ), # asset1 is True, and AssetAny condition (asset2 or asset3 being True) is met ( (True, False, False), False, - ), # d1 is True, but neither d2 nor d3 meet the AssetAny condition + ), # asset1 is True, but neither asset2 nor asset3 meet the AssetAny condition ], ) def test_nested_asset_conditions_with_serialization(status_values, expected_evaluation): # Define assets - d1 = Asset(uri="s3://abc/123") - d2 = Asset(uri="s3://abc/124") - d3 = Asset(uri="s3://abc/125") + asset1 = Asset(uri="s3://abc/123") + asset2 = Asset(uri="s3://abc/124") + asset3 = Asset(uri="s3://abc/125") - # Create a nested condition: AssetAll with d1 and AssetAny with d2 and d3 - nested_condition = AssetAll(d1, AssetAny(d2, d3)) + # Create a nested condition: AssetAll with asset1 and AssetAny with asset2 and asset3 + nested_condition = AssetAll(asset1, AssetAny(asset2, asset3)) statuses = { - d1.uri: status_values[0], - d2.uri: status_values[1], - d3.uri: status_values[2], + asset1.uri: status_values[0], + asset2.uri: status_values[1], + asset3.uri: status_values[2], } assert nested_condition.evaluate(statuses) == expected_evaluation, "Initial evaluation mismatch" @@ -314,7 +314,7 @@ def test_nested_asset_conditions_with_serialization(status_values, expected_eval @pytest.fixture def create_test_assets(session): """Fixture to create test assets and corresponding models.""" - assets = [Asset(uri=f"hello{i}") for i in range(1, 3)] + assets = [Asset(uri=f"test://asset{i}", name=f"hello{i}") for i in range(1, 3)] for asset in assets: session.add(AssetModel(uri=asset.uri)) session.commit() @@ -380,17 +380,17 @@ def test_asset_dag_run_queue_processing(session, clear_assets, dag_maker, create @pytest.mark.usefixtures("clear_assets") def test_dag_with_complex_asset_condition(session, dag_maker): # Create Asset instances - d1 = Asset(uri="hello1") - d2 = Asset(uri="hello2") + asset1 = Asset(uri="test://asset1", name="hello1") + asset2 = Asset(uri="test://asset2", name="hello2") # Create and add AssetModel instances to the session - am1 = AssetModel(uri=d1.uri) - am2 = AssetModel(uri=d2.uri) + am1 = AssetModel(uri=asset1.uri, name=asset1.name, group="asset") + am2 = AssetModel(uri=asset2.uri, name=asset2.name, group="asset") session.add_all([am1, am2]) session.commit() # Setup a DAG with complex asset triggers (AssetAny with AssetAll) - with dag_maker(schedule=AssetAny(d1, AssetAll(d2, d1))) as dag: + with dag_maker(schedule=AssetAny(asset1, AssetAll(asset2, asset1))) as dag: EmptyOperator(task_id="hello") assert isinstance( @@ -442,11 +442,11 @@ def assets_equal(a1: BaseAsset, a2: BaseAsset) -> bool: return False -asset1 = Asset(uri="s3://bucket1/data1") -asset2 = Asset(uri="s3://bucket2/data2") -asset3 = Asset(uri="s3://bucket3/data3") -asset4 = Asset(uri="s3://bucket4/data4") -asset5 = Asset(uri="s3://bucket5/data5") +asset1 = Asset(uri="s3://bucket1/data1", name="asset-1") +asset2 = Asset(uri="s3://bucket2/data2", name="asset-2") +asset3 = Asset(uri="s3://bucket3/data3", name="asset-3") +asset4 = Asset(uri="s3://bucket4/data4", name="asset-4") +asset5 = Asset(uri="s3://bucket5/data5", name="asset-5") test_cases = [ (lambda: asset1, asset1), @@ -579,21 +579,27 @@ def test_normalize_uri_valid_uri(): @pytest.mark.usefixtures("clear_assets") class TestAssetAliasCondition: @pytest.fixture - def asset_1(self, session): + def asset_model(self, session): """Example asset links to asset alias resolved_asset_alias_2.""" - asset_uri = "test_uri" - asset_1 = AssetModel(id=1, uri=asset_uri) - - session.add(asset_1) + asset_model = AssetModel( + id=1, + uri="test://asset1/", + name="test_name", + group="asset", + ) + + session.add(asset_model) session.commit() - return asset_1 + return asset_model @pytest.fixture def asset_alias_1(self, session): """Example asset alias links to no assets.""" - alias_name = "test_name" - asset_alias_model = AssetAliasModel(name=alias_name) + asset_alias_model = AssetAliasModel( + name="test_name", + group="test", + ) session.add(asset_alias_model) session.commit() @@ -601,35 +607,34 @@ def asset_alias_1(self, session): return asset_alias_model @pytest.fixture - def resolved_asset_alias_2(self, session, asset_1): + def resolved_asset_alias_2(self, session, asset_model): """Example asset alias links to asset asset_alias_1.""" - asset_name = "test_name_2" - asset_alias_2 = AssetAliasModel(name=asset_name) - asset_alias_2.assets.append(asset_1) + asset_alias_2 = AssetAliasModel(name="test_name_2") + asset_alias_2.assets.append(asset_model) session.add(asset_alias_2) session.commit() return asset_alias_2 - def test_init(self, asset_alias_1, asset_1, resolved_asset_alias_2): - cond = AssetAliasCondition(name=asset_alias_1.name) + def test_init(self, asset_alias_1, asset_model, resolved_asset_alias_2): + cond = AssetAliasCondition(name=asset_alias_1.name, group=asset_alias_1.group) assert cond.objects == [] - cond = AssetAliasCondition(name=resolved_asset_alias_2.name) - assert cond.objects == [Asset(uri=asset_1.uri)] + cond = AssetAliasCondition(name=resolved_asset_alias_2.name, group=resolved_asset_alias_2.group) + assert cond.objects == [Asset(uri=asset_model.uri, name=asset_model.name)] def test_as_expression(self, asset_alias_1, resolved_asset_alias_2): - for assset_alias in (asset_alias_1, resolved_asset_alias_2): - cond = AssetAliasCondition(assset_alias.name) - assert cond.as_expression() == {"alias": assset_alias.name} + for asset_alias in (asset_alias_1, resolved_asset_alias_2): + cond = AssetAliasCondition(name=asset_alias.name, group=asset_alias.group) + assert cond.as_expression() == {"alias": {"name": asset_alias.name, "group": asset_alias.group}} - def test_evalute(self, asset_alias_1, resolved_asset_alias_2, asset_1): - cond = AssetAliasCondition(asset_alias_1.name) - assert cond.evaluate({asset_1.uri: True}) is False + def test_evalute(self, asset_alias_1, resolved_asset_alias_2, asset_model): + cond = AssetAliasCondition(name=asset_alias_1.name, group=asset_alias_1.group) + assert cond.evaluate({asset_model.uri: True}) is False - cond = AssetAliasCondition(resolved_asset_alias_2.name) - assert cond.evaluate({asset_1.uri: True}) is True + cond = AssetAliasCondition(name=resolved_asset_alias_2.name, group=resolved_asset_alias_2.group) + assert cond.evaluate({asset_model.uri: True}) is True class TestAssetSubclasses: From 4f7265b4933a70c0247c9a5b8e1bc0f3261468b5 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 28 Nov 2024 10:09:18 +0800 Subject: [PATCH 23/51] test(api_connexion/endpoints/asset): extend asset test cases to cover name, uri, group --- tests/api_connexion/endpoints/test_asset_endpoint.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/api_connexion/endpoints/test_asset_endpoint.py b/tests/api_connexion/endpoints/test_asset_endpoint.py index db064ac5b443e..57bea9c6643e2 100644 --- a/tests/api_connexion/endpoints/test_asset_endpoint.py +++ b/tests/api_connexion/endpoints/test_asset_endpoint.py @@ -80,6 +80,8 @@ def _create_asset(self, session): asset_model = AssetModel( id=1, uri="s3://bucket/key", + name="asset-name", + group="asset", extra={"foo": "bar"}, created_at=timezone.parse(self.default_time), updated_at=timezone.parse(self.default_time), @@ -103,6 +105,8 @@ def test_should_respond_200(self, session): assert response.json == { "id": 1, "uri": "s3://bucket/key", + "name": "asset-name", + "group": "asset", "extra": {"foo": "bar"}, "created_at": self.default_time, "updated_at": self.default_time, @@ -136,6 +140,8 @@ def test_should_respond_200(self, session): AssetModel( id=i, uri=f"s3://bucket/key/{i}", + name=f"asset_{i}", + group="asset", extra={"foo": "bar"}, created_at=timezone.parse(self.default_time), updated_at=timezone.parse(self.default_time), @@ -156,6 +162,8 @@ def test_should_respond_200(self, session): { "id": 1, "uri": "s3://bucket/key/1", + "name": "asset_1", + "group": "asset", "extra": {"foo": "bar"}, "created_at": self.default_time, "updated_at": self.default_time, @@ -166,6 +174,8 @@ def test_should_respond_200(self, session): { "id": 2, "uri": "s3://bucket/key/2", + "name": "asset_2", + "group": "asset", "extra": {"foo": "bar"}, "created_at": self.default_time, "updated_at": self.default_time, From 3d1e510873ac05d136e80c0056d71d0b46220eab Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 28 Nov 2024 10:09:29 +0800 Subject: [PATCH 24/51] test: add missing session --- tests/api_fastapi/core_api/routes/public/test_assets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/api_fastapi/core_api/routes/public/test_assets.py b/tests/api_fastapi/core_api/routes/public/test_assets.py index a20353d32f86d..9218cbbf820cd 100644 --- a/tests/api_fastapi/core_api/routes/public/test_assets.py +++ b/tests/api_fastapi/core_api/routes/public/test_assets.py @@ -722,7 +722,7 @@ def test_should_respond_200(self, test_client, session): } def test_invalid_attr_not_allowed(self, test_client, session): - self.create_assets() + self.create_assets(session) event_invalid_payload = {"asset_uri": "s3://bucket/key/1", "extra": {"foo": "bar"}, "fake": {}} response = test_client.post("/public/assets/events", json=event_invalid_payload) @@ -731,7 +731,7 @@ def test_invalid_attr_not_allowed(self, test_client, session): @pytest.mark.usefixtures("time_freezer") @pytest.mark.enable_redact def test_should_mask_sensitive_extra(self, test_client, session): - self.create_assets() + self.create_assets(session) event_payload = {"uri": "s3://bucket/key/1", "extra": {"password": "bar"}} response = test_client.post("/public/assets/events", json=event_payload) assert response.status_code == 200 From 3920a79a79bcb9a494e25c563144e1109881e691 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 28 Nov 2024 15:21:01 +0800 Subject: [PATCH 25/51] test(www/views/asset): extend asset test cases to cover name, uri, group --- tests/www/views/test_views_asset.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/www/views/test_views_asset.py b/tests/www/views/test_views_asset.py index e4fda0aeac662..2e6668f134ad3 100644 --- a/tests/www/views/test_views_asset.py +++ b/tests/www/views/test_views_asset.py @@ -42,7 +42,10 @@ def _cleanup(self): @pytest.fixture def create_assets(self, session): def create(indexes): - assets = [AssetModel(id=i, uri=f"s3://bucket/key/{i}") for i in indexes] + assets = [ + AssetModel(id=i, uri=f"s3://bucket/key/{i}", name=f"asset-{i}", group="asset") + for i in indexes + ] session.add_all(assets) session.flush() session.add_all(AssetActive.for_asset(a) for a in assets) @@ -220,7 +223,7 @@ def test_search_uri_pattern(self, admin_client, create_assets, session): @pytest.mark.need_serialized_dag def test_correct_counts_update(self, admin_client, session, dag_maker, app, monkeypatch): with monkeypatch.context() as m: - assets = [Asset(uri=f"s3://bucket/key/{i}") for i in [1, 2, 3, 4, 5]] + assets = [Asset(uri=f"s3://bucket/key/{i}", name=f"asset-{i}") for i in range(1, 6)] # DAG that produces asset #1 with dag_maker(dag_id="upstream", schedule=None, serialized=True, session=session): @@ -399,7 +402,9 @@ def test_should_return_max_if_req_above(self, admin_client, create_assets, sessi class TestGetAssetNextRunSummary(TestAssetEndpoint): def test_next_run_asset_summary(self, dag_maker, admin_client): - with dag_maker(dag_id="upstream", schedule=[Asset(uri="s3://bucket/key/1")], serialized=True): + with dag_maker( + dag_id="upstream", schedule=[Asset(uri="s3://bucket/key/1", name="asset-1")], serialized=True + ): EmptyOperator(task_id="task1") response = admin_client.post("/next_run_assets_summary", data={"dag_ids": ["upstream"]}) From 42bb0d16a5611964a8329fbc998fbeed549d7e85 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 28 Nov 2024 15:23:36 +0800 Subject: [PATCH 26/51] test(models/seraialized_dag): extend asset test cases to cover name, uri, group --- tests/models/test_serialized_dag.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/models/test_serialized_dag.py b/tests/models/test_serialized_dag.py index 011e785626e2b..41632fe0458a3 100644 --- a/tests/models/test_serialized_dag.py +++ b/tests/models/test_serialized_dag.py @@ -243,16 +243,16 @@ def test_order_of_deps_is_consistent(self): dag_id="example", start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), schedule=[ - Asset("1"), - Asset("2"), - Asset("3"), - Asset("4"), - Asset("5"), + Asset(uri="test://asset1", name="1"), + Asset(uri="test://asset2", name="2"), + Asset(uri="test://asset3", name="3"), + Asset(uri="test://asset4", name="4"), + Asset(uri="test://asset5", name="5"), ], ) as dag6: BashOperator( task_id="any", - outlets=[Asset("0*"), Asset("6*")], + outlets=[Asset(uri="test://asset0", name="0*"), Asset(uri="test://asset6", name="6*")], bash_command="sleep 5", ) deps_order = [x["dependency_id"] for x in SerializedDAG.serialize_dag(dag6)["dag_dependencies"]] From 3a68576914d1b37f7c5cac45e047adc711949aa5 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 28 Nov 2024 15:55:15 +0800 Subject: [PATCH 27/51] test(lineage/hook): extend asset test cases to cover name, uri, group --- tests/lineage/test_hook.py | 49 ++++++++++++++++++++++++++++++-------- 1 file changed, 39 insertions(+), 10 deletions(-) diff --git a/tests/lineage/test_hook.py b/tests/lineage/test_hook.py index ec6390c77a555..f66f6c2bf9f83 100644 --- a/tests/lineage/test_hook.py +++ b/tests/lineage/test_hook.py @@ -46,13 +46,26 @@ def test_are_assets_collected(self): assert self.collector.collected_assets == HookLineage() input_hook = BaseHook() output_hook = BaseHook() - self.collector.add_input_asset(input_hook, uri="s3://in_bucket/file") - self.collector.add_output_asset(output_hook, uri="postgres://example.com:5432/database/default/table") + self.collector.add_input_asset(input_hook, uri="s3://in_bucket/file", name="asset-1", group="test") + self.collector.add_output_asset( + output_hook, + uri="postgres://example.com:5432/database/default/table", + ) assert self.collector.collected_assets == HookLineage( - [AssetLineageInfo(asset=Asset("s3://in_bucket/file"), count=1, context=input_hook)], [ AssetLineageInfo( - asset=Asset("postgres://example.com:5432/database/default/table"), + asset=Asset(uri="s3://in_bucket/file", name="asset-1", group="test"), + count=1, + context=input_hook, + ) + ], + [ + AssetLineageInfo( + asset=Asset( + uri="postgres://example.com:5432/database/default/table", + name="postgres://example.com:5432/database/default/table", + group="asset", + ), count=1, context=output_hook, ) @@ -68,7 +81,7 @@ def test_add_input_asset(self, mock_asset): self.collector.add_input_asset(hook, uri="test_uri") assert next(iter(self.collector._inputs.values())) == (asset, hook) - mock_asset.assert_called_once_with(uri="test_uri", extra=None) + mock_asset.assert_called_once_with(uri="test_uri") def test_grouping_assets(self): hook_1 = MagicMock() @@ -95,18 +108,29 @@ def test_grouping_assets(self): @patch("airflow.lineage.hook.ProvidersManager") def test_create_asset(self, mock_providers_manager): def create_asset(arg1, arg2="default", extra=None): - return Asset(uri=f"myscheme://{arg1}/{arg2}", extra=extra or {}) + return Asset( + uri=f"myscheme://{arg1}/{arg2}", name=f"asset-{arg1}", group="test", extra=extra or {} + ) mock_providers_manager.return_value.asset_factories = {"myscheme": create_asset} assert self.collector.create_asset( - scheme="myscheme", uri=None, asset_kwargs={"arg1": "value_1"}, asset_extra=None - ) == Asset("myscheme://value_1/default") + scheme="myscheme", + uri=None, + name=None, + group=None, + asset_kwargs={"arg1": "value_1"}, + asset_extra=None, + ) == Asset(uri="myscheme://value_1/default", name="asset-value_1", group="test") assert self.collector.create_asset( scheme="myscheme", uri=None, + name=None, + group=None, asset_kwargs={"arg1": "value_1", "arg2": "value_2"}, asset_extra={"key": "value"}, - ) == Asset("myscheme://value_1/value_2", extra={"key": "value"}) + ) == Asset( + uri="myscheme://value_1/value_2", name="asset-value_1", group="test", extra={"key": "value"} + ) @patch("airflow.lineage.hook.ProvidersManager") def test_create_asset_no_factory(self, mock_providers_manager): @@ -117,7 +141,12 @@ def test_create_asset_no_factory(self, mock_providers_manager): assert ( self.collector.create_asset( - scheme=test_scheme, uri=None, asset_kwargs=test_kwargs, asset_extra=None + scheme=test_scheme, + uri=None, + name=None, + group=None, + asset_kwargs=test_kwargs, + asset_extra=None, ) is None ) From 4417d8f50e5236de2fd9dc994c4c69a885148fda Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 28 Nov 2024 16:12:13 +0800 Subject: [PATCH 28/51] test(io/path): extend asset test cases to cover name, uri, group --- tests/io/test_path.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/io/test_path.py b/tests/io/test_path.py index fd9844bc4bc58..29e67ca84649d 100644 --- a/tests/io/test_path.py +++ b/tests/io/test_path.py @@ -405,7 +405,7 @@ def test_asset(self): p = "s3" f = "bucket/object" - i = Asset(uri=f"{p}://{f}", extra={"foo": "bar"}) + i = Asset(uri=f"{p}://{f}", name="test-asset", extra={"foo": "bar"}) o = ObjectStoragePath(i) assert o.protocol == p assert o.path == f From 277813f814dbf20dab61cc48a462d1a34d6f4fab Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 28 Nov 2024 17:51:51 +0800 Subject: [PATCH 29/51] test(jobs): enhance test_activate_referenced_assets_with_no_existing_warning to cover extra edge case --- tests/jobs/test_scheduler_job.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index ccbd32cca2741..6b413135bfc3d 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -6302,11 +6302,13 @@ def test_activate_referenced_assets_with_no_existing_warning(self, session): asset1 = Asset(name=asset1_name, uri="s3://bucket/key/1", extra=asset_extra) asset1_1 = Asset(name=asset1_name, uri="it's duplicate", extra=asset_extra) - dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule=[asset1, asset1_1]) + asset1_2 = Asset(name="it's also a duplicate", uri="s3://bucket/key/1", extra=asset_extra) + dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule=[asset1, asset1_1, asset1_2]) DAG.bulk_write_to_db([dag1], session=session) asset_models = session.scalars(select(AssetModel)).all() + assert len(asset_models) == 3 SchedulerJobRunner._activate_referenced_assets(asset_models, session=session) session.flush() @@ -6317,8 +6319,10 @@ def test_activate_referenced_assets_with_no_existing_warning(self, session): ) ) assert dag_warning.message == ( - "Cannot activate asset AssetModel(name='asset1', uri=\"it's duplicate\", extra={'foo': 'bar'}); " - "name is already associated to 's3://bucket/key/1'" + 'Cannot activate asset AssetModel(name="it\'s also a duplicate",' + " uri='s3://bucket/key/1', extra={'foo': 'bar'}); uri is already associated to 'asset1'\n" + "Cannot activate asset AssetModel(name='asset1', uri" + "=\"it's duplicate\", extra={'foo': 'bar'}); name is already associated to 's3://bucket/key/1'" ) def test_activate_referenced_assets_with_existing_warnings(self, session): From 5bfb68b6175262d1cc96cfb23525f26cf6feb9f8 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 7 Nov 2024 15:40:33 +0800 Subject: [PATCH 30/51] fix(serialization): serialize both name, uri and group for Asset --- airflow/serialization/serialized_objects.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 1a13430e2fcb5..8c5125eb8ef8e 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -256,7 +256,7 @@ def encode_asset_condition(var: BaseAsset) -> dict[str, Any]: :meta private: """ if isinstance(var, Asset): - return {"__type": DAT.ASSET, "name": var.name, "uri": var.uri, "extra": var.extra} + return {"__type": DAT.ASSET, "name": var.name, "uri": var.uri, "group": var.group, "extra": var.extra} if isinstance(var, AssetAlias): return {"__type": DAT.ASSET_ALIAS, "name": var.name} if isinstance(var, AssetAll): @@ -274,7 +274,7 @@ def decode_asset_condition(var: dict[str, Any]) -> BaseAsset: """ dat = var["__type"] if dat == DAT.ASSET: - return Asset(uri=var["uri"], name=var["name"], extra=var["extra"]) + return Asset(name=var["name"], uri=var["uri"], group=var["group"], extra=var["extra"]) if dat == DAT.ASSET_ALL: return AssetAll(*(decode_asset_condition(x) for x in var["objects"])) if dat == DAT.ASSET_ANY: From f0aaf1d370d7db545194b91298721788fafc3abf Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 13 Nov 2024 18:36:12 +0800 Subject: [PATCH 31/51] fix(assets): extend Asset as_expression methods to include name, group fields (also AssetAlias group field) --- airflow/serialization/serialized_objects.py | 52 +++++++++++++++---- airflow/timetables/simple.py | 4 +- .../airflow/sdk/definitions/asset/__init__.py | 2 +- tests/www/views/test_views_grid.py | 7 ++- 4 files changed, 52 insertions(+), 13 deletions(-) diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 8c5125eb8ef8e..d0b890f20009e 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -44,7 +44,11 @@ from airflow.models.connection import Connection from airflow.models.dag import DAG, DagModel from airflow.models.dagrun import DagRun -from airflow.models.expandinput import EXPAND_INPUT_EMPTY, create_expand_input, get_map_type_key +from airflow.models.expandinput import ( + EXPAND_INPUT_EMPTY, + create_expand_input, + get_map_type_key, +) from airflow.models.mappedoperator import MappedOperator from airflow.models.param import Param, ParamsDict from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance @@ -213,7 +217,9 @@ def _get_registered_timetable(importable_string: str) -> type[Timetable] | None: return None -def _get_registered_priority_weight_strategy(importable_string: str) -> type[PriorityWeightStrategy] | None: +def _get_registered_priority_weight_strategy( + importable_string: str, +) -> type[PriorityWeightStrategy] | None: from airflow import plugins_manager if importable_string in airflow_priority_weight_strategies: @@ -256,13 +262,25 @@ def encode_asset_condition(var: BaseAsset) -> dict[str, Any]: :meta private: """ if isinstance(var, Asset): - return {"__type": DAT.ASSET, "name": var.name, "uri": var.uri, "group": var.group, "extra": var.extra} + return { + "__type": DAT.ASSET, + "name": var.name, + "uri": var.uri, + "group": var.group, + "extra": var.extra, + } if isinstance(var, AssetAlias): return {"__type": DAT.ASSET_ALIAS, "name": var.name} if isinstance(var, AssetAll): - return {"__type": DAT.ASSET_ALL, "objects": [encode_asset_condition(x) for x in var.objects]} + return { + "__type": DAT.ASSET_ALL, + "objects": [encode_asset_condition(x) for x in var.objects], + } if isinstance(var, AssetAny): - return {"__type": DAT.ASSET_ANY, "objects": [encode_asset_condition(x) for x in var.objects]} + return { + "__type": DAT.ASSET_ANY, + "objects": [encode_asset_condition(x) for x in var.objects], + } raise ValueError(f"serialization not implemented for {type(var).__name__!r}") @@ -586,7 +604,9 @@ def _is_excluded(cls, var: Any, attrname: str, instance: Any) -> bool: @classmethod def serialize_to_json( - cls, object_to_serialize: BaseOperator | MappedOperator | DAG, decorated_fields: set + cls, + object_to_serialize: BaseOperator | MappedOperator | DAG, + decorated_fields: set, ) -> dict[str, Any]: """Serialize an object to JSON.""" serialized_object: dict[str, Any] = {} @@ -696,7 +716,11 @@ def serialize( elif isinstance(var, (KeyError, AttributeError)): return cls._encode( cls.serialize( - {"exc_cls_name": var.__class__.__name__, "args": [var.args], "kwargs": {}}, + { + "exc_cls_name": var.__class__.__name__, + "args": [var.args], + "kwargs": {}, + }, use_pydantic_models=use_pydantic_models, strict=strict, ), @@ -704,7 +728,11 @@ def serialize( ) elif isinstance(var, BaseTrigger): return cls._encode( - cls.serialize(var.serialize(), use_pydantic_models=use_pydantic_models, strict=strict), + cls.serialize( + var.serialize(), + use_pydantic_models=use_pydantic_models, + strict=strict, + ), type_=DAT.BASE_TRIGGER, ) elif callable(var): @@ -1069,7 +1097,7 @@ def detect_task_dependencies(task: Operator) -> list[DagDependency]: ) ) elif isinstance(obj, AssetAlias): - cond = AssetAliasCondition(obj.name) + cond = AssetAliasCondition(name=obj.name, group=obj.group) deps.extend(cond.iter_dag_dependencies(source=task.dag_id, target="")) return deps @@ -1298,7 +1326,11 @@ def populate_operator(cls, op: Operator, encoded_op: dict[str, Any]) -> None: # The case for "If OperatorLinks are defined in the operator that is being Serialized" # is handled in the deserialization loop where it matches k == "_operator_extra_links" if op_extra_links_from_plugin and "_operator_extra_links" not in encoded_op: - setattr(op, "operator_extra_links", list(op_extra_links_from_plugin.values())) + setattr( + op, + "operator_extra_links", + list(op_extra_links_from_plugin.values()), + ) for k, v in encoded_op.items(): # python_callable_name only serves to detect function name changes diff --git a/airflow/timetables/simple.py b/airflow/timetables/simple.py index f282c7fe67f8a..ac16eba983475 100644 --- a/airflow/timetables/simple.py +++ b/airflow/timetables/simple.py @@ -170,7 +170,9 @@ def __init__(self, assets: BaseAsset) -> None: super().__init__() self.asset_condition = assets if isinstance(self.asset_condition, AssetAlias): - self.asset_condition = AssetAliasCondition(self.asset_condition.name) + self.asset_condition = AssetAliasCondition( + name=self.asset_condition.name, group=self.asset_condition.group + ) if not next(self.asset_condition.iter_assets(), False): self._summary = AssetTriggeredTimetable.UNRESOLVED_ALIAS_SUMMARY diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py index da44a2c6b42ed..76fdeee9f30ad 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -350,7 +350,7 @@ def as_expression(self) -> Any: :meta private: """ - return self.uri + return {"asset": {"uri": self.uri, "name": self.name, "group": self.group}} def iter_assets(self) -> Iterator[tuple[str, Asset]]: yield self.uri, self diff --git a/tests/www/views/test_views_grid.py b/tests/www/views/test_views_grid.py index 3a34ca1b6dd89..e2181aa702d25 100644 --- a/tests/www/views/test_views_grid.py +++ b/tests/www/views/test_views_grid.py @@ -508,7 +508,12 @@ def test_next_run_assets(admin_client, dag_maker, session, app, monkeypatch): assert resp.status_code == 200, resp.json assert resp.json == { - "asset_expression": {"all": ["s3://bucket/key/1", "s3://bucket/key/2"]}, + "asset_expression": { + "all": [ + {"asset": {"uri": "s3://bucket/key/1", "name": "name_1", "group": "test-group"}}, + {"asset": {"uri": "s3://bucket/key/2", "name": "name_2", "group": "test-group"}}, + ] + }, "events": [ {"id": asset1_id, "uri": "s3://bucket/key/1", "lastUpdate": "2022-08-02T02:00:00+00:00"}, {"id": asset2_id, "uri": "s3://bucket/key/2", "lastUpdate": None}, From 8bb18624c5610cd2254821a007bbe2f302bfdce2 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 27 Nov 2024 18:37:55 +0800 Subject: [PATCH 32/51] fix(serialization/serialized_objects): fix missing AssetAlias.group serialization --- airflow/serialization/serialized_objects.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index d0b890f20009e..d93fcf183f17a 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -270,7 +270,7 @@ def encode_asset_condition(var: BaseAsset) -> dict[str, Any]: "extra": var.extra, } if isinstance(var, AssetAlias): - return {"__type": DAT.ASSET_ALIAS, "name": var.name} + return {"__type": DAT.ASSET_ALIAS, "name": var.name, "group": var.group} if isinstance(var, AssetAll): return { "__type": DAT.ASSET_ALL, @@ -298,7 +298,7 @@ def decode_asset_condition(var: dict[str, Any]) -> BaseAsset: if dat == DAT.ASSET_ANY: return AssetAny(*(decode_asset_condition(x) for x in var["objects"])) if dat == DAT.ASSET_ALIAS: - return AssetAlias(name=var["name"]) + return AssetAlias(name=var["name"], group=var["group"]) raise ValueError(f"deserialization not implemented for DAT {dat!r}") @@ -673,7 +673,11 @@ def serialize( return cls._encode(json_pod, type_=DAT.POD) elif isinstance(var, OutletEventAccessors): return cls._encode( - cls.serialize(var._dict, strict=strict, use_pydantic_models=use_pydantic_models), # type: ignore[attr-defined] + cls.serialize( + var._dict, # type: ignore[attr-defined] + strict=strict, + use_pydantic_models=use_pydantic_models, + ), type_=DAT.ASSET_EVENT_ACCESSORS, ) elif isinstance(var, OutletEventAccessor): From 0785ff2182fc0b37269a284fd323bb7caf3941d2 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 27 Nov 2024 19:38:26 +0800 Subject: [PATCH 33/51] fix(serialization): change dependency_id to use name instead of uri --- airflow/serialization/serialized_objects.py | 2 +- task_sdk/src/airflow/sdk/definitions/asset/__init__.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index d93fcf183f17a..f78a2b78b8811 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -1097,7 +1097,7 @@ def detect_task_dependencies(task: Operator) -> list[DagDependency]: source=task.dag_id, target="asset", dependency_type="asset", - dependency_id=obj.uri, + dependency_id=obj.name, ) ) elif isinstance(obj, AssetAlias): diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py index 76fdeee9f30ad..6e57ffdb4e900 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -371,7 +371,7 @@ def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDepe source=source or "asset", target=target or "asset", dependency_type="asset", - dependency_id=self.uri, + dependency_id=self.name, ) @@ -544,18 +544,18 @@ def iter_dag_dependencies(self, *, source: str = "", target: str = "") -> Iterat if self.objects: for obj in self.objects: asset = cast(Asset, obj) - uri = asset.uri + asset_name = asset.name # asset yield DagDependency( source=f"asset-alias:{self.name}" if source else "asset", target="asset" if source else f"asset-alias:{self.name}", dependency_type="asset", - dependency_id=uri, + dependency_id=asset_name, ) # asset alias yield DagDependency( - source=source or f"asset:{uri}", - target=target or f"asset:{uri}", + source=source or f"asset:{asset_name}", + target=target or f"asset:{asset_name}", dependency_type="asset-alias", dependency_id=self.name, ) From 22fb3f1a062be7ac95fec90dfb4d139f7d56c43b Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 27 Nov 2024 21:29:08 +0800 Subject: [PATCH 34/51] feat(api_connexion/schemas/asset): add name, group to asset schema and group to asset alias schema --- airflow/api_connexion/schemas/asset_schema.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/airflow/api_connexion/schemas/asset_schema.py b/airflow/api_connexion/schemas/asset_schema.py index e83c4f1b42797..078ebb3e75866 100644 --- a/airflow/api_connexion/schemas/asset_schema.py +++ b/airflow/api_connexion/schemas/asset_schema.py @@ -70,6 +70,7 @@ class Meta: id = auto_field() name = auto_field() + group = auto_field() class AssetSchema(SQLAlchemySchema): @@ -82,6 +83,8 @@ class Meta: id = auto_field() uri = auto_field() + name = auto_field() + group = auto_field() extra = JsonObjectField() created_at = auto_field() updated_at = auto_field() From f020705cd82386b8304d7c84f7d93fa7e119799e Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 27 Nov 2024 22:03:18 +0800 Subject: [PATCH 35/51] feat(assets/manager): filter asset by name, uri, group instead of uri only --- airflow/assets/manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/assets/manager.py b/airflow/assets/manager.py index 40bc97b8134c7..8ed2996933abc 100644 --- a/airflow/assets/manager.py +++ b/airflow/assets/manager.py @@ -121,7 +121,7 @@ def register_asset_change( """ asset_model = session.scalar( select(AssetModel) - .where(AssetModel.uri == asset.uri) + .where(AssetModel.name == asset.name, AssetModel.uri == asset.uri) .options( joinedload(AssetModel.aliases), joinedload(AssetModel.consuming_dags).joinedload(DagScheduleAssetReference.dag), From cc54512b35f6d006dc4305212112b4a72bd86261 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 27 Nov 2024 22:11:51 +0800 Subject: [PATCH 36/51] style(assets/manager): rename argument asset in _add_asset_alias_association as asset_model --- airflow/assets/manager.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/airflow/assets/manager.py b/airflow/assets/manager.py index 8ed2996933abc..364d01607e5c4 100644 --- a/airflow/assets/manager.py +++ b/airflow/assets/manager.py @@ -86,16 +86,16 @@ def _add_one(asset_alias: AssetAlias) -> AssetAliasModel: def _add_asset_alias_association( cls, alias_names: Collection[str], - asset: AssetModel, + asset_model: AssetModel, *, session: Session, ) -> None: - already_related = {m.name for m in asset.aliases} + already_related = {m.name for m in asset_model.aliases} existing_aliases = { m.name: m for m in session.scalars(select(AssetAliasModel).where(AssetAliasModel.name.in_(alias_names))) } - asset.aliases.extend( + asset_model.aliases.extend( existing_aliases.get(name, AssetAliasModel(name=name)) for name in alias_names if name not in already_related @@ -131,7 +131,9 @@ def register_asset_change( cls.logger().warning("AssetModel %s not found", asset) return None - cls._add_asset_alias_association({alias.name for alias in aliases}, asset_model, session=session) + cls._add_asset_alias_association( + alias_names={alias.name for alias in aliases}, asset_model=asset_model, session=session + ) event_kwargs = { "asset_id": asset_model.id, From 6b0fd18222ee4c3045e8dff84b7e6390a09515e8 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 27 Nov 2024 23:15:48 +0800 Subject: [PATCH 37/51] fix(asset): use name to evalute instead of uri --- airflow/models/dag.py | 2 +- .../airflow/sdk/definitions/asset/__init__.py | 2 +- task_sdk/tests/defintions/test_asset.py | 32 +++++++++---------- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 34ed6694a0695..20fdc2d3d5190 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -2297,7 +2297,7 @@ def dag_ready(dag_id: str, cond: BaseAsset, statuses: dict) -> bool | None: del all_records dag_statuses = {} for dag_id, records in by_dag.items(): - dag_statuses[dag_id] = {x.asset.uri: True for x in records} + dag_statuses[dag_id] = {x.asset.name: True for x in records} ser_dags = SerializedDagModel.get_latest_serialized_dags( dag_ids=list(dag_statuses.keys()), session=session ) diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py index 6e57ffdb4e900..10e4d7db28a70 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -359,7 +359,7 @@ def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: return iter(()) def evaluate(self, statuses: dict[str, bool]) -> bool: - return statuses.get(self.uri, False) + return statuses.get(self.name, False) def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]: """ diff --git a/task_sdk/tests/defintions/test_asset.py b/task_sdk/tests/defintions/test_asset.py index 1a580061f65ec..39579483c2404 100644 --- a/task_sdk/tests/defintions/test_asset.py +++ b/task_sdk/tests/defintions/test_asset.py @@ -192,8 +192,8 @@ def test_asset_iter_asset_aliases(): def test_asset_evaluate(): - assert asset1.evaluate({"s3://bucket1/data1": True}) is True - assert asset1.evaluate({"s3://bucket1/data1": False}) is False + assert asset1.evaluate({"asset-1": True}) is True + assert asset1.evaluate({"asset-1": False}) is False def test_asset_any_operations(): @@ -219,8 +219,8 @@ def test_assset_boolean_condition_evaluate_iter(): """ any_condition = AssetAny(asset1, asset2) all_condition = AssetAll(asset1, asset2) - assert any_condition.evaluate({"s3://bucket1/data1": False, "s3://bucket2/data2": True}) is True - assert all_condition.evaluate({"s3://bucket1/data1": True, "s3://bucket2/data2": False}) is False + assert any_condition.evaluate({"asset-1": False, "asset-2": True}) is True + assert all_condition.evaluate({"asset-1": True, "asset-2": False}) is False # Testing iter_assets indirectly through the subclasses assets_any = dict(any_condition.iter_assets()) @@ -257,7 +257,7 @@ def test_asset_logical_conditions_evaluation_and_serialization(inputs, scenario, assets = [Asset(uri=f"s3://abc/{i}", name=f"asset_{i}") for i in range(123, 126)] condition = class_(*assets) - statuses = {asset.uri: status for asset, status in zip(assets, inputs)} + statuses = {asset.name: status for asset, status in zip(assets, inputs)} assert ( condition.evaluate(statuses) == expected ), f"Condition evaluation failed for inputs {inputs} and scenario '{scenario}'" @@ -288,17 +288,17 @@ def test_asset_logical_conditions_evaluation_and_serialization(inputs, scenario, ) def test_nested_asset_conditions_with_serialization(status_values, expected_evaluation): # Define assets - asset1 = Asset(uri="s3://abc/123") - asset2 = Asset(uri="s3://abc/124") - asset3 = Asset(uri="s3://abc/125") + asset1 = Asset(uri="s3://abc/123", name="asset-1") + asset2 = Asset(uri="s3://abc/124", name="asset-2") + asset3 = Asset(uri="s3://abc/125", name="asset-3") # Create a nested condition: AssetAll with asset1 and AssetAny with asset2 and asset3 nested_condition = AssetAll(asset1, AssetAny(asset2, asset3)) statuses = { - asset1.uri: status_values[0], - asset2.uri: status_values[1], - asset3.uri: status_values[2], + asset1.name: status_values[0], + asset2.name: status_values[1], + asset3.name: status_values[2], } assert nested_condition.evaluate(statuses) == expected_evaluation, "Initial evaluation mismatch" @@ -363,7 +363,7 @@ def test_asset_dag_run_queue_processing(session, clear_assets, dag_maker, create records = session.scalars(select(AssetDagRunQueue)).all() dag_statuses = defaultdict(lambda: defaultdict(bool)) for record in records: - dag_statuses[record.target_dag_id][record.asset.uri] = True + dag_statuses[record.target_dag_id][record.asset.name] = True serialized_dags = session.execute( select(SerializedDagModel).where(SerializedDagModel.dag_id.in_(dag_statuses.keys())) @@ -371,9 +371,9 @@ def test_asset_dag_run_queue_processing(session, clear_assets, dag_maker, create for (serialized_dag,) in serialized_dags: dag = SerializedDAG.deserialize(serialized_dag.data) - for asset_uri, status in dag_statuses[dag.dag_id].items(): + for asset_name, status in dag_statuses[dag.dag_id].items(): cond = dag.timetable.asset_condition - assert cond.evaluate({asset_uri: status}), "DAG trigger evaluation failed" + assert cond.evaluate({asset_name: status}), "DAG trigger evaluation failed" @pytest.mark.db_test @@ -631,10 +631,10 @@ def test_as_expression(self, asset_alias_1, resolved_asset_alias_2): def test_evalute(self, asset_alias_1, resolved_asset_alias_2, asset_model): cond = AssetAliasCondition(name=asset_alias_1.name, group=asset_alias_1.group) - assert cond.evaluate({asset_model.uri: True}) is False + assert cond.evaluate({asset_model.name: True}) is False cond = AssetAliasCondition(name=resolved_asset_alias_2.name, group=resolved_asset_alias_2.group) - assert cond.evaluate({asset_model.uri: True}) is True + assert cond.evaluate({asset_model.name: True}) is True class TestAssetSubclasses: From 4d3a61738649cc692419690b6b091a6011317509 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 28 Nov 2024 10:10:10 +0800 Subject: [PATCH 38/51] fix(api_connexion/endpoints/asset): fix how asset event is fetch in create asset event --- airflow/api_connexion/endpoints/asset_endpoint.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/airflow/api_connexion/endpoints/asset_endpoint.py b/airflow/api_connexion/endpoints/asset_endpoint.py index 1bda4fdb2a218..64930b1249468 100644 --- a/airflow/api_connexion/endpoints/asset_endpoint.py +++ b/airflow/api_connexion/endpoints/asset_endpoint.py @@ -45,7 +45,6 @@ ) from airflow.assets.manager import asset_manager from airflow.models.asset import AssetDagRunQueue, AssetEvent, AssetModel -from airflow.sdk.definitions.asset import Asset from airflow.utils import timezone from airflow.utils.api_migration import mark_fastapi_migration_done from airflow.utils.db import get_query_count @@ -341,15 +340,16 @@ def create_asset_event(session: Session = NEW_SESSION) -> APIResponse: except ValidationError as err: raise BadRequest(detail=str(err)) + # TODO: handle name uri = json_body["asset_uri"] - asset = session.scalar(select(AssetModel).where(AssetModel.uri == uri).limit(1)) - if not asset: + asset_model = session.scalar(select(AssetModel).where(AssetModel.uri == uri).limit(1)) + if not asset_model: raise NotFound(title="Asset not found", detail=f"Asset with uri: '{uri}' not found") timestamp = timezone.utcnow() extra = json_body.get("extra", {}) extra["from_rest_api"] = True asset_event = asset_manager.register_asset_change( - asset=Asset(uri=uri), + asset=asset_model.to_public(), timestamp=timestamp, extra=extra, session=session, From 7c3a33df5065f39c75fb9e10700f85f79b37594f Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 28 Nov 2024 12:08:29 +0800 Subject: [PATCH 39/51] fix(api_fastapi/asset): fix how asset event is fetch in create asset event --- airflow/api_fastapi/core_api/routes/public/assets.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/airflow/api_fastapi/core_api/routes/public/assets.py b/airflow/api_fastapi/core_api/routes/public/assets.py index 70bf5b047bbac..db3fa61767a9a 100644 --- a/airflow/api_fastapi/core_api/routes/public/assets.py +++ b/airflow/api_fastapi/core_api/routes/public/assets.py @@ -51,7 +51,6 @@ from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc from airflow.assets.manager import asset_manager from airflow.models.asset import AssetDagRunQueue, AssetEvent, AssetModel -from airflow.sdk.definitions.asset import Asset from airflow.utils import timezone assets_router = AirflowRouter(tags=["Asset"]) @@ -171,13 +170,13 @@ def create_asset_event( session: SessionDep, ) -> AssetEventResponse: """Create asset events.""" - asset = session.scalar(select(AssetModel).where(AssetModel.uri == body.uri).limit(1)) - if not asset: + asset_model = session.scalar(select(AssetModel).where(AssetModel.uri == body.uri).limit(1)) + if not asset_model: raise HTTPException(status.HTTP_404_NOT_FOUND, f"Asset with uri: `{body.uri}` was not found") timestamp = timezone.utcnow() assets_event = asset_manager.register_asset_change( - asset=Asset(uri=body.uri), + asset=asset_model.to_public(), timestamp=timestamp, extra=body.extra, session=session, From 3e16cea04bfb626d290753265c0251bcf81ae3e2 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 28 Nov 2024 15:55:55 +0800 Subject: [PATCH 40/51] fix(lineage/hook): extend asset realted methods to include name and group --- airflow/lineage/hook.py | 39 +++++++++++++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/airflow/lineage/hook.py b/airflow/lineage/hook.py index 9e5f8f6648229..459d04b7dde8a 100644 --- a/airflow/lineage/hook.py +++ b/airflow/lineage/hook.py @@ -95,24 +95,39 @@ def _generate_key(self, asset: Asset, context: LineageContext) -> str: return f"{asset.uri}_{extra_hash}_{id(context)}" def create_asset( - self, scheme: str | None, uri: str | None, asset_kwargs: dict | None, asset_extra: dict | None + self, + scheme: str | None = None, + uri: str | None = None, + name: str | None = None, + group: str | None = None, + asset_kwargs: dict | None = None, + asset_extra: dict | None = None, ) -> Asset | None: """ Create an asset instance using the provided parameters. This method attempts to create an asset instance using the given parameters. - It first checks if a URI is provided and falls back to using the default asset factory - with the given URI if no other information is available. + It first checks if a URI or a name is provided and falls back to using the default asset factory + with the given URI or name if no other information is available. - If a scheme is provided but no URI, it attempts to find an asset factory that matches + If a scheme is provided but no URI or name, it attempts to find an asset factory that matches the given scheme. If no such factory is found, it logs an error message and returns None. If asset_kwargs is provided, it is used to pass additional parameters to the asset factory. The asset_extra parameter is also passed to the factory as an ``extra`` parameter. """ - if uri: + if uri or name: # Fallback to default factory using the provided URI - return Asset(uri=uri, extra=asset_extra) + kwargs: dict[str, str | dict] = {} + if uri: + kwargs["uri"] = uri + if name: + kwargs["name"] = name + if group: + kwargs["group"] = group + if asset_extra: + kwargs["extra"] = asset_extra + return Asset(**kwargs) # type: ignore[call-overload] if not scheme: self.log.debug( @@ -137,11 +152,15 @@ def add_input_asset( context: LineageContext, scheme: str | None = None, uri: str | None = None, + name: str | None = None, + group: str | None = None, asset_kwargs: dict | None = None, asset_extra: dict | None = None, ): """Add the input asset and its corresponding hook execution context to the collector.""" - asset = self.create_asset(scheme=scheme, uri=uri, asset_kwargs=asset_kwargs, asset_extra=asset_extra) + asset = self.create_asset( + scheme=scheme, uri=uri, name=name, group=group, asset_kwargs=asset_kwargs, asset_extra=asset_extra + ) if asset: key = self._generate_key(asset, context) if key not in self._inputs: @@ -153,11 +172,15 @@ def add_output_asset( context: LineageContext, scheme: str | None = None, uri: str | None = None, + name: str | None = None, + group: str | None = None, asset_kwargs: dict | None = None, asset_extra: dict | None = None, ): """Add the output asset and its corresponding hook execution context to the collector.""" - asset = self.create_asset(scheme=scheme, uri=uri, asset_kwargs=asset_kwargs, asset_extra=asset_extra) + asset = self.create_asset( + scheme=scheme, uri=uri, name=name, group=group, asset_kwargs=asset_kwargs, asset_extra=asset_extra + ) if asset: key = self._generate_key(asset, context) if key not in self._outputs: From f290198ac39ee61fd8b31211e2fee9342825f9ed Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 27 Nov 2024 23:08:51 +0800 Subject: [PATCH 41/51] fix(task_sdk/asset): change iter_assets to return ((name, uri), obj) instead of (uri, obj) --- airflow/timetables/base.py | 4 ++-- .../airflow/sdk/definitions/asset/__init__.py | 16 +++++++++++----- task_sdk/tests/defintions/test_asset.py | 13 ++++++++++--- 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/airflow/timetables/base.py b/airflow/timetables/base.py index 60b1c141209fe..b80a6323a8c18 100644 --- a/airflow/timetables/base.py +++ b/airflow/timetables/base.py @@ -19,7 +19,7 @@ from collections.abc import Iterator, Sequence from typing import TYPE_CHECKING, Any, NamedTuple -from airflow.sdk.definitions.asset import BaseAsset +from airflow.sdk.definitions.asset import AssetUniqueKey, BaseAsset from airflow.typing_compat import Protocol, runtime_checkable if TYPE_CHECKING: @@ -55,7 +55,7 @@ def as_expression(self) -> Any: def evaluate(self, statuses: dict[str, bool]) -> bool: return False - def iter_assets(self) -> Iterator[tuple[str, Asset]]: + def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]: return iter(()) def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py index 10e4d7db28a70..0c3418d503df2 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -27,6 +27,7 @@ Any, Callable, ClassVar, + NamedTuple, cast, overload, ) @@ -63,6 +64,11 @@ log = logging.getLogger(__name__) +class AssetUniqueKey(NamedTuple): + name: str + uri: str + + def normalize_noop(parts: SplitResult) -> SplitResult: """ Place-hold a :class:`~urllib.parse.SplitResult`` normalizer. @@ -203,7 +209,7 @@ def as_expression(self) -> Any: def evaluate(self, statuses: dict[str, bool]) -> bool: raise NotImplementedError - def iter_assets(self) -> Iterator[tuple[str, Asset]]: + def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]: raise NotImplementedError def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: @@ -352,8 +358,8 @@ def as_expression(self) -> Any: """ return {"asset": {"uri": self.uri, "name": self.name, "group": self.group}} - def iter_assets(self) -> Iterator[tuple[str, Asset]]: - yield self.uri, self + def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]: + yield AssetUniqueKey(name=self.name, uri=self.uri), self def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: return iter(()) @@ -401,7 +407,7 @@ class AssetAlias(BaseAsset): name: str = attrs.field(validator=_validate_non_empty_identifier) group: str = attrs.field(kw_only=True, default="", validator=_validate_identifier) - def iter_assets(self) -> Iterator[tuple[str, Asset]]: + def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]: return iter(()) def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: @@ -446,7 +452,7 @@ def __init__(self, *objects: BaseAsset) -> None: def evaluate(self, statuses: dict[str, bool]) -> bool: return self.agg_func(x.evaluate(statuses=statuses) for x in self.objects) - def iter_assets(self) -> Iterator[tuple[str, Asset]]: + def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]: seen = set() # We want to keep the first instance. for o in self.objects: for k, v in o.iter_assets(): diff --git a/task_sdk/tests/defintions/test_asset.py b/task_sdk/tests/defintions/test_asset.py index 39579483c2404..e79261cb6e472 100644 --- a/task_sdk/tests/defintions/test_asset.py +++ b/task_sdk/tests/defintions/test_asset.py @@ -170,7 +170,7 @@ def test_asset_logic_operations(): def test_asset_iter_assets(): - assert list(asset1.iter_assets()) == [("s3://bucket1/data1", asset1)] + assert list(asset1.iter_assets()) == [(("asset-1", "s3://bucket1/data1"), asset1)] @pytest.mark.db_test @@ -225,8 +225,15 @@ def test_assset_boolean_condition_evaluate_iter(): # Testing iter_assets indirectly through the subclasses assets_any = dict(any_condition.iter_assets()) assets_all = dict(all_condition.iter_assets()) - assert assets_any == {"s3://bucket1/data1": asset1, "s3://bucket2/data2": asset2} - assert assets_all == {"s3://bucket1/data1": asset1, "s3://bucket2/data2": asset2} + assert assets_any == { + ("asset-1", "s3://bucket1/data1"): asset1, + # AssetUniqueKey(name="asset-1", uri="s3://bucket1/data1"): asset1, + ("asset-2", "s3://bucket2/data2"): asset2, + } + assert assets_all == { + ("asset-1", "s3://bucket1/data1"): asset1, + ("asset-2", "s3://bucket2/data2"): asset2, + } @pytest.mark.parametrize( From 11c26f2d171322f7b32ff019f39079e9387bcbb3 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 29 Nov 2024 10:22:46 +0800 Subject: [PATCH 42/51] fix(fastapi/asset): add missing group column to asset alias schema --- airflow/api_fastapi/core_api/datamodels/assets.py | 1 + airflow/api_fastapi/core_api/openapi/v1-generated.yaml | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/airflow/api_fastapi/core_api/datamodels/assets.py b/airflow/api_fastapi/core_api/datamodels/assets.py index 638ee1cba6e29..72bba200fabf3 100644 --- a/airflow/api_fastapi/core_api/datamodels/assets.py +++ b/airflow/api_fastapi/core_api/datamodels/assets.py @@ -47,6 +47,7 @@ class AssetAliasSchema(BaseModel): id: int name: str + group: str class AssetResponse(BaseModel): diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index 6c34fc4cf41f3..7e33084929043 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -5742,10 +5742,14 @@ components: name: type: string title: Name + group: + type: string + title: Group type: object required: - id - name + - group title: AssetAliasSchema description: Asset alias serializer for assets. AssetCollectionResponse: From 2dafc1ebff21f7b43557a032f865bea0ebae0dce Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 29 Nov 2024 15:21:57 +0800 Subject: [PATCH 43/51] build: build autogen ts files --- airflow/ui/openapi-gen/requests/schemas.gen.ts | 6 +++++- airflow/ui/openapi-gen/requests/types.gen.ts | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow/ui/openapi-gen/requests/schemas.gen.ts index c2cf77baab6ff..6e05c32f75302 100644 --- a/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -99,9 +99,13 @@ export const $AssetAliasSchema = { type: "string", title: "Name", }, + group: { + type: "string", + title: "Group", + }, }, type: "object", - required: ["id", "name"], + required: ["id", "name", "group"], title: "AssetAliasSchema", description: "Asset alias serializer for assets.", } as const; diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow/ui/openapi-gen/requests/types.gen.ts index dcb3ec94f9526..545ef594f471e 100644 --- a/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow/ui/openapi-gen/requests/types.gen.ts @@ -27,6 +27,7 @@ export type AppBuilderViewResponse = { export type AssetAliasSchema = { id: number; name: string; + group: string; }; /** From f6d5e4aeefe7b2bb53d8722f95fa66fb9b058a35 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 29 Nov 2024 15:23:49 +0800 Subject: [PATCH 44/51] feat(lineage/hook): make create_asset keyword only --- airflow/lineage/hook.py | 1 + 1 file changed, 1 insertion(+) diff --git a/airflow/lineage/hook.py b/airflow/lineage/hook.py index 459d04b7dde8a..62a2c7a54933f 100644 --- a/airflow/lineage/hook.py +++ b/airflow/lineage/hook.py @@ -96,6 +96,7 @@ def _generate_key(self, asset: Asset, context: LineageContext) -> str: def create_asset( self, + *, scheme: str | None = None, uri: str | None = None, name: str | None = None, From f13f57120a31b3965fc641aa98fa77f0313f89e0 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 29 Nov 2024 15:28:01 +0800 Subject: [PATCH 45/51] docs(newsfragments): add 43774.significant.rst --- newsfragments/43774.significant.rst | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 newsfragments/43774.significant.rst diff --git a/newsfragments/43774.significant.rst b/newsfragments/43774.significant.rst new file mode 100644 index 0000000000000..b716e1fc83f94 --- /dev/null +++ b/newsfragments/43774.significant.rst @@ -0,0 +1,22 @@ +``HookLineageCollector.create_asset`` now accept only keyword arguments + +To provider AIP-74 support, new arguments "name" and "group" are added to ``HookLineageCollector.create_asset``. +For easier change in the future, this function now takes only keyword arguments. + +.. Check the type of change that applies to this change + +* Types of change + + * [ ] DAG changes + * [ ] Config changes + * [ ] API changes + * [ ] CLI changes + * [x] Behaviour changes + * [ ] Plugin changes + * [ ] Dependency change + +.. List the migration rules needed for this change (see https://github.com/apache/airflow/issues/41641) + +* Migrations rules needed + + * Calling ``HookLineageCollector.create_asset`` with positional argument should raise an error From 3162024c17f212d9f3e84f3510de9e4f196aa8ca Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 29 Nov 2024 15:35:58 +0800 Subject: [PATCH 46/51] refactor(task_sdk/asset): add from_asset_alias to AssetAliasCondition to remove duplicate code --- airflow/timetables/simple.py | 4 +--- task_sdk/src/airflow/sdk/definitions/asset/__init__.py | 6 +++++- task_sdk/tests/defintions/test_asset.py | 10 +++++----- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/airflow/timetables/simple.py b/airflow/timetables/simple.py index ac16eba983475..57eec884b558a 100644 --- a/airflow/timetables/simple.py +++ b/airflow/timetables/simple.py @@ -170,9 +170,7 @@ def __init__(self, assets: BaseAsset) -> None: super().__init__() self.asset_condition = assets if isinstance(self.asset_condition, AssetAlias): - self.asset_condition = AssetAliasCondition( - name=self.asset_condition.name, group=self.asset_condition.group - ) + self.asset_condition = AssetAliasCondition.from_asset_alias(self.asset_condition) if not next(self.asset_condition.iter_assets(), False): self._summary = AssetTriggeredTimetable.UNRESOLVED_ALIAS_SUMMARY diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py index 0c3418d503df2..370ddce94303c 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -445,7 +445,7 @@ def __init__(self, *objects: BaseAsset) -> None: raise TypeError("expect asset expressions in condition") self.objects = [ - AssetAliasCondition(name=obj.name, group=obj.group) if isinstance(obj, AssetAlias) else obj + AssetAliasCondition.from_asset_alias(obj) if isinstance(obj, AssetAlias) else obj for obj in objects ] @@ -573,6 +573,10 @@ def iter_dag_dependencies(self, *, source: str = "", target: str = "") -> Iterat dependency_id=self.name, ) + @staticmethod + def from_asset_alias(asset_alias: AssetAlias) -> AssetAliasCondition: + return AssetAliasCondition(name=asset_alias.name, group=asset_alias.group) + class AssetAll(_AssetBooleanCondition): """Use to combine assets schedule references in an "or" relationship.""" diff --git a/task_sdk/tests/defintions/test_asset.py b/task_sdk/tests/defintions/test_asset.py index e79261cb6e472..b452939d21917 100644 --- a/task_sdk/tests/defintions/test_asset.py +++ b/task_sdk/tests/defintions/test_asset.py @@ -625,22 +625,22 @@ def resolved_asset_alias_2(self, session, asset_model): return asset_alias_2 def test_init(self, asset_alias_1, asset_model, resolved_asset_alias_2): - cond = AssetAliasCondition(name=asset_alias_1.name, group=asset_alias_1.group) + cond = AssetAliasCondition.from_asset_alias(asset_alias_1) assert cond.objects == [] - cond = AssetAliasCondition(name=resolved_asset_alias_2.name, group=resolved_asset_alias_2.group) + cond = AssetAliasCondition.from_asset_alias(resolved_asset_alias_2) assert cond.objects == [Asset(uri=asset_model.uri, name=asset_model.name)] def test_as_expression(self, asset_alias_1, resolved_asset_alias_2): for asset_alias in (asset_alias_1, resolved_asset_alias_2): - cond = AssetAliasCondition(name=asset_alias.name, group=asset_alias.group) + cond = AssetAliasCondition.from_asset_alias(asset_alias) assert cond.as_expression() == {"alias": {"name": asset_alias.name, "group": asset_alias.group}} def test_evalute(self, asset_alias_1, resolved_asset_alias_2, asset_model): - cond = AssetAliasCondition(name=asset_alias_1.name, group=asset_alias_1.group) + cond = AssetAliasCondition.from_asset_alias(asset_alias_1) assert cond.evaluate({asset_model.name: True}) is False - cond = AssetAliasCondition(name=resolved_asset_alias_2.name, group=resolved_asset_alias_2.group) + cond = AssetAliasCondition.from_asset_alias(resolved_asset_alias_2) assert cond.evaluate({asset_model.name: True}) is True From e9c4e52ef0bd915dfdb786e9ab2af2f28c644d8c Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 29 Nov 2024 15:37:57 +0800 Subject: [PATCH 47/51] refactor(task_sdk/asset): add AssetUniqueKey.from_asset to reduce duplicate code --- task_sdk/src/airflow/sdk/definitions/asset/__init__.py | 5 ++++- task_sdk/tests/defintions/test_asset.py | 1 - 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py index 370ddce94303c..f3bd80010b470 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -68,6 +68,9 @@ class AssetUniqueKey(NamedTuple): name: str uri: str + @staticmethod + def from_asset(asset: Asset) -> AssetUniqueKey: + return AssetUniqueKey(name=asset.name, uri=asset.uri) def normalize_noop(parts: SplitResult) -> SplitResult: """ @@ -359,7 +362,7 @@ def as_expression(self) -> Any: return {"asset": {"uri": self.uri, "name": self.name, "group": self.group}} def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]: - yield AssetUniqueKey(name=self.name, uri=self.uri), self + yield AssetUniqueKey.from_asset(self), self def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: return iter(()) diff --git a/task_sdk/tests/defintions/test_asset.py b/task_sdk/tests/defintions/test_asset.py index b452939d21917..4277ee6cb93be 100644 --- a/task_sdk/tests/defintions/test_asset.py +++ b/task_sdk/tests/defintions/test_asset.py @@ -227,7 +227,6 @@ def test_assset_boolean_condition_evaluate_iter(): assets_all = dict(all_condition.iter_assets()) assert assets_any == { ("asset-1", "s3://bucket1/data1"): asset1, - # AssetUniqueKey(name="asset-1", uri="s3://bucket1/data1"): asset1, ("asset-2", "s3://bucket2/data2"): asset2, } assert assets_all == { From 981011e85ae6871973d0dda7dd1dccd713fd7494 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 29 Nov 2024 15:42:05 +0800 Subject: [PATCH 48/51] Revert "fix(asset): use name to evalute instead of uri" This reverts commit e812b8ada59e925beeb52c8ddb0d14b0dfec1abf. --- airflow/models/dag.py | 2 +- .../airflow/sdk/definitions/asset/__init__.py | 3 +- task_sdk/tests/defintions/test_asset.py | 32 +++++++++---------- 3 files changed, 19 insertions(+), 18 deletions(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 20fdc2d3d5190..34ed6694a0695 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -2297,7 +2297,7 @@ def dag_ready(dag_id: str, cond: BaseAsset, statuses: dict) -> bool | None: del all_records dag_statuses = {} for dag_id, records in by_dag.items(): - dag_statuses[dag_id] = {x.asset.name: True for x in records} + dag_statuses[dag_id] = {x.asset.uri: True for x in records} ser_dags = SerializedDagModel.get_latest_serialized_dags( dag_ids=list(dag_statuses.keys()), session=session ) diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py index f3bd80010b470..81af48a6b41b4 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -72,6 +72,7 @@ class AssetUniqueKey(NamedTuple): def from_asset(asset: Asset) -> AssetUniqueKey: return AssetUniqueKey(name=asset.name, uri=asset.uri) + def normalize_noop(parts: SplitResult) -> SplitResult: """ Place-hold a :class:`~urllib.parse.SplitResult`` normalizer. @@ -368,7 +369,7 @@ def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: return iter(()) def evaluate(self, statuses: dict[str, bool]) -> bool: - return statuses.get(self.name, False) + return statuses.get(self.uri, False) def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]: """ diff --git a/task_sdk/tests/defintions/test_asset.py b/task_sdk/tests/defintions/test_asset.py index 4277ee6cb93be..d9aa6305f579d 100644 --- a/task_sdk/tests/defintions/test_asset.py +++ b/task_sdk/tests/defintions/test_asset.py @@ -192,8 +192,8 @@ def test_asset_iter_asset_aliases(): def test_asset_evaluate(): - assert asset1.evaluate({"asset-1": True}) is True - assert asset1.evaluate({"asset-1": False}) is False + assert asset1.evaluate({"s3://bucket1/data1": True}) is True + assert asset1.evaluate({"s3://bucket1/data1": False}) is False def test_asset_any_operations(): @@ -219,8 +219,8 @@ def test_assset_boolean_condition_evaluate_iter(): """ any_condition = AssetAny(asset1, asset2) all_condition = AssetAll(asset1, asset2) - assert any_condition.evaluate({"asset-1": False, "asset-2": True}) is True - assert all_condition.evaluate({"asset-1": True, "asset-2": False}) is False + assert any_condition.evaluate({"s3://bucket1/data1": False, "s3://bucket2/data2": True}) is True + assert all_condition.evaluate({"s3://bucket1/data1": True, "s3://bucket2/data2": False}) is False # Testing iter_assets indirectly through the subclasses assets_any = dict(any_condition.iter_assets()) @@ -263,7 +263,7 @@ def test_asset_logical_conditions_evaluation_and_serialization(inputs, scenario, assets = [Asset(uri=f"s3://abc/{i}", name=f"asset_{i}") for i in range(123, 126)] condition = class_(*assets) - statuses = {asset.name: status for asset, status in zip(assets, inputs)} + statuses = {asset.uri: status for asset, status in zip(assets, inputs)} assert ( condition.evaluate(statuses) == expected ), f"Condition evaluation failed for inputs {inputs} and scenario '{scenario}'" @@ -294,17 +294,17 @@ def test_asset_logical_conditions_evaluation_and_serialization(inputs, scenario, ) def test_nested_asset_conditions_with_serialization(status_values, expected_evaluation): # Define assets - asset1 = Asset(uri="s3://abc/123", name="asset-1") - asset2 = Asset(uri="s3://abc/124", name="asset-2") - asset3 = Asset(uri="s3://abc/125", name="asset-3") + asset1 = Asset(uri="s3://abc/123") + asset2 = Asset(uri="s3://abc/124") + asset3 = Asset(uri="s3://abc/125") # Create a nested condition: AssetAll with asset1 and AssetAny with asset2 and asset3 nested_condition = AssetAll(asset1, AssetAny(asset2, asset3)) statuses = { - asset1.name: status_values[0], - asset2.name: status_values[1], - asset3.name: status_values[2], + asset1.uri: status_values[0], + asset2.uri: status_values[1], + asset3.uri: status_values[2], } assert nested_condition.evaluate(statuses) == expected_evaluation, "Initial evaluation mismatch" @@ -369,7 +369,7 @@ def test_asset_dag_run_queue_processing(session, clear_assets, dag_maker, create records = session.scalars(select(AssetDagRunQueue)).all() dag_statuses = defaultdict(lambda: defaultdict(bool)) for record in records: - dag_statuses[record.target_dag_id][record.asset.name] = True + dag_statuses[record.target_dag_id][record.asset.uri] = True serialized_dags = session.execute( select(SerializedDagModel).where(SerializedDagModel.dag_id.in_(dag_statuses.keys())) @@ -377,9 +377,9 @@ def test_asset_dag_run_queue_processing(session, clear_assets, dag_maker, create for (serialized_dag,) in serialized_dags: dag = SerializedDAG.deserialize(serialized_dag.data) - for asset_name, status in dag_statuses[dag.dag_id].items(): + for asset_uri, status in dag_statuses[dag.dag_id].items(): cond = dag.timetable.asset_condition - assert cond.evaluate({asset_name: status}), "DAG trigger evaluation failed" + assert cond.evaluate({asset_uri: status}), "DAG trigger evaluation failed" @pytest.mark.db_test @@ -637,10 +637,10 @@ def test_as_expression(self, asset_alias_1, resolved_asset_alias_2): def test_evalute(self, asset_alias_1, resolved_asset_alias_2, asset_model): cond = AssetAliasCondition.from_asset_alias(asset_alias_1) - assert cond.evaluate({asset_model.name: True}) is False + assert cond.evaluate({asset_model.uri: True}) is False cond = AssetAliasCondition.from_asset_alias(resolved_asset_alias_2) - assert cond.evaluate({asset_model.name: True}) is True + assert cond.evaluate({asset_model.uri: True}) is True class TestAssetSubclasses: From 7790bfa90e83badb7327bf051b7c086322d49746 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 13 Nov 2024 18:51:38 +0800 Subject: [PATCH 49/51] docs: removing examples that accessing inlet and outlet events --- airflow/example_dags/example_asset_alias.py | 2 +- .../example_asset_alias_with_no_taskflow.py | 2 +- .../authoring-and-scheduling/datasets.rst | 12 +++++------ tests/models/test_taskinstance.py | 20 +++++++++---------- 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/airflow/example_dags/example_asset_alias.py b/airflow/example_dags/example_asset_alias.py index a7f2aac5845c2..2d26fa32101c2 100644 --- a/airflow/example_dags/example_asset_alias.py +++ b/airflow/example_dags/example_asset_alias.py @@ -67,7 +67,7 @@ def produce_asset_events(): def produce_asset_events_through_asset_alias(*, outlet_events=None): bucket_name = "bucket" object_path = "my-task" - outlet_events["example-alias"].add(Asset(f"s3://{bucket_name}/{object_path}")) + outlet_events[AssetAlias("example-alias")].add(Asset(f"s3://{bucket_name}/{object_path}")) produce_asset_events_through_asset_alias() diff --git a/airflow/example_dags/example_asset_alias_with_no_taskflow.py b/airflow/example_dags/example_asset_alias_with_no_taskflow.py index 19f31465ea4f8..c3d1ac0b8d14d 100644 --- a/airflow/example_dags/example_asset_alias_with_no_taskflow.py +++ b/airflow/example_dags/example_asset_alias_with_no_taskflow.py @@ -68,7 +68,7 @@ def produce_asset_events(): def produce_asset_events_through_asset_alias_with_no_taskflow(*, outlet_events=None): bucket_name = "bucket" object_path = "my-task" - outlet_events["example-alias-no-taskflow"].add(Asset(f"s3://{bucket_name}/{object_path}")) + outlet_events[AssetAlias("example-alias-no-taskflow")].add(Asset(f"s3://{bucket_name}/{object_path}")) PythonOperator( task_id="produce_asset_events_through_asset_alias_with_no_taskflow", diff --git a/docs/apache-airflow/authoring-and-scheduling/datasets.rst b/docs/apache-airflow/authoring-and-scheduling/datasets.rst index 9e777d9299587..51a6c5be30a21 100644 --- a/docs/apache-airflow/authoring-and-scheduling/datasets.rst +++ b/docs/apache-airflow/authoring-and-scheduling/datasets.rst @@ -445,7 +445,7 @@ The following example creates an asset event against the S3 URI ``f"s3://bucket/ @task(outlets=[AssetAlias("my-task-outputs")]) def my_task_with_outlet_events(*, outlet_events): - outlet_events["my-task-outputs"].add(Asset("s3://bucket/my-task"), extra={"k": "v"}) + outlet_events[AssetAlias("my-task-outputs")].add(Asset("s3://bucket/my-task"), extra={"k": "v"}) **Emit an asset event during task execution through yielding Metadata** @@ -475,11 +475,11 @@ Only one asset event is emitted for an added asset, even if it is added to the a ] ) def my_task_with_outlet_events(*, outlet_events): - outlet_events["my-task-outputs-1"].add(Asset("s3://bucket/my-task"), extra={"k": "v"}) + outlet_events[AssetAlias("my-task-outputs-1")].add(Asset("s3://bucket/my-task"), extra={"k": "v"}) # This line won't emit an additional asset event as the asset and extra are the same as the previous line. - outlet_events["my-task-outputs-2"].add(Asset("s3://bucket/my-task"), extra={"k": "v"}) + outlet_events[AssetAlias("my-task-outputs-2")].add(Asset("s3://bucket/my-task"), extra={"k": "v"}) # This line will emit an additional asset event as the extra is different. - outlet_events["my-task-outputs-3"].add(Asset("s3://bucket/my-task"), extra={"k2": "v2"}) + outlet_events[AssetAlias("my-task-outputs-3")].add(Asset("s3://bucket/my-task"), extra={"k2": "v2"}) Scheduling based on asset aliases ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -500,7 +500,7 @@ The asset alias is resolved to the assets during DAG parsing. Thus, if the "min_ @task(outlets=[AssetAlias("example-alias")]) def produce_asset_events(*, outlet_events): - outlet_events["example-alias"].add(Asset("s3://bucket/my-task")) + outlet_events[AssetAlias("example-alias")].add(Asset("s3://bucket/my-task")) with DAG(dag_id="asset-consumer", schedule=Asset("s3://bucket/my-task")): @@ -524,7 +524,7 @@ As mentioned in :ref:`Fetching information from previously emitted asset events< @task(outlets=[AssetAlias("example-alias")]) def produce_asset_events(*, outlet_events): - outlet_events["example-alias"].add(Asset("s3://bucket/my-task"), extra={"row_count": 1}) + outlet_events[AssetAlias("example-alias")].add(Asset("s3://bucket/my-task"), extra={"row_count": 1}) with DAG(dag_id="asset-alias-consumer", schedule=None): diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 51803e788cbbc..e6563c4417494 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -2405,12 +2405,12 @@ def test_outlet_asset_extra(self, dag_maker, session): @task(outlets=Asset("test_outlet_asset_extra_1")) def write1(*, outlet_events): - outlet_events["test_outlet_asset_extra_1"].extra = {"foo": "bar"} + outlet_events[Asset("test_outlet_asset_extra_1")].extra = {"foo": "bar"} write1() def _write2_post_execute(context, _): - context["outlet_events"]["test_outlet_asset_extra_2"].extra = {"x": 1} + context["outlet_events"][Asset("test_outlet_asset_extra_2")].extra = {"x": 1} BashOperator( task_id="write2", @@ -2446,8 +2446,8 @@ def test_outlet_asset_extra_ignore_different(self, dag_maker, session): @task(outlets=Asset("test_outlet_asset_extra")) def write(*, outlet_events): - outlet_events["test_outlet_asset_extra"].extra = {"one": 1} - outlet_events["different_uri"].extra = {"foo": "bar"} # Will be silently dropped. + outlet_events[Asset("test_outlet_asset_extra")].extra = {"one": 1} + outlet_events[Asset("different_uri")].extra = {"foo": "bar"} # Will be silently dropped. write() @@ -2722,22 +2722,22 @@ def test_inlet_asset_extra(self, dag_maker, session): @task(outlets=Asset("test_inlet_asset_extra")) def write(*, ti, outlet_events): - outlet_events["test_inlet_asset_extra"].extra = {"from": ti.task_id} + outlet_events[Asset("test_inlet_asset_extra")].extra = {"from": ti.task_id} @task(inlets=Asset("test_inlet_asset_extra")) def read(*, inlet_events): - second_event = inlet_events["test_inlet_asset_extra"][1] + second_event = inlet_events[Asset("test_inlet_asset_extra")][1] assert second_event.uri == "test_inlet_asset_extra" assert second_event.extra == {"from": "write2"} - last_event = inlet_events["test_inlet_asset_extra"][-1] + last_event = inlet_events[Asset("test_inlet_asset_extra")][-1] assert last_event.uri == "test_inlet_asset_extra" assert last_event.extra == {"from": "write3"} with pytest.raises(KeyError): - inlet_events["does_not_exist"] + inlet_events[Asset("does_not_exist")] with pytest.raises(IndexError): - inlet_events["test_inlet_asset_extra"][5] + inlet_events[Asset("test_inlet_asset_extra")][5] # TODO: Support slices. @@ -2797,7 +2797,7 @@ def read(*, inlet_events): assert last_event.extra == {"from": "write3"} with pytest.raises(KeyError): - inlet_events["does_not_exist"] + inlet_events[Asset("does_not_exist")] with pytest.raises(KeyError): inlet_events[AssetAlias("does_not_exist")] with pytest.raises(IndexError): From 2f2c3a78da7dcfb7d7342cdc69547251c51be228 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 13 Nov 2024 19:30:48 +0800 Subject: [PATCH 50/51] feat(utils/context): Deprecate accessing inlet and outlet events through string --- airflow/serialization/serialized_objects.py | 6 +- airflow/utils/context.py | 60 +++++++++++-------- airflow/utils/context.pyi | 18 +++--- airflow/utils/operator_helpers.py | 9 ++- .../serialization/test_serialized_objects.py | 11 +--- tests/utils/test_context.py | 14 ++--- 6 files changed, 64 insertions(+), 54 deletions(-) diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index f78a2b78b8811..f25f7a9a6af8a 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -303,11 +303,11 @@ def decode_asset_condition(var: dict[str, Any]) -> BaseAsset: def encode_outlet_event_accessor(var: OutletEventAccessor) -> dict[str, Any]: - raw_key = var.raw_key + key = var.key return { "extra": var.extra, "asset_alias_events": var.asset_alias_events, - "raw_key": BaseSerialization.serialize(raw_key), + "key": BaseSerialization.serialize(key), } @@ -316,7 +316,7 @@ def decode_outlet_event_accessor(var: dict[str, Any]) -> OutletEventAccessor: outlet_event_accessor = OutletEventAccessor( extra=var["extra"], - raw_key=BaseSerialization.deserialize(var["raw_key"]), + key=BaseSerialization.deserialize(var["key"]), asset_alias_events=asset_alias_events, ) return outlet_event_accessor diff --git a/airflow/utils/context.py b/airflow/utils/context.py index 90232e7b2efd8..82557fe16ad4c 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -23,7 +23,15 @@ import copy import functools import warnings -from collections.abc import Container, ItemsView, Iterator, KeysView, Mapping, MutableMapping, ValuesView +from collections.abc import ( + Container, + ItemsView, + Iterator, + KeysView, + Mapping, + MutableMapping, + ValuesView, +) from typing import ( TYPE_CHECKING, Any, @@ -35,12 +43,18 @@ from sqlalchemy import select from airflow.exceptions import RemovedInAirflow3Warning -from airflow.models.asset import AssetAliasModel, AssetEvent, AssetModel, _fetch_active_assets_by_name +from airflow.models.asset import ( + AssetAliasModel, + AssetEvent, + AssetModel, + _fetch_active_assets_by_name, +) from airflow.sdk.definitions.asset import ( Asset, AssetAlias, AssetAliasEvent, AssetRef, + BaseAsset, ) from airflow.sdk.definitions.asset.metadata import extract_event_key from airflow.utils.db import LazySelectSequence @@ -153,33 +167,30 @@ class OutletEventAccessor: :meta private: """ - raw_key: str | Asset | AssetAlias + key: BaseAsset extra: dict[str, Any] = attrs.Factory(dict) asset_alias_events: list[AssetAliasEvent] = attrs.field(factory=list) - def add(self, asset: Asset | str, extra: dict[str, Any] | None = None) -> None: + def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None: """Add an AssetEvent to an existing Asset.""" - if isinstance(asset, str): - asset_uri = asset - elif isinstance(asset, Asset): - asset_uri = asset.uri - else: + if not isinstance(asset, Asset): return - if isinstance(self.raw_key, str): - asset_alias_name = self.raw_key - elif isinstance(self.raw_key, AssetAlias): - asset_alias_name = self.raw_key.name + if isinstance(self.key, AssetAlias): + asset_alias_name = self.key.name else: return + # TODO: handle asset.name event = AssetAliasEvent( - source_alias_name=asset_alias_name, dest_asset_uri=asset_uri, extra=extra or {} + source_alias_name=asset_alias_name, + dest_asset_uri=asset.uri, + extra=extra or {}, ) self.asset_alias_events.append(event) -class OutletEventAccessors(Mapping[str, OutletEventAccessor]): +class OutletEventAccessors(Mapping[BaseAsset, OutletEventAccessor]): """ Lazy mapping of outlet asset event accessors. @@ -187,22 +198,21 @@ class OutletEventAccessors(Mapping[str, OutletEventAccessor]): """ def __init__(self) -> None: - self._dict: dict[str, OutletEventAccessor] = {} + self._dict: dict[BaseAsset, OutletEventAccessor] = {} def __str__(self) -> str: return f"OutletEventAccessors(_dict={self._dict})" - def __iter__(self) -> Iterator[str]: + def __iter__(self) -> Iterator[BaseAsset]: return iter(self._dict) def __len__(self) -> int: return len(self._dict) - def __getitem__(self, key: str | Asset | AssetAlias) -> OutletEventAccessor: - event_key = extract_event_key(key) - if event_key not in self._dict: - self._dict[event_key] = OutletEventAccessor(extra={}, raw_key=key) - return self._dict[event_key] + def __getitem__(self, key: BaseAsset) -> OutletEventAccessor: + if key not in self._dict: + self._dict[key] = OutletEventAccessor(extra={}, key=key) + return self._dict[key] class LazyAssetEventSelectSequence(LazySelectSequence[AssetEvent]): @@ -222,7 +232,7 @@ def _process_row(row: Row) -> AssetEvent: @attrs.define(init=False) -class InletEventsAccessors(Mapping[str, LazyAssetEventSelectSequence]): +class InletEventsAccessors(Mapping[BaseAsset, LazyAssetEventSelectSequence]): """ Lazy mapping for inlet asset events accessors. @@ -253,13 +263,13 @@ def __init__(self, inlets: list, *, session: Session) -> None: for asset_name, asset in _fetch_active_assets_by_name(_asset_ref_names, self._session).items(): self._assets[asset_name] = asset - def __iter__(self) -> Iterator[str]: + def __iter__(self) -> Iterator[BaseAsset]: return iter(self._inlets) def __len__(self) -> int: return len(self._inlets) - def __getitem__(self, key: int | str | Asset | AssetAlias) -> LazyAssetEventSelectSequence: + def __getitem__(self, key: int | BaseAsset) -> LazyAssetEventSelectSequence: if isinstance(key, int): # Support index access; it's easier for trivial cases. obj = self._inlets[key] if not isinstance(obj, (Asset, AssetAlias, AssetRef)): diff --git a/airflow/utils/context.pyi b/airflow/utils/context.pyi index d0ea2132f1c95..e25a0c7681873 100644 --- a/airflow/utils/context.pyi +++ b/airflow/utils/context.pyi @@ -63,18 +63,18 @@ class OutletEventAccessor: self, *, extra: dict[str, Any], - raw_key: str | Asset | AssetAlias, + key: Asset | AssetAlias, asset_alias_events: list[AssetAliasEvent], ) -> None: ... - def add(self, asset: Asset | str, extra: dict[str, Any] | None = None) -> None: ... + def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None: ... extra: dict[str, Any] - raw_key: str | Asset | AssetAlias + key: Asset | AssetAlias asset_alias_events: list[AssetAliasEvent] -class OutletEventAccessors(Mapping[str, OutletEventAccessor]): - def __iter__(self) -> Iterator[str]: ... +class OutletEventAccessors(Mapping[Asset | AssetAlias, OutletEventAccessor]): + def __iter__(self) -> Iterator[Asset | AssetAlias]: ... def __len__(self) -> int: ... - def __getitem__(self, key: str | Asset | AssetAlias) -> OutletEventAccessor: ... + def __getitem__(self, key: Asset | AssetAlias) -> OutletEventAccessor: ... class InletEventsAccessor(Sequence[AssetEvent]): @overload @@ -83,11 +83,11 @@ class InletEventsAccessor(Sequence[AssetEvent]): def __getitem__(self, key: slice) -> Sequence[AssetEvent]: ... def __len__(self) -> int: ... -class InletEventsAccessors(Mapping[str, InletEventsAccessor]): +class InletEventsAccessors(Mapping[Asset | AssetAlias, InletEventsAccessor]): def __init__(self, inlets: list, *, session: Session) -> None: ... - def __iter__(self) -> Iterator[str]: ... + def __iter__(self) -> Iterator[Asset | AssetAlias]: ... def __len__(self) -> int: ... - def __getitem__(self, key: int | str | Asset | AssetAlias) -> InletEventsAccessor: ... + def __getitem__(self, key: int | Asset | AssetAlias) -> InletEventsAccessor: ... # NOTE: Please keep this in sync with the following: # * KNOWN_CONTEXT_KEYS in airflow/utils/context.py diff --git a/airflow/utils/operator_helpers.py b/airflow/utils/operator_helpers.py index e5e304bb4a414..3dd81630ca071 100644 --- a/airflow/utils/operator_helpers.py +++ b/airflow/utils/operator_helpers.py @@ -24,6 +24,7 @@ from typing import TYPE_CHECKING, Any, Callable, Protocol, TypeVar from airflow import settings +from airflow.sdk.definitions.asset import Asset, AssetAlias from airflow.sdk.definitions.asset.metadata import Metadata from airflow.typing_compat import ParamSpec from airflow.utils.context import Context, lazy_mapping_from_context @@ -276,10 +277,14 @@ def _run(): for metadata in _run(): if isinstance(metadata, Metadata): - outlet_events[metadata.uri].extra.update(metadata.extra) + # TODO: handle asset name + outlet_events[Asset(uri=metadata.uri)].extra.update(metadata.extra) if metadata.alias_name: - outlet_events[metadata.alias_name].add(metadata.uri, extra=metadata.extra) + # TODO: handle asset name + outlet_events[AssetAlias(name=metadata.alias_name)].add( + Asset(uri=metadata.uri), extra=metadata.extra + ) continue logger.warning("Ignoring unknown data of %r received from task", type(metadata)) diff --git a/tests/serialization/test_serialized_objects.py b/tests/serialization/test_serialized_objects.py index 3e8e844528822..3dd7d3211bdad 100644 --- a/tests/serialization/test_serialized_objects.py +++ b/tests/serialization/test_serialized_objects.py @@ -163,7 +163,7 @@ def equal_exception(a: AirflowException, b: AirflowException) -> bool: def equal_outlet_event_accessor(a: OutletEventAccessor, b: OutletEventAccessor) -> bool: - return a.raw_key == b.raw_key and a.extra == b.extra and a.asset_alias_events == b.asset_alias_events + return a.key == b.key and a.extra == b.extra and a.asset_alias_events == b.asset_alias_events class MockLazySelectSequence(LazySelectSequence): @@ -258,7 +258,7 @@ def __len__(self) -> int: ), ( OutletEventAccessor( - raw_key=Asset(uri="test://asset1", name="test", group="test-group"), + key=Asset(uri="test", name="test", group="test-group"), extra={"key": "value"}, asset_alias_events=[], ), @@ -267,7 +267,7 @@ def __len__(self) -> int: ), ( OutletEventAccessor( - raw_key=AssetAlias(name="test_alias", group="test-alias-group"), + key=AssetAlias(name="test_alias", group="test-alias-group"), extra={"key": "value"}, asset_alias_events=[ AssetAliasEvent( @@ -280,11 +280,6 @@ def __len__(self) -> int: DAT.ASSET_EVENT_ACCESSOR, equal_outlet_event_accessor, ), - ( - OutletEventAccessor(raw_key="test", extra={"key": "value"}, asset_alias_events=[]), - DAT.ASSET_EVENT_ACCESSOR, - equal_outlet_event_accessor, - ), ( AirflowException("test123 wohoo!"), DAT.AIRFLOW_EXC_SER, diff --git a/tests/utils/test_context.py b/tests/utils/test_context.py index 0e7309075b38c..3cd2972a9b749 100644 --- a/tests/utils/test_context.py +++ b/tests/utils/test_context.py @@ -27,7 +27,7 @@ class TestOutletEventAccessor: @pytest.mark.parametrize( - "raw_key, asset_alias_events", + "key, asset_alias_events", ( ( AssetAlias("test_alias"), @@ -36,14 +36,14 @@ class TestOutletEventAccessor: (Asset("test_uri"), []), ), ) - def test_add(self, raw_key, asset_alias_events): - outlet_event_accessor = OutletEventAccessor(raw_key=raw_key, extra={}) + def test_add(self, key, asset_alias_events): + outlet_event_accessor = OutletEventAccessor(key=key, extra={}) outlet_event_accessor.add(Asset("test_uri")) assert outlet_event_accessor.asset_alias_events == asset_alias_events @pytest.mark.db_test @pytest.mark.parametrize( - "raw_key, asset_alias_events", + "key, asset_alias_events", ( ( AssetAlias("test_alias"), @@ -56,13 +56,13 @@ def test_add(self, raw_key, asset_alias_events): (Asset("test_uri"), []), ), ) - def test_add_with_db(self, raw_key, asset_alias_events, session): + def test_add_with_db(self, key, asset_alias_events, session): asm = AssetModel(uri="test_uri") aam = AssetAliasModel(name="test_alias") session.add_all([asm, aam]) session.flush() - outlet_event_accessor = OutletEventAccessor(raw_key=raw_key, extra={"not": ""}) + outlet_event_accessor = OutletEventAccessor(key=key, extra={"not": ""}) outlet_event_accessor.add("test_uri", extra={}) assert outlet_event_accessor.asset_alias_events == asset_alias_events @@ -74,5 +74,5 @@ def test____get_item___dict_key_not_exists(self, key): assert len(outlet_event_accessors) == 0 outlet_event_accessor = outlet_event_accessors[key] assert len(outlet_event_accessors) == 1 - assert outlet_event_accessor.raw_key == key + assert outlet_event_accessor.key == key assert outlet_event_accessor.extra == {} From b0efef5ff3279733cd5450e0045c54aeef116b8b Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 29 Nov 2024 13:07:33 +0800 Subject: [PATCH 51/51] feat(tmp): commit --- airflow/utils/context.py | 21 ++++++++++++++----- .../airflow/sdk/definitions/asset/__init__.py | 8 +++++++ 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/airflow/utils/context.py b/airflow/utils/context.py index 82557fe16ad4c..166ed9b882e09 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -53,7 +53,9 @@ Asset, AssetAlias, AssetAliasEvent, + AssetAliasUniqueKey, AssetRef, + AssetUniqueKey, BaseAsset, ) from airflow.sdk.definitions.asset.metadata import extract_event_key @@ -167,7 +169,7 @@ class OutletEventAccessor: :meta private: """ - key: BaseAsset + key: AssetUniqueKey | AssetAliasUniqueKey extra: dict[str, Any] = attrs.Factory(dict) asset_alias_events: list[AssetAliasEvent] = attrs.field(factory=list) @@ -176,7 +178,7 @@ def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None: if not isinstance(asset, Asset): return - if isinstance(self.key, AssetAlias): + if isinstance(self.key, AssetAliasUniqueKey): asset_alias_name = self.key.name else: return @@ -210,9 +212,18 @@ def __len__(self) -> int: return len(self._dict) def __getitem__(self, key: BaseAsset) -> OutletEventAccessor: - if key not in self._dict: - self._dict[key] = OutletEventAccessor(extra={}, key=key) - return self._dict[key] + hashable_key: AssetUniqueKey | AssetAliasUniqueKey + if isinstance(key, Asset): + hashable_key = AssetUniqueKey.from_asset(key) + elif isinstance(key, AssetAlias): + hashable_key = AssetAliasUniqueKey.from_asset_alias(key) + else: + # TODO + raise SystemExit() + + if hashable_key not in self._dict: + self._dict[hashable_key] = OutletEventAccessor(extra={}, key=hashable_key) + return self._dict[hashable_key] class LazyAssetEventSelectSequence(LazySelectSequence[AssetEvent]): diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py index 81af48a6b41b4..98c12ca6fd52a 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -73,6 +73,14 @@ def from_asset(asset: Asset) -> AssetUniqueKey: return AssetUniqueKey(name=asset.name, uri=asset.uri) +class AssetAliasUniqueKey(NamedTuple): + name: str + + @staticmethod + def from_asset_alias(asset_alias: AssetAlias) -> AssetAliasUniqueKey: + return AssetAliasUniqueKey(name=asset_alias.name) + + def normalize_noop(parts: SplitResult) -> SplitResult: """ Place-hold a :class:`~urllib.parse.SplitResult`` normalizer.