From 7b19c701ddbc5ea30c9989306cdc18a026e2bccb Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Tue, 13 Aug 2024 18:31:22 +0800 Subject: [PATCH 1/7] Add 'name' to Dataset The Dataset class now accepts *either* name or uri, and fills the other by the provided value if both are not explicitly set. Although not strictly required by this change, 'name' on DatasetAlias also now received the same parse-time check as Dataset's name and uri fields so they emit the same errors on incorrectly inputs. --- airflow/datasets/__init__.py | 49 +- ...5_3_0_0_add_name_field_to_dataset_model.py | 58 + airflow/models/dataset.py | 12 + airflow/serialization/serialized_objects.py | 5 +- airflow/utils/db.py | 2 +- docs/apache-airflow/img/airflow_erd.sha256 | 2 +- docs/apache-airflow/img/airflow_erd.svg | 2297 +++++++++-------- docs/apache-airflow/migrations-ref.rst | 4 +- 8 files changed, 1267 insertions(+), 1162 deletions(-) create mode 100644 airflow/migrations/versions/0005_3_0_0_add_name_field_to_dataset_model.py diff --git a/airflow/datasets/__init__.py b/airflow/datasets/__init__.py index 55d947544c1d2..2fa1c972853b2 100644 --- a/airflow/datasets/__init__.py +++ b/airflow/datasets/__init__.py @@ -63,6 +63,28 @@ def _get_normalized_scheme(uri: str) -> str: return parsed.scheme.lower() +class _IdentifierValidator: + @staticmethod + def validate(kind: str, value: str | None, *, optional: bool = False) -> None: + if optional and value is None: + return + if not value: + raise ValueError(f"{kind} cannot be empty") + if len(value) > 3000: + raise ValueError(f"{kind} must be at most 3000 characters") + if value.isspace(): + raise ValueError(f"{kind} cannot be just whitespace") + if not value.isascii(): + raise ValueError(f"{kind} must only consist of ASCII characters") + + def __call__(self, inst: Dataset | DatasetAlias, attribute: attr.Attribute, value: str | None) -> None: + self.validate( + f"{type(inst).__name__} {attribute.name}", + value, + optional=attribute.default is None, + ) + + def _sanitize_uri(uri: str) -> str: """ Sanitize a dataset URI. @@ -70,12 +92,6 @@ def _sanitize_uri(uri: str) -> str: This checks for URI validity, and normalizes the URI if needed. A fully normalized URI is returned. """ - if not uri: - raise ValueError("Dataset URI cannot be empty") - if uri.isspace(): - raise ValueError("Dataset URI cannot be just whitespace") - if not uri.isascii(): - raise ValueError("Dataset URI must only consist of ASCII characters") parsed = urllib.parse.urlsplit(uri) if not parsed.scheme and not parsed.netloc: # Does not look like a URI. return uri @@ -133,10 +149,10 @@ def extract_event_key(value: str | Dataset | DatasetAlias) -> str: """ if isinstance(value, DatasetAlias): return value.name - if isinstance(value, Dataset): return value.uri - return _sanitize_uri(str(value)) + _IdentifierValidator.validate("Dataset event key", uri := str(value)) + return _sanitize_uri(uri) @internal_api_call @@ -210,7 +226,7 @@ def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDepe class DatasetAlias(BaseDataset): """A represeation of dataset alias which is used to create dataset during the runtime.""" - name: str + name: str = attr.field(validator=_IdentifierValidator()) def __eq__(self, other: Any) -> bool: if isinstance(other, DatasetAlias): @@ -245,14 +261,25 @@ class DatasetAliasEvent(TypedDict): class Dataset(os.PathLike, BaseDataset): """A representation of data dependencies between workflows.""" + name: str = attr.field(default=None, validator=_IdentifierValidator()) uri: str = attr.field( + default=None, + kw_only=True, converter=_sanitize_uri, - validator=[attr.validators.min_len(1), attr.validators.max_len(3000)], + validator=_IdentifierValidator(), ) - extra: dict[str, Any] | None = None + extra: dict[str, Any] | None = attr.field(kw_only=True, default=None) __version__: ClassVar[int] = 1 + def __attrs_post_init__(self) -> None: + if self.name is None and self.uri is None: + raise TypeError("Dataset requires either name or URI") + if self.name is None: + self.name = self.uri + elif self.uri is None: + self.uri = self.name + def __fspath__(self) -> str: return self.uri diff --git a/airflow/migrations/versions/0005_3_0_0_add_name_field_to_dataset_model.py b/airflow/migrations/versions/0005_3_0_0_add_name_field_to_dataset_model.py new file mode 100644 index 0000000000000..ad5a712fcace0 --- /dev/null +++ b/airflow/migrations/versions/0005_3_0_0_add_name_field_to_dataset_model.py @@ -0,0 +1,58 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Add name field to DatasetModel. + +Revision ID: 0d9e73a75ee4 +Revises: 044f740568ec +Create Date: 2024-08-13 09:45:32.213222 + +""" + +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "0d9e73a75ee4" +down_revision = "0bfc26bc256e" +branch_labels = None +depends_on = None +airflow_version = "3.0.0" + + +def upgrade(): + """Add name field to dataset model.""" + with op.batch_alter_table("dataset", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "name", + sa.String(length=3000).with_variant( + sa.String(length=3000, collation="latin1_general_cs"), "mysql" + ), + nullable=False, + ) + ) + + +def downgrade(): + """Remove name field to dataset model.""" + with op.batch_alter_table("dataset", schema=None) as batch_op: + batch_op.drop_column("name") diff --git a/airflow/models/dataset.py b/airflow/models/dataset.py index 5033da48a3059..b22da7038d5bc 100644 --- a/airflow/models/dataset.py +++ b/airflow/models/dataset.py @@ -148,6 +148,18 @@ class DatasetModel(Base): """ id = Column(Integer, primary_key=True, autoincrement=True) + name = Column( + String(length=3000).with_variant( + String( + length=3000, + # latin1 allows for more indexed length in mysql + # and this field should only be ascii chars + collation="latin1_general_cs", + ), + "mysql", + ), + nullable=False, + ) uri = Column( String(length=3000).with_variant( String( diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 2bbeb8aae917c..c8b690132de76 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -253,7 +253,7 @@ def encode_dataset_condition(var: BaseDataset) -> dict[str, Any]: :meta private: """ if isinstance(var, Dataset): - return {"__type": DAT.DATASET, "uri": var.uri, "extra": var.extra} + return {"__type": DAT.DATASET, "name": var.name, "uri": var.uri, "extra": var.extra} if isinstance(var, DatasetAlias): return {"__type": DAT.DATASET_ALIAS, "name": var.name} if isinstance(var, DatasetAll): @@ -271,7 +271,8 @@ def decode_dataset_condition(var: dict[str, Any]) -> BaseDataset: """ dat = var["__type"] if dat == DAT.DATASET: - return Dataset(var["uri"], extra=var["extra"]) + uri = var["uri"] + return Dataset(name=var.get("name", uri), uri=uri, extra=var["extra"]) if dat == DAT.DATASET_ALL: return DatasetAll(*(decode_dataset_condition(x) for x in var["objects"])) if dat == DAT.DATASET_ANY: diff --git a/airflow/utils/db.py b/airflow/utils/db.py index ec1805608c636..091fcdfb6570b 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -91,7 +91,7 @@ class MappedClassProtocol(Protocol): _REVISION_HEADS_MAP = { "2.10.0": "22ed7efa9da2", - "3.0.0": "0bfc26bc256e", + "3.0.0": "0d9e73a75ee4", } diff --git a/docs/apache-airflow/img/airflow_erd.sha256 b/docs/apache-airflow/img/airflow_erd.sha256 index f91459ccf67e7..1921a05ad3126 100644 --- a/docs/apache-airflow/img/airflow_erd.sha256 +++ b/docs/apache-airflow/img/airflow_erd.sha256 @@ -1 +1 @@ -60024966161ea6c01433f9a803aacd8b9d42cdab074a962e535ae3ad03da54bc \ No newline at end of file +63ad331c1b28dca55e57f53a7248ac36c017a917ae6831721c45a9d12c709ef6 \ No newline at end of file diff --git a/docs/apache-airflow/img/airflow_erd.svg b/docs/apache-airflow/img/airflow_erd.svg index 0753fa2b873b0..1eec6f5bd12f7 100644 --- a/docs/apache-airflow/img/airflow_erd.svg +++ b/docs/apache-airflow/img/airflow_erd.svg @@ -4,11 +4,11 @@ - - + + %3 - + log @@ -449,244 +449,249 @@ dataset_alias_dataset - -dataset_alias_dataset - -alias_id - - [INTEGER] - NOT NULL - -dataset_id - - [INTEGER] - NOT NULL + +dataset_alias_dataset + +alias_id + + [INTEGER] + NOT NULL + +dataset_id + + [INTEGER] + NOT NULL dataset_alias--dataset_alias_dataset - -0..N -1 + +0..N +1 dataset_alias--dataset_alias_dataset - -0..N -1 + +0..N +1 dataset_alias_dataset_event - -dataset_alias_dataset_event - -alias_id - - [INTEGER] - NOT NULL - -event_id - - [INTEGER] - NOT NULL + +dataset_alias_dataset_event + +alias_id + + [INTEGER] + NOT NULL + +event_id + + [INTEGER] + NOT NULL dataset_alias--dataset_alias_dataset_event - -0..N -1 + +0..N +1 dataset_alias--dataset_alias_dataset_event - -0..N -1 + +0..N +1 dag_schedule_dataset_alias_reference - -dag_schedule_dataset_alias_reference - -alias_id - - [INTEGER] - NOT NULL - -dag_id - - [VARCHAR(250)] - NOT NULL - -created_at - - [TIMESTAMP] - NOT NULL - -updated_at - - [TIMESTAMP] - NOT NULL + +dag_schedule_dataset_alias_reference + +alias_id + + [INTEGER] + NOT NULL + +dag_id + + [VARCHAR(250)] + NOT NULL + +created_at + + [TIMESTAMP] + NOT NULL + +updated_at + + [TIMESTAMP] + NOT NULL dataset_alias--dag_schedule_dataset_alias_reference - -0..N -1 + +0..N +1 dataset - -dataset - -id - - [INTEGER] - NOT NULL - -created_at - - [TIMESTAMP] - NOT NULL - -extra - - [JSON] - NOT NULL - -is_orphaned - - [BOOLEAN] - NOT NULL - -updated_at - - [TIMESTAMP] - NOT NULL - -uri - - [VARCHAR(3000)] - NOT NULL + +dataset + +id + + [INTEGER] + NOT NULL + +created_at + + [TIMESTAMP] + NOT NULL + +extra + + [JSON] + NOT NULL + +is_orphaned + + [BOOLEAN] + NOT NULL + +name + + [VARCHAR(3000)] + NOT NULL + +updated_at + + [TIMESTAMP] + NOT NULL + +uri + + [VARCHAR(3000)] + NOT NULL dataset--dataset_alias_dataset - -0..N -1 + +0..N +1 dataset--dataset_alias_dataset - -0..N -1 + +0..N +1 dag_schedule_dataset_reference - -dag_schedule_dataset_reference - -dag_id - - [VARCHAR(250)] - NOT NULL - -dataset_id - - [INTEGER] - NOT NULL - -created_at - - [TIMESTAMP] - NOT NULL - -updated_at - - [TIMESTAMP] - NOT NULL + +dag_schedule_dataset_reference + +dag_id + + [VARCHAR(250)] + NOT NULL + +dataset_id + + [INTEGER] + NOT NULL + +created_at + + [TIMESTAMP] + NOT NULL + +updated_at + + [TIMESTAMP] + NOT NULL dataset--dag_schedule_dataset_reference - -0..N -1 + +0..N +1 task_outlet_dataset_reference - -task_outlet_dataset_reference - -dag_id - - [VARCHAR(250)] - NOT NULL - -dataset_id - - [INTEGER] - NOT NULL - -task_id - - [VARCHAR(250)] - NOT NULL - -created_at - - [TIMESTAMP] - NOT NULL - -updated_at - - [TIMESTAMP] - NOT NULL + +task_outlet_dataset_reference + +dag_id + + [VARCHAR(250)] + NOT NULL + +dataset_id + + [INTEGER] + NOT NULL + +task_id + + [VARCHAR(250)] + NOT NULL + +created_at + + [TIMESTAMP] + NOT NULL + +updated_at + + [TIMESTAMP] + NOT NULL dataset--task_outlet_dataset_reference - -0..N -1 + +0..N +1 dataset_dag_run_queue - -dataset_dag_run_queue - -dataset_id - - [INTEGER] - NOT NULL - -target_dag_id - - [VARCHAR(250)] - NOT NULL - -created_at - - [TIMESTAMP] - NOT NULL + +dataset_dag_run_queue + +dataset_id + + [INTEGER] + NOT NULL + +target_dag_id + + [VARCHAR(250)] + NOT NULL + +created_at + + [TIMESTAMP] + NOT NULL dataset--dataset_dag_run_queue - -0..N -1 + +0..N +1 @@ -733,39 +738,39 @@ dataset_event--dataset_alias_dataset_event - -0..N -1 + +0..N +1 dataset_event--dataset_alias_dataset_event - -0..N -1 + +0..N +1 dagrun_dataset_event - -dagrun_dataset_event - -dag_run_id - - [INTEGER] - NOT NULL - -event_id - - [INTEGER] - NOT NULL + +dagrun_dataset_event + +dag_run_id + + [INTEGER] + NOT NULL + +event_id + + [INTEGER] + NOT NULL dataset_event--dagrun_dataset_event - -0..N -1 + +0..N +1 @@ -884,114 +889,114 @@ dag--dag_schedule_dataset_alias_reference - -0..N -1 + +0..N +1 dag--dag_schedule_dataset_reference - -0..N -1 + +0..N +1 dag--task_outlet_dataset_reference - -0..N -1 + +0..N +1 dag--dataset_dag_run_queue - -0..N -1 + +0..N +1 dag_tag - -dag_tag - -dag_id - - [VARCHAR(250)] - NOT NULL - -name - - [VARCHAR(100)] - NOT NULL + +dag_tag + +dag_id + + [VARCHAR(250)] + NOT NULL + +name + + [VARCHAR(100)] + NOT NULL dag--dag_tag - -0..N -1 + +0..N +1 dag_owner_attributes - -dag_owner_attributes - -dag_id - - [VARCHAR(250)] - NOT NULL - -owner - - [VARCHAR(500)] - NOT NULL - -link - - [VARCHAR(500)] - NOT NULL + +dag_owner_attributes + +dag_id + + [VARCHAR(250)] + NOT NULL + +owner + + [VARCHAR(500)] + NOT NULL + +link + + [VARCHAR(500)] + NOT NULL dag--dag_owner_attributes - -0..N -1 + +0..N +1 dag_warning - -dag_warning - -dag_id - - [VARCHAR(250)] - NOT NULL - -warning_type - - [VARCHAR(50)] - NOT NULL - -message - - [TEXT] - NOT NULL - -timestamp - - [TIMESTAMP] - NOT NULL + +dag_warning + +dag_id + + [VARCHAR(250)] + NOT NULL + +warning_type + + [VARCHAR(50)] + NOT NULL + +message + + [TEXT] + NOT NULL + +timestamp + + [TIMESTAMP] + NOT NULL dag--dag_warning - -0..N -1 + +0..N +1 @@ -1117,813 +1122,813 @@ dag_run--dagrun_dataset_event - -0..N -1 + +0..N +1 task_instance - -task_instance - -dag_id - - [VARCHAR(250)] - NOT NULL - -map_index - - [INTEGER] - NOT NULL - -run_id - - [VARCHAR(250)] - NOT NULL - -task_id - - [VARCHAR(250)] - NOT NULL - -custom_operator_name - - [VARCHAR(1000)] - -duration - - [DOUBLE_PRECISION] - -end_date - - [TIMESTAMP] - -executor - - [VARCHAR(1000)] - -executor_config - - [BYTEA] - -external_executor_id - - [VARCHAR(250)] - -hostname - - [VARCHAR(1000)] - -job_id - - [INTEGER] - -max_tries - - [INTEGER] - -next_kwargs - - [JSON] - -next_method - - [VARCHAR(1000)] - -operator - - [VARCHAR(1000)] - -pid - - [INTEGER] - -pool - - [VARCHAR(256)] - NOT NULL - -pool_slots - - [INTEGER] - NOT NULL - -priority_weight - - [INTEGER] - -queue - - [VARCHAR(256)] - -queued_by_job_id - - [INTEGER] - -queued_dttm - - [TIMESTAMP] - -rendered_map_index - - [VARCHAR(250)] - -start_date - - [TIMESTAMP] - -state - - [VARCHAR(20)] - -task_display_name - - [VARCHAR(2000)] - -trigger_id - - [INTEGER] - -trigger_timeout - - [TIMESTAMP] - -try_number - - [INTEGER] - -unixname - - [VARCHAR(1000)] - -updated_at - - [TIMESTAMP] + +task_instance + +dag_id + + [VARCHAR(250)] + NOT NULL + +map_index + + [INTEGER] + NOT NULL + +run_id + + [VARCHAR(250)] + NOT NULL + +task_id + + [VARCHAR(250)] + NOT NULL + +custom_operator_name + + [VARCHAR(1000)] + +duration + + [DOUBLE_PRECISION] + +end_date + + [TIMESTAMP] + +executor + + [VARCHAR(1000)] + +executor_config + + [BYTEA] + +external_executor_id + + [VARCHAR(250)] + +hostname + + [VARCHAR(1000)] + +job_id + + [INTEGER] + +max_tries + + [INTEGER] + +next_kwargs + + [JSON] + +next_method + + [VARCHAR(1000)] + +operator + + [VARCHAR(1000)] + +pid + + [INTEGER] + +pool + + [VARCHAR(256)] + NOT NULL + +pool_slots + + [INTEGER] + NOT NULL + +priority_weight + + [INTEGER] + +queue + + [VARCHAR(256)] + +queued_by_job_id + + [INTEGER] + +queued_dttm + + [TIMESTAMP] + +rendered_map_index + + [VARCHAR(250)] + +start_date + + [TIMESTAMP] + +state + + [VARCHAR(20)] + +task_display_name + + [VARCHAR(2000)] + +trigger_id + + [INTEGER] + +trigger_timeout + + [TIMESTAMP] + +try_number + + [INTEGER] + +unixname + + [VARCHAR(1000)] + +updated_at + + [TIMESTAMP] dag_run--task_instance - -0..N -1 + +0..N +1 dag_run--task_instance - -0..N -1 + +0..N +1 dag_run_note - -dag_run_note - -dag_run_id - - [INTEGER] - NOT NULL - -content - - [VARCHAR(1000)] - -created_at - - [TIMESTAMP] - NOT NULL - -updated_at - - [TIMESTAMP] - NOT NULL - -user_id - - [INTEGER] + +dag_run_note + +dag_run_id + + [INTEGER] + NOT NULL + +content + + [VARCHAR(1000)] + +created_at + + [TIMESTAMP] + NOT NULL + +updated_at + + [TIMESTAMP] + NOT NULL + +user_id + + [INTEGER] dag_run--dag_run_note - -1 -1 + +1 +1 task_reschedule - -task_reschedule - -id - - [INTEGER] - NOT NULL - -dag_id - - [VARCHAR(250)] - NOT NULL - -duration - - [INTEGER] - NOT NULL - -end_date - - [TIMESTAMP] - NOT NULL - -map_index - - [INTEGER] - NOT NULL - -reschedule_date - - [TIMESTAMP] - NOT NULL - -run_id - - [VARCHAR(250)] - NOT NULL - -start_date - - [TIMESTAMP] - NOT NULL - -task_id - - [VARCHAR(250)] - NOT NULL - -try_number - - [INTEGER] - NOT NULL + +task_reschedule + +id + + [INTEGER] + NOT NULL + +dag_id + + [VARCHAR(250)] + NOT NULL + +duration + + [INTEGER] + NOT NULL + +end_date + + [TIMESTAMP] + NOT NULL + +map_index + + [INTEGER] + NOT NULL + +reschedule_date + + [TIMESTAMP] + NOT NULL + +run_id + + [VARCHAR(250)] + NOT NULL + +start_date + + [TIMESTAMP] + NOT NULL + +task_id + + [VARCHAR(250)] + NOT NULL + +try_number + + [INTEGER] + NOT NULL dag_run--task_reschedule - -0..N -1 + +0..N +1 dag_run--task_reschedule - -0..N -1 + +0..N +1 task_instance--task_reschedule - -0..N -1 + +0..N +1 task_instance--task_reschedule - -0..N -1 + +0..N +1 task_instance--task_reschedule - -0..N -1 + +0..N +1 task_instance--task_reschedule - -0..N -1 + +0..N +1 rendered_task_instance_fields - -rendered_task_instance_fields - -dag_id - - [VARCHAR(250)] - NOT NULL - -map_index - - [INTEGER] - NOT NULL - -run_id - - [VARCHAR(250)] - NOT NULL - -task_id - - [VARCHAR(250)] - NOT NULL - -k8s_pod_yaml - - [JSON] - -rendered_fields - - [JSON] - NOT NULL + +rendered_task_instance_fields + +dag_id + + [VARCHAR(250)] + NOT NULL + +map_index + + [INTEGER] + NOT NULL + +run_id + + [VARCHAR(250)] + NOT NULL + +task_id + + [VARCHAR(250)] + NOT NULL + +k8s_pod_yaml + + [JSON] + +rendered_fields + + [JSON] + NOT NULL task_instance--rendered_task_instance_fields - -0..N -1 + +0..N +1 task_instance--rendered_task_instance_fields - -0..N -1 + +0..N +1 task_instance--rendered_task_instance_fields - -0..N -1 + +0..N +1 task_instance--rendered_task_instance_fields - -0..N -1 + +0..N +1 task_fail - -task_fail - -id - - [INTEGER] - NOT NULL - -dag_id - - [VARCHAR(250)] - NOT NULL - -duration - - [INTEGER] - -end_date - - [TIMESTAMP] - -map_index - - [INTEGER] - NOT NULL - -run_id - - [VARCHAR(250)] - NOT NULL - -start_date - - [TIMESTAMP] - -task_id - - [VARCHAR(250)] - NOT NULL + +task_fail + +id + + [INTEGER] + NOT NULL + +dag_id + + [VARCHAR(250)] + NOT NULL + +duration + + [INTEGER] + +end_date + + [TIMESTAMP] + +map_index + + [INTEGER] + NOT NULL + +run_id + + [VARCHAR(250)] + NOT NULL + +start_date + + [TIMESTAMP] + +task_id + + [VARCHAR(250)] + NOT NULL task_instance--task_fail - -0..N -1 + +0..N +1 task_instance--task_fail - -0..N -1 + +0..N +1 task_instance--task_fail - -0..N -1 + +0..N +1 task_instance--task_fail - -0..N -1 + +0..N +1 task_map - -task_map - -dag_id - - [VARCHAR(250)] - NOT NULL - -map_index - - [INTEGER] - NOT NULL - -run_id - - [VARCHAR(250)] - NOT NULL - -task_id - - [VARCHAR(250)] - NOT NULL - -keys - - [JSON] - -length - - [INTEGER] - NOT NULL + +task_map + +dag_id + + [VARCHAR(250)] + NOT NULL + +map_index + + [INTEGER] + NOT NULL + +run_id + + [VARCHAR(250)] + NOT NULL + +task_id + + [VARCHAR(250)] + NOT NULL + +keys + + [JSON] + +length + + [INTEGER] + NOT NULL task_instance--task_map - -0..N -1 + +0..N +1 task_instance--task_map - -0..N -1 + +0..N +1 task_instance--task_map - -0..N -1 + +0..N +1 task_instance--task_map - -0..N -1 + +0..N +1 xcom - -xcom - -dag_run_id - - [INTEGER] - NOT NULL - -key - - [VARCHAR(512)] - NOT NULL - -map_index - - [INTEGER] - NOT NULL - -task_id - - [VARCHAR(250)] - NOT NULL - -dag_id - - [VARCHAR(250)] - NOT NULL - -run_id - - [VARCHAR(250)] - NOT NULL - -timestamp - - [TIMESTAMP] - NOT NULL - -value - - [BYTEA] + +xcom + +dag_run_id + + [INTEGER] + NOT NULL + +key + + [VARCHAR(512)] + NOT NULL + +map_index + + [INTEGER] + NOT NULL + +task_id + + [VARCHAR(250)] + NOT NULL + +dag_id + + [VARCHAR(250)] + NOT NULL + +run_id + + [VARCHAR(250)] + NOT NULL + +timestamp + + [TIMESTAMP] + NOT NULL + +value + + [BYTEA] task_instance--xcom - -0..N -1 + +0..N +1 task_instance--xcom - -0..N -1 + +0..N +1 task_instance--xcom - -0..N -1 + +0..N +1 task_instance--xcom - -0..N -1 + +0..N +1 task_instance_note - -task_instance_note - -dag_id - - [VARCHAR(250)] - NOT NULL - -map_index - - [INTEGER] - NOT NULL - -run_id - - [VARCHAR(250)] - NOT NULL - -task_id - - [VARCHAR(250)] - NOT NULL - -content - - [VARCHAR(1000)] - -created_at - - [TIMESTAMP] - NOT NULL - -updated_at - - [TIMESTAMP] - NOT NULL - -user_id - - [INTEGER] + +task_instance_note + +dag_id + + [VARCHAR(250)] + NOT NULL + +map_index + + [INTEGER] + NOT NULL + +run_id + + [VARCHAR(250)] + NOT NULL + +task_id + + [VARCHAR(250)] + NOT NULL + +content + + [VARCHAR(1000)] + +created_at + + [TIMESTAMP] + NOT NULL + +updated_at + + [TIMESTAMP] + NOT NULL + +user_id + + [INTEGER] task_instance--task_instance_note - -0..N -1 + +0..N +1 task_instance--task_instance_note - -0..N -1 + +0..N +1 task_instance--task_instance_note - -0..N -1 + +0..N +1 task_instance--task_instance_note - -0..N -1 + +0..N +1 task_instance_history - -task_instance_history - -id - - [INTEGER] - NOT NULL - -custom_operator_name - - [VARCHAR(1000)] - -dag_id - - [VARCHAR(250)] - NOT NULL - -duration - - [DOUBLE_PRECISION] - -end_date - - [TIMESTAMP] - -executor - - [VARCHAR(1000)] - -executor_config - - [BYTEA] - -external_executor_id - - [VARCHAR(250)] - -hostname - - [VARCHAR(1000)] - -job_id - - [INTEGER] - -map_index - - [INTEGER] - NOT NULL - -max_tries - - [INTEGER] - -next_kwargs - - [JSON] - -next_method - - [VARCHAR(1000)] - -operator - - [VARCHAR(1000)] - -pid - - [INTEGER] - -pool - - [VARCHAR(256)] - NOT NULL - -pool_slots - - [INTEGER] - NOT NULL - -priority_weight - - [INTEGER] - -queue - - [VARCHAR(256)] - -queued_by_job_id - - [INTEGER] - -queued_dttm - - [TIMESTAMP] - -rendered_map_index - - [VARCHAR(250)] - -run_id - - [VARCHAR(250)] - NOT NULL - -start_date - - [TIMESTAMP] - -state - - [VARCHAR(20)] - -task_display_name - - [VARCHAR(2000)] - -task_id - - [VARCHAR(250)] - NOT NULL - -trigger_id - - [INTEGER] - -trigger_timeout - - [TIMESTAMP] - -try_number - - [INTEGER] - NOT NULL - -unixname - - [VARCHAR(1000)] - -updated_at - - [TIMESTAMP] + +task_instance_history + +id + + [INTEGER] + NOT NULL + +custom_operator_name + + [VARCHAR(1000)] + +dag_id + + [VARCHAR(250)] + NOT NULL + +duration + + [DOUBLE_PRECISION] + +end_date + + [TIMESTAMP] + +executor + + [VARCHAR(1000)] + +executor_config + + [BYTEA] + +external_executor_id + + [VARCHAR(250)] + +hostname + + [VARCHAR(1000)] + +job_id + + [INTEGER] + +map_index + + [INTEGER] + NOT NULL + +max_tries + + [INTEGER] + +next_kwargs + + [JSON] + +next_method + + [VARCHAR(1000)] + +operator + + [VARCHAR(1000)] + +pid + + [INTEGER] + +pool + + [VARCHAR(256)] + NOT NULL + +pool_slots + + [INTEGER] + NOT NULL + +priority_weight + + [INTEGER] + +queue + + [VARCHAR(256)] + +queued_by_job_id + + [INTEGER] + +queued_dttm + + [TIMESTAMP] + +rendered_map_index + + [VARCHAR(250)] + +run_id + + [VARCHAR(250)] + NOT NULL + +start_date + + [TIMESTAMP] + +state + + [VARCHAR(20)] + +task_display_name + + [VARCHAR(2000)] + +task_id + + [VARCHAR(250)] + NOT NULL + +trigger_id + + [INTEGER] + +trigger_timeout + + [TIMESTAMP] + +try_number + + [INTEGER] + NOT NULL + +unixname + + [VARCHAR(1000)] + +updated_at + + [TIMESTAMP] task_instance--task_instance_history - -0..N -1 + +0..N +1 task_instance--task_instance_history - -0..N -1 + +0..N +1 task_instance--task_instance_history - -0..N -1 + +0..N +1 task_instance--task_instance_history - -0..N -1 + +0..N +1 @@ -1958,314 +1963,314 @@ trigger--task_instance - -0..N -{0,1} + +0..N +{0,1} alembic_version - -alembic_version - -version_num - - [VARCHAR(32)] - NOT NULL + +alembic_version + +version_num + + [VARCHAR(32)] + NOT NULL session - -session - -id - - [INTEGER] - NOT NULL - -data - - [BYTEA] - -expiry - - [TIMESTAMP] - -session_id - - [VARCHAR(255)] + +session + +id + + [INTEGER] + NOT NULL + +data + + [BYTEA] + +expiry + + [TIMESTAMP] + +session_id + + [VARCHAR(255)] ab_user - -ab_user - -id - - [INTEGER] - NOT NULL - -active - - [BOOLEAN] - -changed_by_fk - - [INTEGER] - -changed_on - - [TIMESTAMP] - -created_by_fk - - [INTEGER] - -created_on - - [TIMESTAMP] - -email - - [VARCHAR(512)] - NOT NULL - -fail_login_count - - [INTEGER] - -first_name - - [VARCHAR(256)] - NOT NULL - -last_login - - [TIMESTAMP] - -last_name - - [VARCHAR(256)] - NOT NULL - -login_count - - [INTEGER] - -password - - [VARCHAR(256)] - -username - - [VARCHAR(512)] - NOT NULL + +ab_user + +id + + [INTEGER] + NOT NULL + +active + + [BOOLEAN] + +changed_by_fk + + [INTEGER] + +changed_on + + [TIMESTAMP] + +created_by_fk + + [INTEGER] + +created_on + + [TIMESTAMP] + +email + + [VARCHAR(512)] + NOT NULL + +fail_login_count + + [INTEGER] + +first_name + + [VARCHAR(256)] + NOT NULL + +last_login + + [TIMESTAMP] + +last_name + + [VARCHAR(256)] + NOT NULL + +login_count + + [INTEGER] + +password + + [VARCHAR(256)] + +username + + [VARCHAR(512)] + NOT NULL ab_user--ab_user - -0..N -{0,1} + +0..N +{0,1} ab_user--ab_user - -0..N -{0,1} + +0..N +{0,1} ab_user_role - -ab_user_role - -id - - [INTEGER] - NOT NULL - -role_id - - [INTEGER] - -user_id - - [INTEGER] + +ab_user_role + +id + + [INTEGER] + NOT NULL + +role_id + + [INTEGER] + +user_id + + [INTEGER] ab_user--ab_user_role - -0..N -{0,1} + +0..N +{0,1} ab_register_user - -ab_register_user - -id - - [INTEGER] - NOT NULL - -email - - [VARCHAR(512)] - NOT NULL - -first_name - - [VARCHAR(256)] - NOT NULL - -last_name - - [VARCHAR(256)] - NOT NULL - -password - - [VARCHAR(256)] - -registration_date - - [TIMESTAMP] - -registration_hash - - [VARCHAR(256)] - -username - - [VARCHAR(512)] - NOT NULL + +ab_register_user + +id + + [INTEGER] + NOT NULL + +email + + [VARCHAR(512)] + NOT NULL + +first_name + + [VARCHAR(256)] + NOT NULL + +last_name + + [VARCHAR(256)] + NOT NULL + +password + + [VARCHAR(256)] + +registration_date + + [TIMESTAMP] + +registration_hash + + [VARCHAR(256)] + +username + + [VARCHAR(512)] + NOT NULL ab_permission - -ab_permission - -id - - [INTEGER] - NOT NULL - -name - - [VARCHAR(100)] - NOT NULL + +ab_permission + +id + + [INTEGER] + NOT NULL + +name + + [VARCHAR(100)] + NOT NULL ab_permission_view - -ab_permission_view - -id - - [INTEGER] - NOT NULL - -permission_id - - [INTEGER] - -view_menu_id - - [INTEGER] + +ab_permission_view + +id + + [INTEGER] + NOT NULL + +permission_id + + [INTEGER] + +view_menu_id + + [INTEGER] ab_permission--ab_permission_view - -0..N -{0,1} + +0..N +{0,1} ab_permission_view_role - -ab_permission_view_role - -id - - [INTEGER] - NOT NULL - -permission_view_id - - [INTEGER] - -role_id - - [INTEGER] + +ab_permission_view_role + +id + + [INTEGER] + NOT NULL + +permission_view_id + + [INTEGER] + +role_id + + [INTEGER] ab_permission_view--ab_permission_view_role - -0..N -{0,1} + +0..N +{0,1} ab_view_menu - -ab_view_menu - -id - - [INTEGER] - NOT NULL - -name - - [VARCHAR(250)] - NOT NULL + +ab_view_menu + +id + + [INTEGER] + NOT NULL + +name + + [VARCHAR(250)] + NOT NULL ab_view_menu--ab_permission_view - -0..N -{0,1} + +0..N +{0,1} ab_role - -ab_role - -id - - [INTEGER] - NOT NULL - -name - - [VARCHAR(64)] - NOT NULL + +ab_role + +id + + [INTEGER] + NOT NULL + +name + + [VARCHAR(64)] + NOT NULL ab_role--ab_user_role - -0..N -{0,1} + +0..N +{0,1} ab_role--ab_permission_view_role - -0..N -{0,1} + +0..N +{0,1} diff --git a/docs/apache-airflow/migrations-ref.rst b/docs/apache-airflow/migrations-ref.rst index d93e0d78d8f67..feacedce6e680 100644 --- a/docs/apache-airflow/migrations-ref.rst +++ b/docs/apache-airflow/migrations-ref.rst @@ -39,7 +39,9 @@ Here's the list of all the Database Migrations that are executed via when you ru +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | Revision ID | Revises ID | Airflow Version | Description | +=========================+==================+===================+==============================================================+ -| ``0bfc26bc256e`` (head) | ``d0f1c55954fa`` | ``3.0.0`` | Rename DagModel schedule_interval to timetable_summary. | +| ``0d9e73a75ee4`` (head) | ``0bfc26bc256e`` | ``3.0.0`` | Add name field to DatasetModel. | ++-------------------------+------------------+-------------------+--------------------------------------------------------------+ +| ``0bfc26bc256e`` | ``d0f1c55954fa`` | ``3.0.0`` | Rename DagModel schedule_interval to timetable_summary. | +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | ``d0f1c55954fa`` | ``044f740568ec`` | ``3.0.0`` | Remove SubDAGs: ``is_subdag`` & ``root_dag_id`` columns from | | | | | DAG table. | From 94895e0f6746cdbcd544e32cd61344a9e3da9a66 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Wed, 21 Aug 2024 14:22:44 +0800 Subject: [PATCH 2/7] Allow cross attr reference --- airflow/datasets/__init__.py | 52 +++++++++++++++++++++++++++++------- 1 file changed, 43 insertions(+), 9 deletions(-) diff --git a/airflow/datasets/__init__.py b/airflow/datasets/__init__.py index 2fa1c972853b2..4dc385b6c5727 100644 --- a/airflow/datasets/__init__.py +++ b/airflow/datasets/__init__.py @@ -257,6 +257,17 @@ class DatasetAliasEvent(TypedDict): dest_dataset_uri: str +@attr.define() +class NoComparison(ArithmeticError): + """Exception for when two datasets cannot be compared directly.""" + + a: Dataset + b: Dataset + + def __str__(self) -> str: + return f"Can not compare {self.a} and {self.b}" + + @attr.define() class Dataset(os.PathLike, BaseDataset): """A representation of data dependencies between workflows.""" @@ -275,21 +286,44 @@ class Dataset(os.PathLike, BaseDataset): def __attrs_post_init__(self) -> None: if self.name is None and self.uri is None: raise TypeError("Dataset requires either name or URI") - if self.name is None: - self.name = self.uri - elif self.uri is None: - self.uri = self.name def __fspath__(self) -> str: return self.uri def __eq__(self, other: Any) -> bool: - if isinstance(other, self.__class__): + """ + Check equality of two datasets. + + Since either *name* or *uri* is required, and we ensure integrity when + DAG files are parsed, we only need to consider the following combos: + + * Both datasets have name and uri defined: Both fields must match. + * One dataset have only one field (name or uri) defined: The field + defined by both must match. + * Both datasets have the same one field defined: The field must match. + * Either dataset has the other field defined (e.g. *self* defines only + *name*, but *other* only *uri*): The two cannot be reliably compared, + and (a subclass of) *ArithmeticError* is raised. + + In the last case, we can still check dataset equality by querying the + database. We do not do here though since that has too much performance + implication. The call site should consider the possibility instead. + + However, since *Dataset* objects created from the meta-database (e.g. + those in the task execution context) would have both concrete name and + URI values filled by the DAG parser. Non-comparability only happens if + the user accesses the dataset objects that aren't created from the + database, say globally in a DAG file. This is discouraged anyway. + """ + if not isinstance(other, self.__class__): + return NotImplemented + if self.name is not None and other.name is not None: + if self.uri is None or other.uri is None: + return self.name == other.name + return self.name == other.name and self.uri == other.uri + if self.uri is not None and other.uri is not None: return self.uri == other.uri - return NotImplemented - - def __hash__(self) -> int: - return hash(self.uri) + raise NoComparison(self, other) @property def normalized_uri(self) -> str | None: From c9b310f4e250f5d70825f86960802786c99b135b Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Tue, 27 Aug 2024 19:47:03 +0800 Subject: [PATCH 3/7] Try to auto-fill name from other datasets Instead of automatically fill from the other field, try to find in the database if there's a best matching choice, and use that instead. Various issues. Going to write them up later... --- airflow/datasets/__init__.py | 3 - airflow/datasets/manager.py | 10 +- airflow/datasets/references.py | 122 ++++++++++++++ ...5_3_0_0_add_name_field_to_dataset_model.py | 48 ++++-- airflow/models/dag.py | 150 +++++++++--------- airflow/models/dataset.py | 10 +- docs/apache-airflow/img/airflow_erd.sha256 | 2 +- tests/models/test_dag.py | 6 +- .../serialization/test_serialized_objects.py | 2 +- 9 files changed, 255 insertions(+), 98 deletions(-) create mode 100644 airflow/datasets/references.py diff --git a/airflow/datasets/__init__.py b/airflow/datasets/__init__.py index 4dc385b6c5727..eea3184747f1c 100644 --- a/airflow/datasets/__init__.py +++ b/airflow/datasets/__init__.py @@ -233,9 +233,6 @@ def __eq__(self, other: Any) -> bool: return self.name == other.name return NotImplemented - def __hash__(self) -> int: - return hash(self.name) - def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]: """ Iterate a dataset alias as dag dependency. diff --git a/airflow/datasets/manager.py b/airflow/datasets/manager.py index 058eef6ab8922..850462fc557c7 100644 --- a/airflow/datasets/manager.py +++ b/airflow/datasets/manager.py @@ -25,7 +25,6 @@ from airflow.api_internal.internal_api_call import internal_api_call from airflow.configuration import conf -from airflow.datasets import Dataset from airflow.listeners.listener import get_listener_manager from airflow.models.dagbag import DagPriorityParsingRequest from airflow.models.dataset import ( @@ -43,6 +42,7 @@ if TYPE_CHECKING: from sqlalchemy.orm.session import Session + from airflow.datasets import Dataset from airflow.models.dag import DagModel from airflow.models.taskinstance import TaskInstance @@ -58,14 +58,18 @@ class DatasetManager(LoggingMixin): def __init__(self, **kwargs): super().__init__(**kwargs) - def create_datasets(self, dataset_models: list[DatasetModel], session: Session) -> None: + def create_datasets(self, dataset_models: Iterable[DatasetModel], session: Session) -> None: """Create new datasets.""" for dataset_model in dataset_models: + if not dataset_model.name: + dataset_model.name = dataset_model.uri + elif not dataset_model.uri: + dataset_model.uri = dataset_model.name session.add(dataset_model) session.flush() for dataset_model in dataset_models: - self.notify_dataset_created(dataset=Dataset(uri=dataset_model.uri, extra=dataset_model.extra)) + self.notify_dataset_created(dataset=dataset_model.as_public()) @classmethod @internal_api_call diff --git a/airflow/datasets/references.py b/airflow/datasets/references.py new file mode 100644 index 0000000000000..612c5a4b4459a --- /dev/null +++ b/airflow/datasets/references.py @@ -0,0 +1,122 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Dataset reference objects. + +These are intermediate representations of DAG- and task-level references to +Dataset and DatasetAlias. These are meant only for Airflow internals so the DAG +processor can collect information from DAGs without the "full picture", which is +only available when it updates the database. +""" + +from __future__ import annotations + +import dataclasses +from typing import TYPE_CHECKING, Union + +if TYPE_CHECKING: + from airflow.datasets import Dataset, DatasetAlias + from airflow.models.dataset import ( + DagScheduleDatasetAliasReference, + DagScheduleDatasetReference, + DatasetAliasModel, + DatasetModel, + TaskOutletDatasetReference, + ) + +DatasetReference = Union["DatasetNameReference", "DatasetURIReference"] + +DatasetOrAliasReference = Union[DatasetReference, "DatasetAliasReference"] + + +def create_dag_dataset_reference(source: Dataset) -> DatasetReference: + """Create reference to a dataset.""" + if source.name: + return DatasetNameReference(source.name) + return DatasetURIReference(source.uri) + + +def create_dag_dataset_alias_reference(source: DatasetAlias) -> DatasetAliasReference: + """Create reference to a dataset or dataset alias.""" + return DatasetAliasReference(source.name) + + +@dataclasses.dataclass +class DatasetNameReference: + """Reference to a dataset by name.""" + + name: str + + def __hash__(self) -> int: + return hash((self.__class__.__name__, self.name)) + + +@dataclasses.dataclass +class DatasetURIReference: + """Reference to a dataset by URI.""" + + uri: str + + def __hash__(self) -> int: + return hash((self.__class__.__name__, self.uri)) + + +@dataclasses.dataclass +class DatasetAliasReference: + """Reference to a dataset alias.""" + + name: str + + def __hash__(self) -> int: + return hash((self.__class__.__name__, self.name)) + + +def resolve_dag_schedule_reference( + ref: DatasetOrAliasReference, + *, + dag_id: str, + dataset_names: dict[str, DatasetModel], + dataset_uris: dict[str, DatasetModel], + alias_names: dict[str, DatasetAliasModel], +) -> DagScheduleDatasetReference | DagScheduleDatasetAliasReference: + """Create database representation from DAG-level references.""" + from airflow.models.dataset import DagScheduleDatasetAliasReference, DagScheduleDatasetReference + + if isinstance(ref, DatasetNameReference): + return DagScheduleDatasetReference(dataset_id=dataset_names[ref.name].id, dag_id=dag_id) + elif isinstance(ref, DatasetURIReference): + return DagScheduleDatasetReference(dataset_id=dataset_uris[ref.uri].id, dag_id=dag_id) + return DagScheduleDatasetAliasReference(alias_id=alias_names[ref.name], dag_id=dag_id) + + +def resolve_task_outlet_reference( + ref: DatasetReference, + *, + dag_id: str, + task_id: str, + dataset_names: dict[str, DatasetModel], + dataset_uris: dict[str, DatasetModel], +) -> TaskOutletDatasetReference: + """Create database representation from task-level references.""" + from airflow.models.dataset import TaskOutletDatasetReference + + if isinstance(ref, DatasetURIReference): + dataset = dataset_uris[ref.uri] + else: + dataset = dataset_names[ref.name] + return TaskOutletDatasetReference(dataset_id=dataset.id, dag_id=dag_id, task_id=task_id) diff --git a/airflow/migrations/versions/0005_3_0_0_add_name_field_to_dataset_model.py b/airflow/migrations/versions/0005_3_0_0_add_name_field_to_dataset_model.py index ad5a712fcace0..a50087a79e06e 100644 --- a/airflow/migrations/versions/0005_3_0_0_add_name_field_to_dataset_model.py +++ b/airflow/migrations/versions/0005_3_0_0_add_name_field_to_dataset_model.py @@ -19,16 +19,22 @@ """ Add name field to DatasetModel. +This also renames two indexes. Index names are scoped to the entire database. +Airflow generally includes the table's name to manually scope the index, but +``idx_uri_unique`` (on DatasetModel) and ``idx_name_unique`` (on +DatasetAliasModel) do not do this. They are renamed here so we can create a +unique index on DatasetModel as well. + Revision ID: 0d9e73a75ee4 -Revises: 044f740568ec +Revises: 0bfc26bc256e Create Date: 2024-08-13 09:45:32.213222 - """ from __future__ import annotations import sqlalchemy as sa from alembic import op +from sqlalchemy.orm import Session # revision identifiers, used by Alembic. revision = "0d9e73a75ee4" @@ -37,22 +43,38 @@ depends_on = None airflow_version = "3.0.0" +_NAME_COLUMN_TYPE = sa.String(length=3000).with_variant( + sa.String(length=3000, collation="latin1_general_cs"), + dialect_name="mysql", +) + def upgrade(): - """Add name field to dataset model.""" + # Fix index name on DatasetAlias. + with op.batch_alter_table("dataset_alias", schema=None) as batch_op: + batch_op.drop_index("idx_name_unique") + batch_op.create_index("idx_dataset_alias_name_unique", ["name"], unique=True) + # Fix index name (of 'uri') on Dataset. + # Add 'name' column. Set it to nullable for now. + with op.batch_alter_table("dataset", schema=None) as batch_op: + batch_op.drop_index("idx_uri_unique") + batch_op.create_index("idx_dataset_uri_unique", ["uri"], unique=True) + batch_op.add_column(sa.Column("name", _NAME_COLUMN_TYPE)) + # Fill name from uri column. + Session(bind=op.get_bind()).execute(sa.text("update dataset set name=uri")) + # Set the name column non-nullable. + # Now with values in there, we can create the unique constraint and index. with op.batch_alter_table("dataset", schema=None) as batch_op: - batch_op.add_column( - sa.Column( - "name", - sa.String(length=3000).with_variant( - sa.String(length=3000, collation="latin1_general_cs"), "mysql" - ), - nullable=False, - ) - ) + batch_op.alter_column("name", existing_type=_NAME_COLUMN_TYPE, nullable=False) + batch_op.create_index("idx_dataset_name_unique", ["name"], unique=True) def downgrade(): - """Remove name field to dataset model.""" with op.batch_alter_table("dataset", schema=None) as batch_op: + batch_op.drop_index("idx_dataset_name_unique") batch_op.drop_column("name") + batch_op.drop_index("idx_dataset_uri_unique") + batch_op.create_index("idx_uri_unique", ["uri"], unique=True) + with op.batch_alter_table("dataset_alias", schema=None) as batch_op: + batch_op.drop_index("idx_dataset_alias_name_unique") + batch_op.create_index("idx_name_unique", ["name"], unique=True) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index bd1679e177b42..79450fdb18551 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -83,6 +83,7 @@ from airflow.configuration import conf as airflow_conf, secrets_backend_list from airflow.datasets import BaseDataset, Dataset, DatasetAlias, DatasetAll from airflow.datasets.manager import dataset_manager +from airflow.datasets.references import resolve_dag_schedule_reference, resolve_task_outlet_reference from airflow.exceptions import ( AirflowException, DuplicateTaskIdFound, @@ -100,11 +101,7 @@ from airflow.models.dagcode import DagCode from airflow.models.dagpickle import DagPickle from airflow.models.dagrun import RUN_ID_REGEX, DagRun -from airflow.models.dataset import ( - DatasetAliasModel, - DatasetDagRunQueue, - DatasetModel, -) +from airflow.models.dataset import DatasetAliasModel, DatasetDagRunQueue, DatasetModel from airflow.models.param import DagParam, ParamsDict from airflow.models.taskinstance import ( Context, @@ -2642,8 +2639,6 @@ def bulk_write_to_db( if not dags: return - from airflow.models.dataset import DagScheduleDatasetAliasReference - log.info("Sync %s DAGs", len(dags)) dag_by_ids = {dag.dag_id: dag for dag in dags} @@ -2742,19 +2737,20 @@ def bulk_write_to_db( DagCode.bulk_sync_to_db(filelocs, session=session) from airflow.datasets import Dataset - from airflow.models.dataset import ( - DagScheduleDatasetReference, - DatasetModel, - TaskOutletDatasetReference, + from airflow.datasets.references import ( + DatasetOrAliasReference, + DatasetReference, + create_dag_dataset_alias_reference, + create_dag_dataset_reference, ) + from airflow.models.dataset import DatasetModel - dag_references: dict[str, set[Dataset | DatasetAlias]] = defaultdict(set) - outlet_references = defaultdict(set) + dag_references: dict[str, set[DatasetOrAliasReference]] = defaultdict(set) + outlet_references: dict[tuple[str, str], set[DatasetReference]] = defaultdict(set) # We can't use a set here as we want to preserve order outlet_datasets: dict[DatasetModel, None] = {} input_datasets: dict[DatasetModel, None] = {} - outlet_dataset_alias_models: set[DatasetAliasModel] = set() - input_dataset_aliases: set[DatasetAliasModel] = set() + lineage_dataset_alias_models: list[DatasetAliasModel] = [] # here we go through dags and tasks to check for dataset references # if there are now None and previously there were some, we delete them @@ -2770,105 +2766,111 @@ def bulk_write_to_db( curr_orm_dag.schedule_dataset_alias_references = [] else: for _, dataset in dataset_condition.iter_datasets(): - dag_references[dag.dag_id].add(Dataset(uri=dataset.uri)) + dag_references[dag.dag_id].add(create_dag_dataset_reference(dataset)) input_datasets[DatasetModel.from_public(dataset)] = None for dataset_alias in dataset_condition.iter_dataset_aliases(): - dag_references[dag.dag_id].add(dataset_alias) - input_dataset_aliases.add(DatasetAliasModel.from_public(dataset_alias)) + dag_references[dag.dag_id].add(create_dag_dataset_alias_reference(dataset_alias)) + lineage_dataset_alias_models.append(DatasetAliasModel.from_public(dataset_alias)) curr_outlet_references = curr_orm_dag and curr_orm_dag.task_outlet_dataset_references for task in dag.tasks: dataset_outlets: list[Dataset] = [] - dataset_alias_outlets: set[DatasetAlias] = set() for outlet in task.outlets: if isinstance(outlet, Dataset): dataset_outlets.append(outlet) elif isinstance(outlet, DatasetAlias): - dataset_alias_outlets.add(outlet) - - if not dataset_outlets: - if curr_outlet_references: - this_task_outlet_refs = [ - x - for x in curr_outlet_references - if x.dag_id == dag.dag_id and x.task_id == task.task_id - ] - for ref in this_task_outlet_refs: - curr_outlet_references.remove(ref) + lineage_dataset_alias_models.append(DatasetAliasModel.from_public(outlet)) + + if not dataset_outlets and curr_outlet_references: + this_task_outlet_refs = [ + x + for x in curr_outlet_references + if x.dag_id == dag.dag_id and x.task_id == task.task_id + ] + for ref in this_task_outlet_refs: + curr_outlet_references.remove(ref) for d in dataset_outlets: - outlet_references[(task.dag_id, task.task_id)].add(d.uri) + outlet_references[(task.dag_id, task.task_id)].add(create_dag_dataset_reference(d)) outlet_datasets[DatasetModel.from_public(d)] = None - for d_a in dataset_alias_outlets: - outlet_dataset_alias_models.add(DatasetAliasModel.from_public(d_a)) - all_datasets = outlet_datasets all_datasets.update(input_datasets) # store datasets - stored_datasets: dict[str, DatasetModel] = {} + stored_datasets_by_name: dict[str, DatasetModel] = {} + stored_datasets_by_uri: dict[str, DatasetModel] = {} new_datasets: list[DatasetModel] = [] for dataset in all_datasets: - stored_dataset = session.scalar( - select(DatasetModel).where(DatasetModel.uri == dataset.uri).limit(1) - ) - if stored_dataset: + stmt = select(DatasetModel) + if dataset.name: + stmt = stmt.where(DatasetModel.name == dataset.name) + else: + # We match both name and URI here because a new unnamed dataset + # will use the URI as the default name, and that would cause a + # conflict if another dataset (with a different URI) uses that + # same value as the name. + stmt = ( + stmt.where(or_(DatasetModel.uri == dataset.uri, DatasetModel.name == dataset.uri)) + # If both cases are found, prefer the dataset with matching URI. + .order_by(case((DatasetModel.uri == dataset.uri, 0), else_=1)) + ) + if stored_dataset := session.scalar(stmt.where(DatasetModel.uri == dataset.uri).limit(1)): # Some datasets may have been previously unreferenced, and therefore orphaned by the # scheduler. But if we're here, then we have found that dataset again in our DAGs, which # means that it is no longer an orphan, so set is_orphaned to False. stored_dataset.is_orphaned = expression.false() - stored_datasets[stored_dataset.uri] = stored_dataset + if dataset.name: + stored_dataset.name = dataset.name + if dataset.uri: + stored_dataset.uri = dataset.uri + stored_dataset.extra = dataset.extra + stored_datasets_by_name[stored_dataset.name] = stored_dataset + stored_datasets_by_uri[stored_dataset.uri] = stored_dataset else: new_datasets.append(dataset) dataset_manager.create_datasets(dataset_models=new_datasets, session=session) - stored_datasets.update({dataset.uri: dataset for dataset in new_datasets}) + stored_datasets_by_name.update((ds.name, ds) for ds in new_datasets) + stored_datasets_by_uri.update((ds.uri, ds) for ds in new_datasets) del new_datasets del all_datasets # store dataset aliases - all_datasets_alias_models = input_dataset_aliases | outlet_dataset_alias_models stored_dataset_aliases: dict[str, DatasetAliasModel] = {} - new_dataset_alias_models: set[DatasetAliasModel] = set() - if all_datasets_alias_models: - all_dataset_alias_names = {dataset_alias.name for dataset_alias in all_datasets_alias_models} - + new_dataset_alias_models: dict[str, DatasetAliasModel] = {} + if lineage_dataset_alias_models: + all_dataset_alias_names = {dataset_alias.name for dataset_alias in lineage_dataset_alias_models} stored_dataset_aliases = { dsa_m.name: dsa_m for dsa_m in session.scalars( select(DatasetAliasModel).where(DatasetAliasModel.name.in_(all_dataset_alias_names)) - ).fetchall() + ) } - - if stored_dataset_aliases: - new_dataset_alias_models = { - dataset_alias_model - for dataset_alias_model in all_datasets_alias_models - if dataset_alias_model.name not in stored_dataset_aliases.keys() - } - else: - new_dataset_alias_models = all_datasets_alias_models - - session.add_all(new_dataset_alias_models) + new_dataset_alias_models = { + dataset_alias_model.name: dataset_alias_model + for dataset_alias_model in lineage_dataset_alias_models + if dataset_alias_model.name not in stored_dataset_aliases + } + session.add_all(new_dataset_alias_models.values()) session.flush() - stored_dataset_aliases.update( - {dataset_alias.name: dataset_alias for dataset_alias in new_dataset_alias_models} - ) + stored_dataset_aliases.update(new_dataset_alias_models) del new_dataset_alias_models - del all_datasets_alias_models + del lineage_dataset_alias_models # reconcile dag-schedule-on-dataset and dag-schedule-on-dataset-alias references - for dag_id, base_dataset_list in dag_references.items(): + for dag_id, dag_refs in dag_references.items(): dag_refs_needed = { - DagScheduleDatasetReference(dataset_id=stored_datasets[base_dataset.uri].id, dag_id=dag_id) - if isinstance(base_dataset, Dataset) - else DagScheduleDatasetAliasReference( - alias_id=stored_dataset_aliases[base_dataset.name].id, dag_id=dag_id + resolve_dag_schedule_reference( + ref, + dag_id=dag_id, + dataset_names=stored_datasets_by_name, + dataset_uris=stored_datasets_by_uri, + alias_names=stored_dataset_aliases, ) - for base_dataset in base_dataset_list + for ref in dag_refs } dag_refs_stored = ( @@ -2888,10 +2890,16 @@ def bulk_write_to_db( existing_task_outlet_refs_dict[(dag_id, todr.task_id)].add(todr) # reconcile task-outlet-dataset references - for (dag_id, task_id), uri_list in outlet_references.items(): + for (dag_id, task_id), outlet_refs in outlet_references.items(): task_refs_needed = { - TaskOutletDatasetReference(dataset_id=stored_datasets[uri].id, dag_id=dag_id, task_id=task_id) - for uri in uri_list + resolve_task_outlet_reference( + ref, + dag_id=dag_id, + task_id=task_id, + dataset_names=stored_datasets_by_name, + dataset_uris=stored_datasets_by_uri, + ) + for ref in outlet_refs } task_refs_stored = existing_task_outlet_refs_dict[(dag_id, task_id)] task_refs_to_add = {x for x in task_refs_needed if x not in task_refs_stored} diff --git a/airflow/models/dataset.py b/airflow/models/dataset.py index b22da7038d5bc..f28573df8617c 100644 --- a/airflow/models/dataset.py +++ b/airflow/models/dataset.py @@ -106,7 +106,7 @@ class DatasetAliasModel(Base): __tablename__ = "dataset_alias" __table_args__ = ( - Index("idx_name_unique", name, unique=True), + Index("idx_dataset_alias_name_unique", name, unique=True), {"sqlite_autoincrement": True}, # ensures PK values not reused ) @@ -182,13 +182,17 @@ class DatasetModel(Base): __tablename__ = "dataset" __table_args__ = ( - Index("idx_uri_unique", uri, unique=True), + Index("idx_dataset_name_unique", name, unique=True), + Index("idx_dataset_uri_unique", uri, unique=True), {"sqlite_autoincrement": True}, # ensures PK values not reused ) @classmethod def from_public(cls, obj: Dataset) -> DatasetModel: - return cls(uri=obj.uri, extra=obj.extra) + return cls(name=obj.name, uri=obj.uri, extra=obj.extra) + + def as_public(self) -> Dataset: + return Dataset(name=self.name, uri=self.uri, extra=self.extra) def __init__(self, uri: str, **kwargs): try: diff --git a/docs/apache-airflow/img/airflow_erd.sha256 b/docs/apache-airflow/img/airflow_erd.sha256 index 1921a05ad3126..bfa06102aada3 100644 --- a/docs/apache-airflow/img/airflow_erd.sha256 +++ b/docs/apache-airflow/img/airflow_erd.sha256 @@ -1 +1 @@ -63ad331c1b28dca55e57f53a7248ac36c017a917ae6831721c45a9d12c709ef6 \ No newline at end of file +1e6c3a9cc81a67cc76ef515779b5989d6a94e45a039c6a2b563b5b5714efcaea \ No newline at end of file diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 17e9c1b165a9f..a0343fe85c4d9 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -2651,10 +2651,10 @@ def test_dataset_expression(self, session: Session) -> None: dag = DAG( dag_id="test_dag_dataset_expression", schedule=DatasetAny( - Dataset("s3://dag1/output_1.txt", {"hi": "bye"}), + Dataset("s3://dag1/output_1.txt", extra={"hi": "bye"}), DatasetAll( - Dataset("s3://dag2/output_1.txt", {"hi": "bye"}), - Dataset("s3://dag3/output_3.txt", {"hi": "bye"}), + Dataset("s3://dag2/output_1.txt", extra={"hi": "bye"}), + Dataset("s3://dag3/output_3.txt", extra={"hi": "bye"}), ), DatasetAlias(name="test_name"), ), diff --git a/tests/serialization/test_serialized_objects.py b/tests/serialization/test_serialized_objects.py index 20ed3954e12aa..ff89a7bfeeb47 100644 --- a/tests/serialization/test_serialized_objects.py +++ b/tests/serialization/test_serialized_objects.py @@ -328,7 +328,7 @@ def test_backcompat_deserialize_connection(conn_uri): id=1, filename="test_file", elasticsearch_id="test_id", created_at=datetime.now() ), DagTagPydantic: DagTag(), - DatasetPydantic: Dataset("uri", {}), + DatasetPydantic: Dataset("uri"), DatasetEventPydantic: DatasetEvent(), } From b310108ce820eb486602f4b13388df5b8b973af5 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Wed, 28 Aug 2024 13:23:13 +0800 Subject: [PATCH 4/7] Also do not use __hash__ for dedup --- airflow/models/dag.py | 39 +++++++++++++++++---------------------- airflow/models/dataset.py | 28 +++++++++++----------------- 2 files changed, 28 insertions(+), 39 deletions(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 79450fdb18551..2a31e8b20d804 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -2747,10 +2747,8 @@ def bulk_write_to_db( dag_references: dict[str, set[DatasetOrAliasReference]] = defaultdict(set) outlet_references: dict[tuple[str, str], set[DatasetReference]] = defaultdict(set) - # We can't use a set here as we want to preserve order - outlet_datasets: dict[DatasetModel, None] = {} - input_datasets: dict[DatasetModel, None] = {} - lineage_dataset_alias_models: list[DatasetAliasModel] = [] + all_datasets: list[DatasetModel] = [] + all_dataset_aliases: list[DatasetAliasModel] = [] # here we go through dags and tasks to check for dataset references # if there are now None and previously there were some, we delete them @@ -2767,11 +2765,11 @@ def bulk_write_to_db( else: for _, dataset in dataset_condition.iter_datasets(): dag_references[dag.dag_id].add(create_dag_dataset_reference(dataset)) - input_datasets[DatasetModel.from_public(dataset)] = None + all_datasets.append(DatasetModel.from_public(dataset)) for dataset_alias in dataset_condition.iter_dataset_aliases(): dag_references[dag.dag_id].add(create_dag_dataset_alias_reference(dataset_alias)) - lineage_dataset_alias_models.append(DatasetAliasModel.from_public(dataset_alias)) + all_dataset_aliases.append(DatasetAliasModel.from_public(dataset_alias)) curr_outlet_references = curr_orm_dag and curr_orm_dag.task_outlet_dataset_references for task in dag.tasks: @@ -2780,7 +2778,7 @@ def bulk_write_to_db( if isinstance(outlet, Dataset): dataset_outlets.append(outlet) elif isinstance(outlet, DatasetAlias): - lineage_dataset_alias_models.append(DatasetAliasModel.from_public(outlet)) + all_dataset_aliases.append(DatasetAliasModel.from_public(outlet)) if not dataset_outlets and curr_outlet_references: this_task_outlet_refs = [ @@ -2793,10 +2791,7 @@ def bulk_write_to_db( for d in dataset_outlets: outlet_references[(task.dag_id, task.task_id)].add(create_dag_dataset_reference(d)) - outlet_datasets[DatasetModel.from_public(d)] = None - - all_datasets = outlet_datasets - all_datasets.update(input_datasets) + all_datasets.append(DatasetModel.from_public(d)) # store datasets stored_datasets_by_name: dict[str, DatasetModel] = {} @@ -2839,26 +2834,26 @@ def bulk_write_to_db( # store dataset aliases stored_dataset_aliases: dict[str, DatasetAliasModel] = {} - new_dataset_alias_models: dict[str, DatasetAliasModel] = {} - if lineage_dataset_alias_models: - all_dataset_alias_names = {dataset_alias.name for dataset_alias in lineage_dataset_alias_models} + new_dataset_aliases: dict[str, DatasetAliasModel] = {} + if all_dataset_aliases: + all_dataset_alias_names = {dataset_alias.name for dataset_alias in all_dataset_aliases} stored_dataset_aliases = { dsa_m.name: dsa_m for dsa_m in session.scalars( select(DatasetAliasModel).where(DatasetAliasModel.name.in_(all_dataset_alias_names)) ) } - new_dataset_alias_models = { - dataset_alias_model.name: dataset_alias_model - for dataset_alias_model in lineage_dataset_alias_models - if dataset_alias_model.name not in stored_dataset_aliases + new_dataset_aliases = { + dataset_alias.name: dataset_alias + for dataset_alias in all_dataset_aliases + if dataset_alias.name not in stored_dataset_aliases } - session.add_all(new_dataset_alias_models.values()) + session.add_all(new_dataset_aliases.values()) session.flush() - stored_dataset_aliases.update(new_dataset_alias_models) + stored_dataset_aliases.update(new_dataset_aliases) - del new_dataset_alias_models - del lineage_dataset_alias_models + del new_dataset_aliases + del all_dataset_aliases # reconcile dag-schedule-on-dataset and dag-schedule-on-dataset-alias references for dag_id, dag_refs in dag_references.items(): diff --git a/airflow/models/dataset.py b/airflow/models/dataset.py index f28573df8617c..29de1b341eeeb 100644 --- a/airflow/models/dataset.py +++ b/airflow/models/dataset.py @@ -129,9 +129,6 @@ def from_public(cls, obj: DatasetAlias) -> DatasetAliasModel: def __repr__(self): return f"{self.__class__.__name__}(name={self.name!r})" - def __hash__(self): - return hash(self.name) - def __eq__(self, other): if isinstance(other, (self.__class__, DatasetAlias)): return self.name == other.name @@ -194,24 +191,21 @@ def from_public(cls, obj: Dataset) -> DatasetModel: def as_public(self) -> Dataset: return Dataset(name=self.name, uri=self.uri, extra=self.extra) - def __init__(self, uri: str, **kwargs): - try: - uri.encode("ascii") - except UnicodeEncodeError: - raise ValueError("URI must be ascii") - parsed = urlsplit(uri) - if parsed.scheme and parsed.scheme.lower() == "airflow": - raise ValueError("Scheme `airflow` is reserved.") - super().__init__(uri=uri, **kwargs) + def __init__(self, *, name: str | None = None, uri: str | None = None, **kwargs): + if name and not name.isascii(): + raise ValueError("name must be ascii") + if uri: + if not uri.isascii(): + raise ValueError("URI must be ascii") + parsed = urlsplit(uri) + if parsed.scheme and parsed.scheme.lower() == "airflow": + raise ValueError("Scheme `airflow` is reserved.") + super().__init__(name=name, uri=uri, **kwargs) def __eq__(self, other): if isinstance(other, (self.__class__, Dataset)): return self.uri == other.uri - else: - return NotImplemented - - def __hash__(self): - return hash(self.uri) + return NotImplemented def __repr__(self): return f"{self.__class__.__name__}(uri={self.uri!r}, extra={self.extra!r})" From ce22c8408dcd88bb27f009a6728fbad15b9082f1 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Wed, 28 Aug 2024 14:04:18 +0800 Subject: [PATCH 5/7] Prevent duplicates when inserting new datasets --- airflow/models/dag.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 2a31e8b20d804..9969c2fddd481 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -2796,7 +2796,8 @@ def bulk_write_to_db( # store datasets stored_datasets_by_name: dict[str, DatasetModel] = {} stored_datasets_by_uri: dict[str, DatasetModel] = {} - new_datasets: list[DatasetModel] = [] + new_datasets_by_name: dict[str, DatasetModel] = {} + new_datasets_by_uri: dict[str, DatasetModel] = {} for dataset in all_datasets: stmt = select(DatasetModel) if dataset.name: @@ -2824,12 +2825,17 @@ def bulk_write_to_db( stored_datasets_by_name[stored_dataset.name] = stored_dataset stored_datasets_by_uri[stored_dataset.uri] = stored_dataset else: - new_datasets.append(dataset) - dataset_manager.create_datasets(dataset_models=new_datasets, session=session) - stored_datasets_by_name.update((ds.name, ds) for ds in new_datasets) - stored_datasets_by_uri.update((ds.uri, ds) for ds in new_datasets) - - del new_datasets + if (dataset.name and dataset.name not in new_datasets_by_name) or ( + dataset.uri and dataset.uri not in new_datasets_by_uri + ): + new_datasets_by_name[dataset.name] = dataset + new_datasets_by_uri[dataset.uri] = dataset + dataset_manager.create_datasets(dataset_models=new_datasets_by_name.values(), session=session) + stored_datasets_by_name.update(new_datasets_by_name) + stored_datasets_by_uri.update(new_datasets_by_uri) + + del new_datasets_by_name + del new_datasets_by_uri del all_datasets # store dataset aliases From 870492e0b3d36d23a28c16696b743b66639c272b Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Wed, 28 Aug 2024 14:23:38 +0800 Subject: [PATCH 6/7] Fix alias_id type --- airflow/datasets/references.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/datasets/references.py b/airflow/datasets/references.py index 612c5a4b4459a..b2fc9bf3866e6 100644 --- a/airflow/datasets/references.py +++ b/airflow/datasets/references.py @@ -101,7 +101,7 @@ def resolve_dag_schedule_reference( return DagScheduleDatasetReference(dataset_id=dataset_names[ref.name].id, dag_id=dag_id) elif isinstance(ref, DatasetURIReference): return DagScheduleDatasetReference(dataset_id=dataset_uris[ref.uri].id, dag_id=dag_id) - return DagScheduleDatasetAliasReference(alias_id=alias_names[ref.name], dag_id=dag_id) + return DagScheduleDatasetAliasReference(alias_id=alias_names[ref.name].id, dag_id=dag_id) def resolve_task_outlet_reference( From 1722c5017902a609aaf504f1beb9ed6b94157cb4 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Wed, 4 Sep 2024 16:28:03 +0800 Subject: [PATCH 7/7] Fix database integrity in tests --- tests/datasets/test_dataset.py | 17 +++++++---------- tests/datasets/test_manager.py | 12 ++++++------ tests/listeners/test_dataset_listener.py | 7 +++---- tests/timetables/test_datasets_timetable.py | 2 +- 4 files changed, 17 insertions(+), 21 deletions(-) diff --git a/tests/datasets/test_dataset.py b/tests/datasets/test_dataset.py index 017c874763495..c591ddf4c034b 100644 --- a/tests/datasets/test_dataset.py +++ b/tests/datasets/test_dataset.py @@ -269,9 +269,9 @@ def test_nested_dataset_conditions_with_serialization(status_values, expected_ev @pytest.fixture def create_test_datasets(session): """Fixture to create test datasets and corresponding models.""" - datasets = [Dataset(uri=f"hello{i}") for i in range(1, 3)] + datasets = [Dataset(name=f"hello{i}", uri=f"hello{i}") for i in range(1, 3)] for dataset in datasets: - session.add(DatasetModel(uri=dataset.uri)) + session.add(DatasetModel.from_public(dataset)) session.commit() return datasets @@ -337,13 +337,11 @@ def test_dataset_dag_run_queue_processing(session, clear_datasets, dag_maker, cr @pytest.mark.usefixtures("clear_datasets") def test_dag_with_complex_dataset_condition(session, dag_maker): # Create Dataset instances - d1 = Dataset(uri="hello1") - d2 = Dataset(uri="hello2") + d1 = Dataset(name="hello1", uri="hello1") + d2 = Dataset(name="hello2", uri="hello2") # Create and add DatasetModel instances to the session - dm1 = DatasetModel(uri=d1.uri) - dm2 = DatasetModel(uri=d2.uri) - session.add_all([dm1, dm2]) + session.add_all([DatasetModel.from_public(d1), DatasetModel.from_public(d2)]) session.commit() # Setup a DAG with complex dataset triggers (DatasetAny with DatasetAll) @@ -539,12 +537,11 @@ def test_normalize_uri_valid_uri(): @pytest.mark.skip_if_database_isolation_mode @pytest.mark.db_test @pytest.mark.usefixtures("clear_datasets") -class Test_DatasetAliasCondition: +class TestDatasetAliasCondition: @pytest.fixture def ds_1(self, session): """Example dataset links to dataset alias resolved_dsa_2.""" - ds_uri = "test_uri" - ds_1 = DatasetModel(id=1, uri=ds_uri) + ds_1 = DatasetModel(id=1, name="test_dataset", uri="test_uri") session.add(ds_1) session.commit() diff --git a/tests/datasets/test_manager.py b/tests/datasets/test_manager.py index 1e7b4fda40cee..8eb5d4c97bd27 100644 --- a/tests/datasets/test_manager.py +++ b/tests/datasets/test_manager.py @@ -110,12 +110,12 @@ def test_register_dataset_change_dataset_doesnt_exist(self, mock_task_instance): def test_register_dataset_change(self, session, dag_maker, mock_task_instance): dsem = DatasetManager() - ds = Dataset(uri="test_dataset_uri") + ds = Dataset(name="test_dataset_uri", uri="test_dataset_uri") dag1 = DagModel(dag_id="dag1", is_active=True) dag2 = DagModel(dag_id="dag2", is_active=True) session.add_all([dag1, dag2]) - dsm = DatasetModel(uri="test_dataset_uri") + dsm = DatasetModel.from_public(ds) session.add(dsm) dsm.consuming_dags = [DagScheduleDatasetReference(dag_id=dag.dag_id) for dag in (dag1, dag2)] session.execute(delete(DatasetDagRunQueue)) @@ -130,8 +130,8 @@ def test_register_dataset_change(self, session, dag_maker, mock_task_instance): def test_register_dataset_change_no_downstreams(self, session, mock_task_instance): dsem = DatasetManager() - ds = Dataset(uri="never_consumed") - dsm = DatasetModel(uri="never_consumed") + ds = Dataset(name="never_consumed", uri="never_consumed") + dsm = DatasetModel.from_public(ds) session.add(dsm) session.execute(delete(DatasetDagRunQueue)) session.commit() @@ -148,11 +148,11 @@ def test_register_dataset_change_notifies_dataset_listener(self, session, mock_t dataset_listener.clear() get_listener_manager().add_listener(dataset_listener) - ds = Dataset(uri="test_dataset_uri_2") + ds = Dataset(name="test_dataset_uri_2", uri="test_dataset_uri_2") dag1 = DagModel(dag_id="dag3") session.add_all([dag1]) - dsm = DatasetModel(uri="test_dataset_uri_2") + dsm = DatasetModel.from_public(ds) session.add(dsm) dsm.consuming_dags = [DagScheduleDatasetReference(dag_id=dag1.dag_id)] session.commit() diff --git a/tests/listeners/test_dataset_listener.py b/tests/listeners/test_dataset_listener.py index b0ac6223e79ea..6f91ac7ac3ee4 100644 --- a/tests/listeners/test_dataset_listener.py +++ b/tests/listeners/test_dataset_listener.py @@ -41,9 +41,8 @@ def clean_listener_manager(): @pytest.mark.db_test @provide_session def test_dataset_listener_on_dataset_changed_gets_calls(create_task_instance_of_operator, session): - dataset_uri = "test_dataset_uri" - ds = Dataset(uri=dataset_uri) - ds_model = DatasetModel(uri=dataset_uri) + ds = Dataset(name="test_dataset_name", uri="test_dataset_uri") + ds_model = DatasetModel.from_public(ds) session.add(ds_model) session.flush() @@ -58,4 +57,4 @@ def test_dataset_listener_on_dataset_changed_gets_calls(create_task_instance_of_ ti.run() assert len(dataset_listener.changed) == 1 - assert dataset_listener.changed[0].uri == dataset_uri + assert dataset_listener.changed[0].uri == "test_dataset_uri" diff --git a/tests/timetables/test_datasets_timetable.py b/tests/timetables/test_datasets_timetable.py index b055f0d34dc90..da8b37ba8975a 100644 --- a/tests/timetables/test_datasets_timetable.py +++ b/tests/timetables/test_datasets_timetable.py @@ -260,7 +260,7 @@ def test_run_ordering_inheritance(dataset_timetable: DatasetOrTimeSchedule) -> N @pytest.mark.db_test def test_summary(session: Session) -> None: - dataset_model = DatasetModel(uri="test_dataset") + dataset_model = DatasetModel(name="test_dataset", uri="test_dataset") dataset_alias_model = DatasetAliasModel(name="test_dataset_alias") session.add_all([dataset_model, dataset_alias_model]) session.commit()