diff --git a/airflow/datasets/__init__.py b/airflow/datasets/__init__.py index eaa25d0a30c35..1d08d7d6d3621 100644 --- a/airflow/datasets/__init__.py +++ b/airflow/datasets/__init__.py @@ -14,18 +14,35 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + from __future__ import annotations import os -from typing import Any, ClassVar +from typing import Any, Callable, ClassVar, Iterable, Iterator, Protocol, runtime_checkable from urllib.parse import urlsplit import attr +__all__ = ["Dataset", "DatasetAll", "DatasetAny"] + + +@runtime_checkable +class BaseDatasetEventInput(Protocol): + """Protocol for all dataset triggers to use in ``DAG(schedule=...)``. + + :meta private: + """ + + def evaluate(self, statuses: dict[str, bool]) -> bool: + raise NotImplementedError + + def iter_datasets(self) -> Iterator[tuple[str, Dataset]]: + raise NotImplementedError + @attr.define() -class Dataset(os.PathLike): - """A Dataset is used for marking data dependencies between workflows.""" +class Dataset(os.PathLike, BaseDatasetEventInput): + """A representation of data dependencies between workflows.""" uri: str = attr.field(validator=[attr.validators.min_len(1), attr.validators.max_len(3000)]) extra: dict[str, Any] | None = None @@ -44,7 +61,7 @@ def _check_uri(self, attr, uri: str): if parsed.scheme and parsed.scheme.lower() == "airflow": raise ValueError(f"{attr.name!r} scheme `airflow` is reserved") - def __fspath__(self): + def __fspath__(self) -> str: return self.uri def __eq__(self, other): @@ -55,3 +72,42 @@ def __eq__(self, other): def __hash__(self): return hash(self.uri) + + def iter_datasets(self) -> Iterator[tuple[str, Dataset]]: + yield self.uri, self + + def evaluate(self, statuses: dict[str, bool]) -> bool: + return statuses.get(self.uri, False) + + +class _DatasetBooleanCondition(BaseDatasetEventInput): + """Base class for dataset boolean logic.""" + + agg_func: Callable[[Iterable], bool] + + def __init__(self, *objects: BaseDatasetEventInput) -> None: + self.objects = objects + + def evaluate(self, statuses: dict[str, bool]) -> bool: + return self.agg_func(x.evaluate(statuses=statuses) for x in self.objects) + + def iter_datasets(self) -> Iterator[tuple[str, Dataset]]: + seen = set() # We want to keep the first instance. + for o in self.objects: + for k, v in o.iter_datasets(): + if k in seen: + continue + yield k, v + seen.add(k) + + +class DatasetAny(_DatasetBooleanCondition): + """Use to combine datasets schedule references in an "and" relationship.""" + + agg_func = any + + +class DatasetAll(_DatasetBooleanCondition): + """Use to combine datasets schedule references in an "or" relationship.""" + + agg_func = all diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 237759010ac22..19bf4285430cf 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -80,6 +80,7 @@ from airflow import settings, utils from airflow.api_internal.internal_api_call import internal_api_call from airflow.configuration import conf as airflow_conf, secrets_backend_list +from airflow.datasets import BaseDatasetEventInput, Dataset, DatasetAll from airflow.datasets.manager import dataset_manager from airflow.exceptions import ( AirflowDagInconsistent, @@ -98,13 +99,7 @@ from airflow.models.dagcode import DagCode from airflow.models.dagpickle import DagPickle from airflow.models.dagrun import RUN_ID_REGEX, DagRun -from airflow.models.dataset import ( - DatasetAll, - DatasetAny, - DatasetBooleanCondition, - DatasetDagRunQueue, - DatasetModel, -) +from airflow.models.dataset import DatasetDagRunQueue, DatasetModel from airflow.models.param import DagParam, ParamsDict from airflow.models.taskinstance import ( Context, @@ -150,7 +145,6 @@ from sqlalchemy.orm.query import Query from sqlalchemy.orm.session import Session - from airflow.datasets import Dataset from airflow.decorators import TaskDecoratorCollection from airflow.models.dagbag import DagBag from airflow.models.operator import Operator @@ -174,7 +168,7 @@ # but Mypy cannot handle that right now. Track progress of PEP 661 for progress. # See also: https://discuss.python.org/t/9126/7 ScheduleIntervalArg = Union[ArgNotSet, ScheduleInterval] -ScheduleArg = Union[ArgNotSet, ScheduleInterval, Timetable, Collection["Dataset"]] +ScheduleArg = Union[ArgNotSet, ScheduleInterval, Timetable, BaseDatasetEventInput, Collection["Dataset"]] SLAMissCallback = Callable[["DAG", str, str, List["SlaMiss"], List[TaskInstance]], None] @@ -586,12 +580,10 @@ def __init__( self.timetable: Timetable self.schedule_interval: ScheduleInterval - self.dataset_triggers: DatasetBooleanCondition | None = None - if isinstance(schedule, (DatasetAll, DatasetAny)): + self.dataset_triggers: BaseDatasetEventInput | None = None + if isinstance(schedule, BaseDatasetEventInput): self.dataset_triggers = schedule - if isinstance(schedule, Collection) and not isinstance(schedule, str): - from airflow.datasets import Dataset - + elif isinstance(schedule, Collection) and not isinstance(schedule, str): if not all(isinstance(x, Dataset) for x in schedule): raise ValueError("All elements in 'schedule' should be datasets") self.dataset_triggers = DatasetAll(*schedule) @@ -3181,7 +3173,7 @@ def bulk_write_to_db( if curr_orm_dag and curr_orm_dag.schedule_dataset_references: curr_orm_dag.schedule_dataset_references = [] else: - for dataset in dag.dataset_triggers.all_datasets().values(): + for _, dataset in dag.dataset_triggers.iter_datasets(): dag_references[dag.dag_id].add(dataset.uri) input_datasets[DatasetModel.from_public(dataset)] = None curr_outlet_references = curr_orm_dag and curr_orm_dag.task_outlet_dataset_references @@ -3793,14 +3785,14 @@ def dags_needing_dagruns(cls, session: Session) -> tuple[Query, dict[str, tuple[ """ from airflow.models.serialized_dag import SerializedDagModel - def dag_ready(dag_id: str, cond: DatasetBooleanCondition, statuses: dict) -> bool | None: + def dag_ready(dag_id: str, cond: BaseDatasetEventInput, statuses: dict) -> bool | None: # if dag was serialized before 2.9 and we *just* upgraded, # we may be dealing with old version. In that case, # just wait for the dag to be reserialized. try: return cond.evaluate(statuses) except AttributeError: - log.warning("dag '%s' has old serialization; skipping dag run creation.", dag_id) + log.warning("dag '%s' has old serialization; skipping DAG run creation.", dag_id) return None # this loads all the DDRQ records.... may need to limit num dags diff --git a/airflow/models/dataset.py b/airflow/models/dataset.py index bf28777358786..aa10eb3809756 100644 --- a/airflow/models/dataset.py +++ b/airflow/models/dataset.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -from typing import Callable, Iterable from urllib.parse import urlsplit import sqlalchemy_jsonfield @@ -337,49 +336,3 @@ def __repr__(self) -> str: ]: args.append(f"{attr}={getattr(self, attr)!r}") return f"{self.__class__.__name__}({', '.join(args)})" - - -class DatasetBooleanCondition: - """ - Base class for boolean logic for dataset triggers. - - :meta private: - """ - - agg_func: Callable[[Iterable], bool] - - def __init__(self, *objects) -> None: - self.objects = objects - - def evaluate(self, statuses: dict[str, bool]) -> bool: - return self.agg_func(self.eval_one(x, statuses) for x in self.objects) - - def eval_one(self, obj: Dataset | DatasetAny | DatasetAll, statuses) -> bool: - if isinstance(obj, Dataset): - return statuses.get(obj.uri, False) - return obj.evaluate(statuses=statuses) - - def all_datasets(self) -> dict[str, Dataset]: - uris = {} - for x in self.objects: - if isinstance(x, Dataset): - if x.uri not in uris: - uris[x.uri] = x - else: - # keep the first instance - for k, v in x.all_datasets().items(): - if k not in uris: - uris[k] = v - return uris - - -class DatasetAny(DatasetBooleanCondition): - """Use to combine datasets schedule references in an "and" relationship.""" - - agg_func = any - - -class DatasetAll(DatasetBooleanCondition): - """Use to combine datasets schedule references in an "or" relationship.""" - - agg_func = all diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 5e6073233e273..552244d73ba7b 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -35,14 +35,13 @@ from airflow.compat.functools import cache from airflow.configuration import conf -from airflow.datasets import Dataset +from airflow.datasets import Dataset, DatasetAll, DatasetAny from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, SerializationError from airflow.jobs.job import Job from airflow.models.baseoperator import BaseOperator from airflow.models.connection import Connection from airflow.models.dag import DAG, DagModel, create_timetable from airflow.models.dagrun import DagRun -from airflow.models.dataset import DatasetAll, DatasetAny from airflow.models.expandinput import EXPAND_INPUT_EMPTY, create_expand_input, get_map_type_key from airflow.models.mappedoperator import MappedOperator from airflow.models.param import Param, ParamsDict @@ -788,7 +787,7 @@ def detect_dag_dependencies(dag: DAG | None) -> Iterable[DagDependency]: return if not dag.dataset_triggers: return - for uri in dag.dataset_triggers.all_datasets().keys(): + for uri, _ in dag.dataset_triggers.iter_datasets(): yield DagDependency( source="dataset", target=dag.dag_id, diff --git a/airflow/timetables/datasets.py b/airflow/timetables/datasets.py index c755df964ee4d..dcc0652929285 100644 --- a/airflow/timetables/datasets.py +++ b/airflow/timetables/datasets.py @@ -19,8 +19,8 @@ import typing +from airflow.datasets import BaseDatasetEventInput, DatasetAll from airflow.exceptions import AirflowTimetableInvalid -from airflow.models.dataset import DatasetAll, DatasetBooleanCondition from airflow.timetables.simple import DatasetTriggeredTimetable as DatasetTriggeredSchedule from airflow.utils.types import DagRunType @@ -36,9 +36,14 @@ class DatasetOrTimeSchedule(DatasetTriggeredSchedule): """Combine time-based scheduling with event-based scheduling.""" - def __init__(self, timetable: Timetable, datasets: Collection[Dataset] | DatasetBooleanCondition) -> None: + def __init__( + self, + *, + timetable: Timetable, + datasets: Collection[Dataset] | BaseDatasetEventInput, + ) -> None: self.timetable = timetable - if isinstance(datasets, DatasetBooleanCondition): + if isinstance(datasets, BaseDatasetEventInput): self.datasets = datasets else: self.datasets = DatasetAll(*datasets) @@ -70,7 +75,7 @@ def serialize(self) -> dict[str, typing.Any]: def validate(self) -> None: if isinstance(self.timetable, DatasetTriggeredSchedule): raise AirflowTimetableInvalid("cannot nest dataset timetables") - if not isinstance(self.datasets, DatasetBooleanCondition): + if not isinstance(self.datasets, BaseDatasetEventInput): raise AirflowTimetableInvalid("all elements in 'datasets' must be datasets") @property diff --git a/tests/datasets/test_dataset.py b/tests/datasets/test_dataset.py index e10264b0e2490..258e542cec264 100644 --- a/tests/datasets/test_dataset.py +++ b/tests/datasets/test_dataset.py @@ -23,8 +23,8 @@ import pytest from sqlalchemy.sql import select -from airflow.datasets import Dataset -from airflow.models.dataset import DatasetAll, DatasetAny, DatasetDagRunQueue, DatasetModel +from airflow.datasets import Dataset, DatasetAll, DatasetAny +from airflow.models.dataset import DatasetDagRunQueue, DatasetModel from airflow.models.serialized_dag import SerializedDagModel from airflow.operators.empty import EmptyOperator from airflow.serialization.serialized_objects import BaseSerialization, SerializedDAG