From 02e7925fdf2133462016bd0958fd75bb207e399d Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 15 Feb 2024 19:05:17 +0800 Subject: [PATCH 1/4] Refactor DatasetAll and DatasetAny inheritance They are moved from airflow.models.datasets to airflow.datasets since the intention is to use them with Dataset, not DatasetModel. It is more natural for users to import from the latter module instead. A new (abstract) base class is added for the two classes, plus the OG Dataset class, to inherit from. This allows us to replace a few isinstance checks with simple molymorphism and make the logic a bit simpler. --- airflow/datasets/__init__.py | 64 +++++++++++++++++++-- airflow/models/dag.py | 26 +++------ airflow/models/dataset.py | 47 --------------- airflow/serialization/serialized_objects.py | 5 +- tests/datasets/test_dataset.py | 4 +- 5 files changed, 73 insertions(+), 73 deletions(-) diff --git a/airflow/datasets/__init__.py b/airflow/datasets/__init__.py index eaa25d0a30c35..78b9dd399cd07 100644 --- a/airflow/datasets/__init__.py +++ b/airflow/datasets/__init__.py @@ -14,18 +14,35 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + from __future__ import annotations import os -from typing import Any, ClassVar +from typing import Any, Callable, ClassVar, Iterable, Iterator, Protocol, runtime_checkable from urllib.parse import urlsplit import attr +__all__ = ["Dataset", "DatasetAll", "DatasetAny"] + + +@runtime_checkable +class BaseDatasetEventInput(Protocol): + """Protocol for all dataset triggers to use in ``DAG(schedule=...)``. + + :meta private: + """ + + def evaluate(self, statuses: dict[str, bool]) -> bool: + raise NotImplementedError + + def iter_datasets(self) -> Iterator[tuple[str, Dataset]]: + raise NotImplementedError + @attr.define() -class Dataset(os.PathLike): - """A Dataset is used for marking data dependencies between workflows.""" +class Dataset(os.PathLike[str], BaseDatasetEventInput): + """A representation of data dependencies between workflows.""" uri: str = attr.field(validator=[attr.validators.min_len(1), attr.validators.max_len(3000)]) extra: dict[str, Any] | None = None @@ -44,7 +61,7 @@ def _check_uri(self, attr, uri: str): if parsed.scheme and parsed.scheme.lower() == "airflow": raise ValueError(f"{attr.name!r} scheme `airflow` is reserved") - def __fspath__(self): + def __fspath__(self) -> str: return self.uri def __eq__(self, other): @@ -55,3 +72,42 @@ def __eq__(self, other): def __hash__(self): return hash(self.uri) + + def iter_datasets(self) -> Iterator[tuple[str, Dataset]]: + yield self.uri, self + + def evaluate(self, statuses: dict[str, bool]) -> bool: + return statuses.get(self.uri, False) + + +class _DatasetBooleanCondition(BaseDatasetEventInput): + """Base class for dataset boolean logic.""" + + agg_func: Callable[[Iterable], bool] + + def __init__(self, *objects: BaseDatasetEventInput) -> None: + self.objects = objects + + def evaluate(self, statuses: dict[str, bool]): + return self.agg_func(x.evaluate(statuses=statuses) for x in self.objects) + + def iter_datasets(self) -> Iterator[tuple[str, Dataset]]: + seen = set() # We want to keep the first instance. + for o in self.objects: + for k, v in o.iter_datasets(): + if k in seen: + continue + yield k, v + seen.add(k) + + +class DatasetAny(_DatasetBooleanCondition): + """Use to combine datasets schedule references in an "and" relationship.""" + + agg_func = any + + +class DatasetAll(_DatasetBooleanCondition): + """Use to combine datasets schedule references in an "or" relationship.""" + + agg_func = all diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 237759010ac22..19bf4285430cf 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -80,6 +80,7 @@ from airflow import settings, utils from airflow.api_internal.internal_api_call import internal_api_call from airflow.configuration import conf as airflow_conf, secrets_backend_list +from airflow.datasets import BaseDatasetEventInput, Dataset, DatasetAll from airflow.datasets.manager import dataset_manager from airflow.exceptions import ( AirflowDagInconsistent, @@ -98,13 +99,7 @@ from airflow.models.dagcode import DagCode from airflow.models.dagpickle import DagPickle from airflow.models.dagrun import RUN_ID_REGEX, DagRun -from airflow.models.dataset import ( - DatasetAll, - DatasetAny, - DatasetBooleanCondition, - DatasetDagRunQueue, - DatasetModel, -) +from airflow.models.dataset import DatasetDagRunQueue, DatasetModel from airflow.models.param import DagParam, ParamsDict from airflow.models.taskinstance import ( Context, @@ -150,7 +145,6 @@ from sqlalchemy.orm.query import Query from sqlalchemy.orm.session import Session - from airflow.datasets import Dataset from airflow.decorators import TaskDecoratorCollection from airflow.models.dagbag import DagBag from airflow.models.operator import Operator @@ -174,7 +168,7 @@ # but Mypy cannot handle that right now. Track progress of PEP 661 for progress. # See also: https://discuss.python.org/t/9126/7 ScheduleIntervalArg = Union[ArgNotSet, ScheduleInterval] -ScheduleArg = Union[ArgNotSet, ScheduleInterval, Timetable, Collection["Dataset"]] +ScheduleArg = Union[ArgNotSet, ScheduleInterval, Timetable, BaseDatasetEventInput, Collection["Dataset"]] SLAMissCallback = Callable[["DAG", str, str, List["SlaMiss"], List[TaskInstance]], None] @@ -586,12 +580,10 @@ def __init__( self.timetable: Timetable self.schedule_interval: ScheduleInterval - self.dataset_triggers: DatasetBooleanCondition | None = None - if isinstance(schedule, (DatasetAll, DatasetAny)): + self.dataset_triggers: BaseDatasetEventInput | None = None + if isinstance(schedule, BaseDatasetEventInput): self.dataset_triggers = schedule - if isinstance(schedule, Collection) and not isinstance(schedule, str): - from airflow.datasets import Dataset - + elif isinstance(schedule, Collection) and not isinstance(schedule, str): if not all(isinstance(x, Dataset) for x in schedule): raise ValueError("All elements in 'schedule' should be datasets") self.dataset_triggers = DatasetAll(*schedule) @@ -3181,7 +3173,7 @@ def bulk_write_to_db( if curr_orm_dag and curr_orm_dag.schedule_dataset_references: curr_orm_dag.schedule_dataset_references = [] else: - for dataset in dag.dataset_triggers.all_datasets().values(): + for _, dataset in dag.dataset_triggers.iter_datasets(): dag_references[dag.dag_id].add(dataset.uri) input_datasets[DatasetModel.from_public(dataset)] = None curr_outlet_references = curr_orm_dag and curr_orm_dag.task_outlet_dataset_references @@ -3793,14 +3785,14 @@ def dags_needing_dagruns(cls, session: Session) -> tuple[Query, dict[str, tuple[ """ from airflow.models.serialized_dag import SerializedDagModel - def dag_ready(dag_id: str, cond: DatasetBooleanCondition, statuses: dict) -> bool | None: + def dag_ready(dag_id: str, cond: BaseDatasetEventInput, statuses: dict) -> bool | None: # if dag was serialized before 2.9 and we *just* upgraded, # we may be dealing with old version. In that case, # just wait for the dag to be reserialized. try: return cond.evaluate(statuses) except AttributeError: - log.warning("dag '%s' has old serialization; skipping dag run creation.", dag_id) + log.warning("dag '%s' has old serialization; skipping DAG run creation.", dag_id) return None # this loads all the DDRQ records.... may need to limit num dags diff --git a/airflow/models/dataset.py b/airflow/models/dataset.py index bf28777358786..aa10eb3809756 100644 --- a/airflow/models/dataset.py +++ b/airflow/models/dataset.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -from typing import Callable, Iterable from urllib.parse import urlsplit import sqlalchemy_jsonfield @@ -337,49 +336,3 @@ def __repr__(self) -> str: ]: args.append(f"{attr}={getattr(self, attr)!r}") return f"{self.__class__.__name__}({', '.join(args)})" - - -class DatasetBooleanCondition: - """ - Base class for boolean logic for dataset triggers. - - :meta private: - """ - - agg_func: Callable[[Iterable], bool] - - def __init__(self, *objects) -> None: - self.objects = objects - - def evaluate(self, statuses: dict[str, bool]) -> bool: - return self.agg_func(self.eval_one(x, statuses) for x in self.objects) - - def eval_one(self, obj: Dataset | DatasetAny | DatasetAll, statuses) -> bool: - if isinstance(obj, Dataset): - return statuses.get(obj.uri, False) - return obj.evaluate(statuses=statuses) - - def all_datasets(self) -> dict[str, Dataset]: - uris = {} - for x in self.objects: - if isinstance(x, Dataset): - if x.uri not in uris: - uris[x.uri] = x - else: - # keep the first instance - for k, v in x.all_datasets().items(): - if k not in uris: - uris[k] = v - return uris - - -class DatasetAny(DatasetBooleanCondition): - """Use to combine datasets schedule references in an "and" relationship.""" - - agg_func = any - - -class DatasetAll(DatasetBooleanCondition): - """Use to combine datasets schedule references in an "or" relationship.""" - - agg_func = all diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 5e6073233e273..552244d73ba7b 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -35,14 +35,13 @@ from airflow.compat.functools import cache from airflow.configuration import conf -from airflow.datasets import Dataset +from airflow.datasets import Dataset, DatasetAll, DatasetAny from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, SerializationError from airflow.jobs.job import Job from airflow.models.baseoperator import BaseOperator from airflow.models.connection import Connection from airflow.models.dag import DAG, DagModel, create_timetable from airflow.models.dagrun import DagRun -from airflow.models.dataset import DatasetAll, DatasetAny from airflow.models.expandinput import EXPAND_INPUT_EMPTY, create_expand_input, get_map_type_key from airflow.models.mappedoperator import MappedOperator from airflow.models.param import Param, ParamsDict @@ -788,7 +787,7 @@ def detect_dag_dependencies(dag: DAG | None) -> Iterable[DagDependency]: return if not dag.dataset_triggers: return - for uri in dag.dataset_triggers.all_datasets().keys(): + for uri, _ in dag.dataset_triggers.iter_datasets(): yield DagDependency( source="dataset", target=dag.dag_id, diff --git a/tests/datasets/test_dataset.py b/tests/datasets/test_dataset.py index e10264b0e2490..258e542cec264 100644 --- a/tests/datasets/test_dataset.py +++ b/tests/datasets/test_dataset.py @@ -23,8 +23,8 @@ import pytest from sqlalchemy.sql import select -from airflow.datasets import Dataset -from airflow.models.dataset import DatasetAll, DatasetAny, DatasetDagRunQueue, DatasetModel +from airflow.datasets import Dataset, DatasetAll, DatasetAny +from airflow.models.dataset import DatasetDagRunQueue, DatasetModel from airflow.models.serialized_dag import SerializedDagModel from airflow.operators.empty import EmptyOperator from airflow.serialization.serialized_objects import BaseSerialization, SerializedDAG From 355590afd3174c6839706eecc9b0eb81425a7baa Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Wed, 21 Feb 2024 19:10:51 +0530 Subject: [PATCH 2/4] Fix the import paths and static check errors Fix the import paths and static check errors --- airflow/datasets/__init__.py | 2 +- airflow/timetables/datasets.py | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/airflow/datasets/__init__.py b/airflow/datasets/__init__.py index 78b9dd399cd07..44f847ecc483b 100644 --- a/airflow/datasets/__init__.py +++ b/airflow/datasets/__init__.py @@ -41,7 +41,7 @@ def iter_datasets(self) -> Iterator[tuple[str, Dataset]]: @attr.define() -class Dataset(os.PathLike[str], BaseDatasetEventInput): +class Dataset(os.PathLike, BaseDatasetEventInput): """A representation of data dependencies between workflows.""" uri: str = attr.field(validator=[attr.validators.min_len(1), attr.validators.max_len(3000)]) diff --git a/airflow/timetables/datasets.py b/airflow/timetables/datasets.py index c755df964ee4d..428c06c8c6b9a 100644 --- a/airflow/timetables/datasets.py +++ b/airflow/timetables/datasets.py @@ -19,8 +19,8 @@ import typing +from airflow.datasets import DatasetAll, _DatasetBooleanCondition from airflow.exceptions import AirflowTimetableInvalid -from airflow.models.dataset import DatasetAll, DatasetBooleanCondition from airflow.timetables.simple import DatasetTriggeredTimetable as DatasetTriggeredSchedule from airflow.utils.types import DagRunType @@ -36,9 +36,11 @@ class DatasetOrTimeSchedule(DatasetTriggeredSchedule): """Combine time-based scheduling with event-based scheduling.""" - def __init__(self, timetable: Timetable, datasets: Collection[Dataset] | DatasetBooleanCondition) -> None: + def __init__( + self, timetable: Timetable, datasets: Collection[Dataset] | _DatasetBooleanCondition + ) -> None: self.timetable = timetable - if isinstance(datasets, DatasetBooleanCondition): + if isinstance(datasets, _DatasetBooleanCondition): self.datasets = datasets else: self.datasets = DatasetAll(*datasets) @@ -70,7 +72,7 @@ def serialize(self) -> dict[str, typing.Any]: def validate(self) -> None: if isinstance(self.timetable, DatasetTriggeredSchedule): raise AirflowTimetableInvalid("cannot nest dataset timetables") - if not isinstance(self.datasets, DatasetBooleanCondition): + if not isinstance(self.datasets, _DatasetBooleanCondition): raise AirflowTimetableInvalid("all elements in 'datasets' must be datasets") @property From 47325a8f226476487eb7e732f967d043a3999d6e Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Thu, 22 Feb 2024 13:34:55 +0530 Subject: [PATCH 3/4] Apply suggestions from code review Co-authored-by: Wei Lee --- airflow/datasets/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/datasets/__init__.py b/airflow/datasets/__init__.py index 44f847ecc483b..1d08d7d6d3621 100644 --- a/airflow/datasets/__init__.py +++ b/airflow/datasets/__init__.py @@ -88,7 +88,7 @@ class _DatasetBooleanCondition(BaseDatasetEventInput): def __init__(self, *objects: BaseDatasetEventInput) -> None: self.objects = objects - def evaluate(self, statuses: dict[str, bool]): + def evaluate(self, statuses: dict[str, bool]) -> bool: return self.agg_func(x.evaluate(statuses=statuses) for x in self.objects) def iter_datasets(self) -> Iterator[tuple[str, Dataset]]: From e8dac2ce5706bace59c11b25ad22a82710e7f86f Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 22 Feb 2024 16:19:34 +0800 Subject: [PATCH 4/4] Use BaseDatasetEventInput in timetable --- airflow/timetables/datasets.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/airflow/timetables/datasets.py b/airflow/timetables/datasets.py index 428c06c8c6b9a..dcc0652929285 100644 --- a/airflow/timetables/datasets.py +++ b/airflow/timetables/datasets.py @@ -19,7 +19,7 @@ import typing -from airflow.datasets import DatasetAll, _DatasetBooleanCondition +from airflow.datasets import BaseDatasetEventInput, DatasetAll from airflow.exceptions import AirflowTimetableInvalid from airflow.timetables.simple import DatasetTriggeredTimetable as DatasetTriggeredSchedule from airflow.utils.types import DagRunType @@ -37,10 +37,13 @@ class DatasetOrTimeSchedule(DatasetTriggeredSchedule): """Combine time-based scheduling with event-based scheduling.""" def __init__( - self, timetable: Timetable, datasets: Collection[Dataset] | _DatasetBooleanCondition + self, + *, + timetable: Timetable, + datasets: Collection[Dataset] | BaseDatasetEventInput, ) -> None: self.timetable = timetable - if isinstance(datasets, _DatasetBooleanCondition): + if isinstance(datasets, BaseDatasetEventInput): self.datasets = datasets else: self.datasets = DatasetAll(*datasets) @@ -72,7 +75,7 @@ def serialize(self) -> dict[str, typing.Any]: def validate(self) -> None: if isinstance(self.timetable, DatasetTriggeredSchedule): raise AirflowTimetableInvalid("cannot nest dataset timetables") - if not isinstance(self.datasets, _DatasetBooleanCondition): + if not isinstance(self.datasets, BaseDatasetEventInput): raise AirflowTimetableInvalid("all elements in 'datasets' must be datasets") @property