diff --git a/airflow-core/src/airflow/api/common/mark_tasks.py b/airflow-core/src/airflow/api/common/mark_tasks.py index 05c73a512f071..2525250ed82ee 100644 --- a/airflow-core/src/airflow/api/common/mark_tasks.py +++ b/airflow-core/src/airflow/api/common/mark_tasks.py @@ -76,10 +76,13 @@ def set_state( if not tasks: return [] - task_dags = {task[0].dag if isinstance(task, tuple) else task.dag for task in tasks} + task_dags = { + (dag.dag_id if dag else None): dag + for dag in (task[0].dag if isinstance(task, tuple) else task.dag for task in tasks) + } if len(task_dags) > 1: raise ValueError(f"Received tasks from multiple DAGs: {task_dags}") - dag = next(iter(task_dags)) + dag = next(iter(task_dags.values())) if dag is None: raise ValueError("Received tasks with no DAG") if not run_id: diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index eb808fb6cbc4b..d3e43afbd916e 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -304,7 +304,7 @@ def _get_upstream_map_indexes( if (upstream_mapped_group := upstream_task.get_closest_mapped_task_group()) is None: # regular tasks or non-mapped task groups map_indexes = None - elif task.get_closest_mapped_task_group() == upstream_mapped_group: + elif task.get_closest_mapped_task_group() is upstream_mapped_group: # tasks in the same mapped task group hierarchy map_indexes = ti.map_index else: diff --git a/airflow-core/src/airflow/cli/commands/task_command.py b/airflow-core/src/airflow/cli/commands/task_command.py index ef9a653fc2f24..586edad563f62 100644 --- a/airflow-core/src/airflow/cli/commands/task_command.py +++ b/airflow-core/src/airflow/cli/commands/task_command.py @@ -38,7 +38,8 @@ from airflow.models.serialized_dag import SerializedDagModel from airflow.sdk.definitions.dag import DAG, _run_task from airflow.sdk.definitions.param import ParamsDict -from airflow.serialization.serialized_objects import SerializedDAG +from airflow.serialization.definitions.dag import SerializedDAG +from airflow.serialization.serialized_objects import DagSerialization from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.dependencies_deps import SCHEDULER_QUEUED_DEPS from airflow.utils import cli as cli_utils @@ -384,7 +385,7 @@ def task_test(args, dag: DAG | None = None) -> None: if dag: sdk_dag = dag - scheduler_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + scheduler_dag = DagSerialization.from_dict(DagSerialization.to_dict(dag)) else: sdk_dag = get_bagged_dag(args.bundle_name, args.dag_id) scheduler_dag = get_db_dag(args.bundle_name, args.dag_id) @@ -429,11 +430,14 @@ def task_test(args, dag: DAG | None = None) -> None: @providers_configuration_loaded def task_render(args, dag: DAG | None = None) -> None: """Render and displays templated fields for a given task.""" - if not dag: - dag = get_bagged_dag(args.bundle_name, args.dag_id) - serialized_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag)) + if dag: + sdk_dag = dag + scheduler_dag = DagSerialization.from_dict(DagSerialization.to_dict(dag)) + else: + sdk_dag = get_bagged_dag(args.bundle_name, args.dag_id) + scheduler_dag = get_db_dag(args.bundle_name, args.dag_id) ti, _ = _get_ti( - serialized_dag.get_task(task_id=args.task_id), + scheduler_dag.get_task(task_id=args.task_id), args.map_index, logical_date_or_run_id=args.logical_date_or_run_id, create_if_necessary="memory", @@ -441,7 +445,7 @@ def task_render(args, dag: DAG | None = None) -> None: with create_session() as session: context = ti.get_template_context(session=session) - task = dag.get_task(args.task_id) + task = sdk_dag.get_task(args.task_id) # TODO (GH-52141): After sdk separation, ti.get_template_context() would # contain serialized operators, but we need the real operators for # rendering. This does not make sense and eventually we should rewrite diff --git a/airflow-core/src/airflow/dag_processing/processor.py b/airflow-core/src/airflow/dag_processing/processor.py index ae984a0eecfd2..41181f7bca848 100644 --- a/airflow-core/src/airflow/dag_processing/processor.py +++ b/airflow-core/src/airflow/dag_processing/processor.py @@ -66,7 +66,7 @@ ) from airflow.sdk.execution_time.supervisor import WatchedSubprocess from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance, _send_error_email_notification -from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG +from airflow.serialization.serialized_objects import DagSerialization, LazyDeserializedDAG from airflow.utils.file import iter_airflow_imports from airflow.utils.state import TaskInstanceState @@ -239,7 +239,7 @@ def _serialize_dags( serialized_dags = [] for dag in bag.dags.values(): try: - data = SerializedDAG.to_dict(dag) + data = DagSerialization.to_dict(dag) serialized_dags.append(LazyDeserializedDAG(data=data, last_loaded=dag.last_loaded)) except Exception: log.exception("Failed to serialize DAG: %s", dag.fileloc) diff --git a/airflow-core/src/airflow/models/serialized_dag.py b/airflow-core/src/airflow/models/serialized_dag.py index 3809a42d08d6d..664b874176fa9 100644 --- a/airflow-core/src/airflow/models/serialized_dag.py +++ b/airflow-core/src/airflow/models/serialized_dag.py @@ -45,7 +45,7 @@ from airflow.models.dagrun import DagRun from airflow.serialization.dag_dependency import DagDependency from airflow.serialization.definitions.assets import SerializedAssetUniqueKey as UKey -from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG +from airflow.serialization.serialized_objects import DagSerialization from airflow.settings import COMPRESS_SERIALIZED_DAGS, json from airflow.utils.hashlib_wrapper import md5 from airflow.utils.session import NEW_SESSION, provide_session @@ -56,6 +56,9 @@ from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.sql.elements import ColumnElement + from airflow.serialization.definitions.dag import SerializedDAG + from airflow.serialization.serialized_objects import LazyDeserializedDAG + log = logging.getLogger(__name__) @@ -568,14 +571,14 @@ def data(self) -> dict | None: @property def dag(self) -> SerializedDAG: """The DAG deserialized from the ``data`` column.""" - SerializedDAG._load_operator_extra_links = self.load_op_links + DagSerialization._load_operator_extra_links = self.load_op_links if isinstance(self.data, dict): data = self.data elif isinstance(self.data, str): data = json.loads(self.data) else: raise ValueError("invalid or missing serialized DAG data") - return SerializedDAG.from_dict(data) + return DagSerialization.from_dict(data) @classmethod @provide_session diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index eb3112736644d..0814f438688ac 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -115,8 +115,9 @@ from airflow.models.dag import DagModel from airflow.models.dagrun import DagRun from airflow.models.mappedoperator import MappedOperator + from airflow.serialization.definitions.dag import SerializedDAG from airflow.serialization.definitions.taskgroup import SerializedTaskGroup - from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG + from airflow.serialization.serialized_objects import SerializedBaseOperator from airflow.utils.context import Context Operator: TypeAlias = MappedOperator | SerializedBaseOperator @@ -1480,7 +1481,8 @@ def run( """Run TaskInstance (only kept for tests).""" # This method is only used in ti.run and dag.test and task.test. # So doing the s10n/de-s10n dance to operator on Serialized task for the scheduler dep check part. - from airflow.serialization.serialized_objects import SerializedDAG + from airflow.serialization.definitions.dag import SerializedDAG + from airflow.serialization.serialized_objects import DagSerialization original_task = self.task if TYPE_CHECKING: @@ -1489,7 +1491,7 @@ def run( # We don't set up all tests well... if not isinstance(original_task.dag, SerializedDAG): - serialized_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(original_task.dag)) + serialized_dag = DagSerialization.from_dict(DagSerialization.to_dict(original_task.dag)) self.task = serialized_dag.get_task(original_task.task_id) res = self.check_and_change_state_before_execution( diff --git a/airflow-core/src/airflow/serialization/definitions/dag.py b/airflow-core/src/airflow/serialization/definitions/dag.py new file mode 100644 index 0000000000000..f6556dcfd1c0e --- /dev/null +++ b/airflow-core/src/airflow/serialization/definitions/dag.py @@ -0,0 +1,1112 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import copy +import functools +import itertools +import operator +import re +import weakref +from typing import TYPE_CHECKING, cast, overload + +import attrs +import structlog +from sqlalchemy import func, or_, select, tuple_ + +from airflow._shared.timezones.timezone import coerce_datetime +from airflow.configuration import conf as airflow_conf +from airflow.exceptions import AirflowException, TaskNotFound +from airflow.models.dag import DagModel +from airflow.models.dag_version import DagVersion +from airflow.models.dagrun import DagRun +from airflow.models.deadline import Deadline +from airflow.models.taskinstancekey import TaskInstanceKey +from airflow.models.tasklog import LogTemplate +from airflow.observability.stats import Stats +from airflow.sdk.definitions.deadline import DeadlineReference +from airflow.serialization.definitions.param import SerializedParamsDict +from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction +from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.state import DagRunState, TaskInstanceState +from airflow.utils.types import DagRunType + +if TYPE_CHECKING: + import datetime + from collections.abc import Collection, Iterable, Sequence + from typing import Any, Literal + + from pendulum.tz.timezone import FixedTimezone, Timezone + from pydantic import NonNegativeInt + from sqlalchemy.orm import Session + from typing_extensions import TypeIs + + from airflow.models.taskinstance import TaskInstance + from airflow.sdk import DAG + from airflow.sdk.definitions.deadline import DeadlineAlert + from airflow.sdk.definitions.edges import EdgeInfoType + from airflow.serialization.definitions.taskgroup import SerializedTaskGroup + from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedOperator + from airflow.timetables.base import Timetable + from airflow.utils.types import DagRunTriggeredByType + +log = structlog.get_logger(__name__) + + +@attrs.define(eq=False, hash=False, slots=False) +class SerializedDAG: + """ + Serialized representation of a ``DAG`` instance. + + A stringified DAG can only be used in the scope of scheduler and webserver. + Fields that are not serializable, such as functions and customer defined + classes, are casted to strings. + """ + + dag_id: str + dag_display_name: str = attrs.field(default=attrs.Factory(operator.attrgetter("dag_id"), takes_self=True)) + + # Default values of fields below should match schema default. + access_control: dict[str, dict[str, Collection[str]]] | None = None + catchup: bool = False + dagrun_timeout: datetime.timedelta | None = None + deadline: list[DeadlineAlert] | DeadlineAlert | None = None + default_args: dict[str, Any] = attrs.field(factory=dict) + description: str | None = None + disable_bundle_versioning: bool = False + doc_md: str | None = None + edge_info: dict[str, dict[str, EdgeInfoType]] = attrs.field(factory=dict) + end_date: datetime.datetime | None = None + fail_fast: bool = False + has_on_failure_callback: bool = False + has_on_success_callback: bool = False + is_paused_upon_creation: bool | None = None + max_active_runs: int = 16 + max_active_tasks: int = 16 + max_consecutive_failed_dag_runs: int = 0 + owner_links: dict[str, str] = attrs.field(factory=dict) + params: SerializedParamsDict = attrs.field(factory=SerializedParamsDict) + partial: bool = False + render_template_as_native_obj: bool = False + start_date: datetime.datetime | None = None + tags: set[str] = attrs.field(factory=set) + template_searchpath: tuple[str, ...] | None = None + + # These are set dynamically during deserialization. + task_dict: dict[str, SerializedOperator] = attrs.field(init=False) + task_group: SerializedTaskGroup = attrs.field(init=False) + timetable: Timetable = attrs.field(init=False) + timezone: FixedTimezone | Timezone = attrs.field(init=False) + + # Only on serialized dag. + last_loaded: datetime.datetime = attrs.field(init=False) + # Determine the relative fileloc based only on the serialize dag. + _processor_dags_folder: str = attrs.field(init=False) + + def __init__(self, *, dag_id: str) -> None: + self.__attrs_init__(dag_id=dag_id, dag_display_name=dag_id) # type: ignore[attr-defined] + + def __repr__(self) -> str: + return f"" + + @classmethod + def get_serialized_fields(cls) -> frozenset[str]: + return frozenset( + { + "access_control", + "catchup", + "dag_display_name", + "dag_id", + "dagrun_timeout", + "deadline", + "default_args", + "description", + "disable_bundle_versioning", + "doc_md", + "edge_info", + "end_date", + "fail_fast", + "fileloc", + "is_paused_upon_creation", + "max_active_runs", + "max_active_tasks", + "max_consecutive_failed_dag_runs", + "owner_links", + "relative_fileloc", + "render_template_as_native_obj", + "start_date", + "tags", + "task_group", + "timetable", + "timezone", + } + ) + + @classmethod + @provide_session + def bulk_write_to_db( + cls, + bundle_name: str, + bundle_version: str | None, + dags: Collection[DAG | LazyDeserializedDAG], + parse_duration: float | None = None, + session: Session = NEW_SESSION, + ) -> None: + """ + Ensure the DagModel rows for the given dags are up-to-date in the dag table in the DB. + + :param dags: the DAG objects to save to the DB + :return: None + """ + if not dags: + return + + from airflow.dag_processing.collection import AssetModelOperation, DagModelOperation + from airflow.serialization.serialized_objects import LazyDeserializedDAG + + log.info("Bulk-writing dags to db", count=len(dags)) + dag_op = DagModelOperation( + bundle_name=bundle_name, + bundle_version=bundle_version, + dags={d.dag_id: LazyDeserializedDAG.from_dag(d) for d in dags}, + ) + + orm_dags = dag_op.add_dags(session=session) + dag_op.update_dags(orm_dags, parse_duration, session=session) + + asset_op = AssetModelOperation.collect(dag_op.dags) + + orm_assets = asset_op.sync_assets(session=session) + orm_asset_aliases = asset_op.sync_asset_aliases(session=session) + session.flush() # This populates id so we can create fks in later calls. + + orm_dags = dag_op.find_orm_dags(session=session) # Refetch so relationship is up to date. + asset_op.add_dag_asset_references(orm_dags, orm_assets, session=session) + asset_op.add_dag_asset_alias_references(orm_dags, orm_asset_aliases, session=session) + asset_op.add_dag_asset_name_uri_references(session=session) + asset_op.add_task_asset_references(orm_dags, orm_assets, session=session) + asset_op.activate_assets_if_possible(orm_assets.values(), session=session) + session.flush() # Activation is needed when we add trigger references. + + asset_op.add_asset_trigger_references(orm_assets, session=session) + dag_op.update_dag_asset_expression(orm_dags=orm_dags, orm_assets=orm_assets) + session.flush() + + @property + def tasks(self) -> Sequence[SerializedOperator]: + return list(self.task_dict.values()) + + @property + def task_ids(self) -> list[str]: + return list(self.task_dict) + + @property + def roots(self) -> list[SerializedOperator]: + return [task for task in self.tasks if not task.upstream_list] + + @property + def owner(self) -> str: + return ", ".join({t.owner for t in self.tasks}) + + def has_task(self, task_id: str) -> bool: + return task_id in self.task_dict + + def get_task(self, task_id: str) -> SerializedOperator: + if task_id in self.task_dict: + return self.task_dict[task_id] + raise TaskNotFound(f"Task {task_id} not found") + + @property + def task_group_dict(self): + return {k: v for k, v in self.task_group.get_task_group_dict().items() if k is not None} + + def partial_subset( + self, + task_ids: str | Iterable[str], + include_downstream: bool = False, + include_upstream: bool = True, + include_direct_upstream: bool = False, + exclude_original: bool = False, + ): + from airflow.models.mappedoperator import MappedOperator as SerializedMappedOperator + from airflow.serialization.serialized_objects import SerializedBaseOperator + + def is_task(obj) -> TypeIs[SerializedOperator]: + return isinstance(obj, (SerializedMappedOperator, SerializedBaseOperator)) + + # deep-copying self.task_dict and self.task_group takes a long time, and we don't want all + # the tasks anyway, so we copy the tasks manually later + memo = {id(self.task_dict): None, id(self.task_group): None} + dag = copy.deepcopy(self, memo) + + if isinstance(task_ids, str): + matched_tasks = [t for t in self.tasks if task_ids in t.task_id] + else: + matched_tasks = [t for t in self.tasks if t.task_id in task_ids] + + also_include_ids: set[str] = set() + for t in matched_tasks: + if include_downstream: + for rel in t.get_flat_relatives(upstream=False): + also_include_ids.add(rel.task_id) + if rel not in matched_tasks: # if it's in there, we're already processing it + # need to include setups and teardowns for tasks that are in multiple + # non-collinear setup/teardown paths + if not rel.is_setup and not rel.is_teardown: + also_include_ids.update( + x.task_id for x in rel.get_upstreams_only_setups_and_teardowns() + ) + if include_upstream: + also_include_ids.update(x.task_id for x in t.get_upstreams_follow_setups()) + else: + if not t.is_setup and not t.is_teardown: + also_include_ids.update(x.task_id for x in t.get_upstreams_only_setups_and_teardowns()) + if t.is_setup and not include_downstream: + also_include_ids.update(x.task_id for x in t.downstream_list if x.is_teardown) + + also_include: list[SerializedOperator] = [self.task_dict[x] for x in also_include_ids] + direct_upstreams: list[SerializedOperator] = [] + if include_direct_upstream: + for t in itertools.chain(matched_tasks, also_include): + # TODO (GH-52141): This should return scheduler types, but currently we reuse SDK DAGNode. + upstream = (u for u in cast("Iterable[SerializedOperator]", t.upstream_list) if is_task(u)) + direct_upstreams.extend(upstream) + + # Make sure to not recursively deepcopy the dag or task_group while copying the task. + # task_group is reset later + def _deepcopy_task(t) -> SerializedOperator: + memo.setdefault(id(t.task_group), None) + return copy.deepcopy(t, memo) + + # Compiling the unique list of tasks that made the cut + if exclude_original: + matched_tasks = [] + dag.task_dict = { + t.task_id: _deepcopy_task(t) + for t in itertools.chain(matched_tasks, also_include, direct_upstreams) + } + + def filter_task_group(group, parent_group): + """Exclude tasks not included in the partial dag from the given TaskGroup.""" + # We want to deepcopy _most but not all_ attributes of the task group, so we create a shallow copy + # and then manually deep copy the instances. (memo argument to deepcopy only works for instances + # of classes, not "native" properties of an instance) + copied = copy.copy(group) + + memo[id(group.children)] = {} + if parent_group: + memo[id(group.parent_group)] = parent_group + for attr in type(group).__slots__: + value = getattr(group, attr) + value = copy.deepcopy(value, memo) + object.__setattr__(copied, attr, value) + + proxy = weakref.proxy(copied) + + for child in group.children.values(): + if is_task(child): + if child.task_id in dag.task_dict: + task = copied.children[child.task_id] = dag.task_dict[child.task_id] + task.task_group = proxy + else: + filtered_child = filter_task_group(child, proxy) + + # Only include this child TaskGroup if it is non-empty. + if filtered_child.children: + copied.children[child.group_id] = filtered_child + + return copied + + object.__setattr__(dag, "task_group", filter_task_group(self.task_group, None)) + + # Removing upstream/downstream references to tasks and TaskGroups that did not make + # the cut. + groups = dag.task_group.get_task_group_dict() + for g in groups.values(): + g.upstream_group_ids.intersection_update(groups) + g.downstream_group_ids.intersection_update(groups) + g.upstream_task_ids.intersection_update(dag.task_dict) + g.downstream_task_ids.intersection_update(dag.task_dict) + + for t in dag.tasks: + # Removing upstream/downstream references to tasks that did not + # make the cut + t.upstream_task_ids.intersection_update(dag.task_dict) + t.downstream_task_ids.intersection_update(dag.task_dict) + + dag.partial = len(dag.tasks) < len(self.tasks) + + return dag + + @functools.cached_property + def _time_restriction(self) -> TimeRestriction: + start_dates = [t.start_date for t in self.tasks if t.start_date] + if self.start_date is not None: + start_dates.append(self.start_date) + earliest = None + if start_dates: + earliest = coerce_datetime(min(start_dates)) + latest = coerce_datetime(self.end_date) + end_dates = [t.end_date for t in self.tasks if t.end_date] + if len(end_dates) == len(self.tasks): # not exists null end_date + if self.end_date is not None: + end_dates.append(self.end_date) + if end_dates: + latest = coerce_datetime(max(end_dates)) + return TimeRestriction(earliest, latest, self.catchup) + + def next_dagrun_info( + self, + last_automated_dagrun: None | DataInterval, + *, + restricted: bool = True, + ) -> DagRunInfo | None: + """ + Get information about the next DagRun of this dag after ``date_last_automated_dagrun``. + + This calculates what time interval the next DagRun should operate on + (its logical date) and when it can be scheduled, according to the + dag's timetable, start_date, end_date, etc. This doesn't check max + active run or any other "max_active_tasks" type limits, but only + performs calculations based on the various date and interval fields of + this dag and its tasks. + + :param last_automated_dagrun: The ``max(logical_date)`` of + existing "automated" DagRuns for this dag (scheduled or backfill, + but not manual). + :param restricted: If set to *False* (default is *True*), ignore + ``start_date``, ``end_date``, and ``catchup`` specified on the DAG + or tasks. + :return: DagRunInfo of the next dagrun, or None if a dagrun is not + going to be scheduled. + """ + if restricted: + restriction = self._time_restriction + else: + restriction = TimeRestriction(earliest=None, latest=None, catchup=True) + try: + info = self.timetable.next_dagrun_info( + last_automated_data_interval=last_automated_dagrun, + restriction=restriction, + ) + except Exception: + log.exception( + "Failed to fetch run info after data interval %s for DAG %r", + last_automated_dagrun, + self.dag_id, + ) + info = None + return info + + def iter_dagrun_infos_between( + self, + earliest: datetime.datetime | None, + latest: datetime.datetime, + *, + align: bool = True, + ) -> Iterable[DagRunInfo]: + """ + Yield DagRunInfo using this DAG's timetable between given interval. + + DagRunInfo instances yielded if their ``logical_date`` is not earlier + than ``earliest``, nor later than ``latest``. The instances are ordered + by their ``logical_date`` from earliest to latest. + + If ``align`` is ``False``, the first run will happen immediately on + ``earliest``, even if it does not fall on the logical timetable schedule. + The default is ``True``. + + Example: A DAG is scheduled to run every midnight (``0 0 * * *``). If + ``earliest`` is ``2021-06-03 23:00:00``, the first DagRunInfo would be + ``2021-06-03 23:00:00`` if ``align=False``, and ``2021-06-04 00:00:00`` + if ``align=True``. + """ + if earliest is None: + earliest = self._time_restriction.earliest + if earliest is None: + raise ValueError("earliest was None and we had no value in time_restriction to fallback on") + earliest = coerce_datetime(earliest) + latest = coerce_datetime(latest) + + restriction = TimeRestriction(earliest, latest, catchup=True) + + try: + info = self.timetable.next_dagrun_info( + last_automated_data_interval=None, + restriction=restriction, + ) + except Exception: + log.exception( + "Failed to fetch run info after data interval %s for DAG %r", + None, + self.dag_id, + ) + info = None + + if info is None: + # No runs to be scheduled between the user-supplied timeframe. But + # if align=False, "invent" a data interval for the timeframe itself. + if not align: + yield DagRunInfo.interval(earliest, latest) + return + + # If align=False and earliest does not fall on the timetable's logical + # schedule, "invent" a data interval for it. + if not align and info.logical_date != earliest: + yield DagRunInfo.interval(earliest, info.data_interval.start) + + # Generate naturally according to schedule. + while info is not None: + yield info + try: + info = self.timetable.next_dagrun_info( + last_automated_data_interval=info.data_interval, + restriction=restriction, + ) + except Exception: + log.exception( + "Failed to fetch run info after data interval %s for DAG %r", + info.data_interval if info else "", + self.dag_id, + ) + break + + @provide_session + def get_concurrency_reached(self, session=NEW_SESSION) -> bool: + """Return a boolean indicating whether the max_active_tasks limit for this DAG has been reached.""" + from airflow.models.taskinstance import TaskInstance + + total_tasks = session.scalar( + select(func.count(TaskInstance.task_id)).where( + TaskInstance.dag_id == self.dag_id, + TaskInstance.state == TaskInstanceState.RUNNING, + ) + ) + return total_tasks >= self.max_active_tasks + + @provide_session + def create_dagrun( + self, + *, + run_id: str, + logical_date: datetime.datetime | None = None, + data_interval: tuple[datetime.datetime, datetime.datetime] | None = None, + run_after: datetime.datetime, + conf: dict | None = None, + run_type: DagRunType, + triggered_by: DagRunTriggeredByType, + triggering_user_name: str | None = None, + state: DagRunState, + start_date: datetime.datetime | None = None, + creating_job_id: int | None = None, + backfill_id: NonNegativeInt | None = None, + partition_key: str | None = None, + session: Session = NEW_SESSION, + ) -> DagRun: + """ + Create a run for this DAG to run its tasks. + + :param run_id: ID of the dag_run + :param logical_date: date of execution + :param run_after: the datetime before which dag won't run + :param conf: Dict containing configuration/parameters to pass to the DAG + :param triggered_by: the entity which triggers the dag_run + :param triggering_user_name: the user name who triggers the dag_run + :param start_date: the date this dag run should be evaluated + :param creating_job_id: ID of the job creating this DagRun + :param backfill_id: ID of the backfill run if one exists + :param session: Unused. Only added in compatibility with database isolation mode + :return: The created DAG run. + + :meta private: + """ + from airflow.models.dagrun import RUN_ID_REGEX + + logical_date = coerce_datetime(logical_date) + # For manual runs where logical_date is None, ensure no data_interval is set. + if logical_date is None and data_interval is not None: + raise ValueError("data_interval must be None when logical_date is None") + + if data_interval and not isinstance(data_interval, DataInterval): + data_interval = DataInterval(*map(coerce_datetime, data_interval)) + + if isinstance(run_type, DagRunType): + pass + elif isinstance(run_type, str): # Ensure the input value is valid. + run_type = DagRunType(run_type) + else: + raise ValueError(f"run_type should be a DagRunType, not {type(run_type)}") + + if not isinstance(run_id, str): + raise ValueError(f"`run_id` should be a str, not {type(run_id)}") + + # This is also done on the DagRun model class, but SQLAlchemy column + # validator does not work well for some reason. + if not re.match(RUN_ID_REGEX, run_id): + regex = airflow_conf.get("scheduler", "allowed_run_id_pattern").strip() + if not regex or not re.match(regex, run_id): + raise ValueError( + f"The run_id provided '{run_id}' does not match regex pattern " + f"'{regex}' or '{RUN_ID_REGEX}'" + ) + + # Prevent a manual run from using an ID that looks like a scheduled run. + if run_type == DagRunType.MANUAL: + if (inferred_run_type := DagRunType.from_run_id(run_id)) != DagRunType.MANUAL: + raise ValueError( + f"A {run_type.value} DAG run cannot use ID {run_id!r} since it " + f"is reserved for {inferred_run_type.value} runs" + ) + + # todo: AIP-78 add verification that if run type is backfill then we have a backfill id + copied_params = self.params.deep_merge(conf) + copied_params.validate() + orm_dagrun = _create_orm_dagrun( + dag=self, + run_id=run_id, + logical_date=logical_date, + data_interval=data_interval, + run_after=coerce_datetime(run_after), + start_date=coerce_datetime(start_date), + conf=conf, + state=state, + run_type=run_type, + creating_job_id=creating_job_id, + backfill_id=backfill_id, + triggered_by=triggered_by, + triggering_user_name=triggering_user_name, + partition_key=partition_key, + session=session, + ) + + if self.deadline: + for deadline in cast("list", self.deadline): + if isinstance(deadline.reference, DeadlineReference.TYPES.DAGRUN): + deadline_time = deadline.reference.evaluate_with( + session=session, + interval=deadline.interval, + dag_id=self.dag_id, + run_id=run_id, + ) + if deadline_time is not None: + session.add( + Deadline( + deadline_time=deadline_time, + callback=deadline.callback, + dagrun_id=orm_dagrun.id, + dag_id=orm_dagrun.dag_id, + ) + ) + Stats.incr("deadline_alerts.deadline_created", tags={"dag_id": self.dag_id}) + + return orm_dagrun + + @provide_session + def set_task_instance_state( + self, + *, + task_id: str, + map_indexes: Collection[int] | None = None, + run_id: str | None = None, + state: TaskInstanceState, + upstream: bool = False, + downstream: bool = False, + future: bool = False, + past: bool = False, + commit: bool = True, + session=NEW_SESSION, + ) -> list[TaskInstance]: + """ + Set the state of a TaskInstance and clear downstream tasks in failed or upstream_failed state. + + :param task_id: Task ID of the TaskInstance + :param map_indexes: Only set TaskInstance if its map_index matches. + If None (default), all mapped TaskInstances of the task are set. + :param run_id: The run_id of the TaskInstance + :param state: State to set the TaskInstance to + :param upstream: Include all upstream tasks of the given task_id + :param downstream: Include all downstream tasks of the given task_id + :param future: Include all future TaskInstances of the given task_id + :param commit: Commit changes + :param past: Include all past TaskInstances of the given task_id + """ + from airflow.api.common.mark_tasks import set_state + + task = self.get_task(task_id) + task.dag = self + + tasks_to_set_state: list[SerializedOperator | tuple[SerializedOperator, int]] + if map_indexes is None: + tasks_to_set_state = [task] + else: + tasks_to_set_state = [(task, map_index) for map_index in map_indexes] + + altered = set_state( + tasks=tasks_to_set_state, + run_id=run_id, + upstream=upstream, + downstream=downstream, + future=future, + past=past, + state=state, + commit=commit, + session=session, + ) + + if not commit: + return altered + + # Clear downstream tasks that are in failed/upstream_failed state to resume them. + # Flush the session so that the tasks marked success are reflected in the db. + session.flush() + subset = self.partial_subset( + task_ids={task_id}, + include_downstream=True, + include_upstream=False, + ) + + # Raises an error if not found + dr_id, logical_date = session.execute( + select(DagRun.id, DagRun.logical_date).where( + DagRun.run_id == run_id, DagRun.dag_id == self.dag_id + ) + ).one() + + # Now we want to clear downstreams of tasks that had their state set... + clear_kwargs = { + "only_failed": True, + "session": session, + # Exclude the task itself from being cleared. + "exclude_task_ids": frozenset((task_id,)), + } + if not future and not past: # Simple case 1: we're only dealing with exactly one run. + clear_kwargs["run_id"] = run_id + subset.clear(**clear_kwargs) + elif future and past: # Simple case 2: we're clearing ALL runs. + subset.clear(**clear_kwargs) + else: # Complex cases: we may have more than one run, based on a date range. + # Make 'future' and 'past' make some sense when multiple runs exist + # for the same logical date. We order runs by their id and only + # clear runs have larger/smaller ids. + exclude_run_id_stmt = select(DagRun.run_id).where(DagRun.logical_date == logical_date) + if future: + clear_kwargs["start_date"] = logical_date + exclude_run_id_stmt = exclude_run_id_stmt.where(DagRun.id > dr_id) + else: + clear_kwargs["end_date"] = logical_date + exclude_run_id_stmt = exclude_run_id_stmt.where(DagRun.id < dr_id) + subset.clear(exclude_run_ids=frozenset(session.scalars(exclude_run_id_stmt)), **clear_kwargs) + return altered + + @overload + def _get_task_instances( + self, + *, + task_ids: Collection[str | tuple[str, int]] | None, + start_date: datetime.datetime | None, + end_date: datetime.datetime | None, + run_id: str | None, + state: TaskInstanceState | Sequence[TaskInstanceState], + exclude_task_ids: Collection[str | tuple[str, int]] | None, + exclude_run_ids: frozenset[str] | None, + session: Session, + ) -> Iterable[TaskInstance]: ... # pragma: no cover + + @overload + def _get_task_instances( + self, + *, + task_ids: Collection[str | tuple[str, int]] | None, + as_pk_tuple: Literal[True], + start_date: datetime.datetime | None, + end_date: datetime.datetime | None, + run_id: str | None, + state: TaskInstanceState | Sequence[TaskInstanceState], + exclude_task_ids: Collection[str | tuple[str, int]] | None, + exclude_run_ids: frozenset[str] | None, + session: Session, + ) -> set[TaskInstanceKey]: ... # pragma: no cover + + def _get_task_instances( + self, + *, + task_ids: Collection[str | tuple[str, int]] | None, + as_pk_tuple: Literal[True, None] = None, + start_date: datetime.datetime | None, + end_date: datetime.datetime | None, + run_id: str | None, + state: TaskInstanceState | Sequence[TaskInstanceState], + exclude_task_ids: Collection[str | tuple[str, int]] | None, + exclude_run_ids: frozenset[str] | None, + session: Session, + ) -> Iterable[TaskInstance] | set[TaskInstanceKey]: + from airflow.models.taskinstance import TaskInstance + + # If we are looking at dependent dags we want to avoid UNION calls + # in SQL (it doesn't play nice with fields that have no equality operator, + # like JSON types), we instead build our result set separately. + # + # This will be empty if we are only looking at one dag, in which case + # we can return the filtered TI query object directly. + result: set[TaskInstanceKey] = set() + + # Do we want full objects, or just the primary columns? + if as_pk_tuple: + tis_pk = select( + TaskInstance.dag_id, + TaskInstance.task_id, + TaskInstance.run_id, + TaskInstance.map_index, + ) + tis_pk = tis_pk.join(TaskInstance.dag_run) + else: + tis_full = select(TaskInstance) + tis_full = tis_full.join(TaskInstance.dag_run) + + # Apply common filters + def apply_filters(query): + if self.partial: + query = query.where( + TaskInstance.dag_id == self.dag_id, TaskInstance.task_id.in_(self.task_ids) + ) + else: + query = query.where(TaskInstance.dag_id == self.dag_id) + if run_id: + query = query.where(TaskInstance.run_id == run_id) + if start_date: + query = query.where(DagRun.logical_date >= start_date) + if task_ids is not None: + # Use the selector condition directly without intermediate variable + query = query.where(TaskInstance.ti_selector_condition(task_ids)) + if end_date: + query = query.where(DagRun.logical_date <= end_date) + return query + + if as_pk_tuple: + tis_pk = apply_filters(tis_pk) + else: + tis_full = apply_filters(tis_full) + + def apply_state_filter(query): + if state: + if isinstance(state, (str, TaskInstanceState)): + query = query.where(TaskInstance.state == state) + elif len(state) == 1: + query = query.where(TaskInstance.state == state[0]) + else: + # this is required to deal with NULL values + if None in state: + if all(x is None for x in state): + query = query.where(TaskInstance.state.is_(None)) + else: + not_none_state = [s for s in state if s] + query = query.where( + or_(TaskInstance.state.in_(not_none_state), TaskInstance.state.is_(None)) + ) + else: + query = query.where(TaskInstance.state.in_(state)) + + if exclude_run_ids: + query = query.where(TaskInstance.run_id.not_in(exclude_run_ids)) + return query + + if as_pk_tuple: + tis_pk = apply_state_filter(tis_pk) + else: + tis_full = apply_state_filter(tis_full) + + if result or as_pk_tuple: + # Only execute the `ti` query if we have also collected some other results + if as_pk_tuple: + tis_query = session.execute(tis_pk).all() + result.update(TaskInstanceKey(**cols._mapping) for cols in tis_query) + else: + result.update(ti.key for ti in session.scalars(tis_full)) + + if exclude_task_ids is not None: + result = { + task + for task in result + if task.task_id not in exclude_task_ids + and (task.task_id, task.map_index) not in exclude_task_ids + } + + if as_pk_tuple: + return result + if result: + # We've been asked for objects, lets combine it all back in to a result set + ti_filters = TaskInstance.filter_for_tis(result) + if ti_filters is not None: + tis_final = select(TaskInstance).where(ti_filters) + return session.scalars(tis_final) + elif exclude_task_ids is None: + pass # Disable filter if not set. + elif isinstance(next(iter(exclude_task_ids), None), str): + tis_full = tis_full.where(TaskInstance.task_id.notin_(exclude_task_ids)) + else: + tis_full = tis_full.where( + tuple_(TaskInstance.task_id, TaskInstance.map_index).not_in(exclude_task_ids) + ) + + return session.scalars(tis_full) + + @overload + def clear( + self, + *, + dry_run: Literal[True], + task_ids: Collection[str | tuple[str, int]] | None = None, + run_id: str, + only_failed: bool = False, + only_running: bool = False, + dag_run_state: DagRunState = DagRunState.QUEUED, + session: Session = NEW_SESSION, + exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None = frozenset(), + exclude_run_ids: frozenset[str] | None = frozenset(), + run_on_latest_version: bool = False, + ) -> list[TaskInstance]: ... # pragma: no cover + + @overload + def clear( + self, + *, + task_ids: Collection[str | tuple[str, int]] | None = None, + run_id: str, + only_failed: bool = False, + only_running: bool = False, + dag_run_state: DagRunState = DagRunState.QUEUED, + dry_run: Literal[False] = False, + session: Session = NEW_SESSION, + exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None = frozenset(), + exclude_run_ids: frozenset[str] | None = frozenset(), + run_on_latest_version: bool = False, + ) -> int: ... # pragma: no cover + + @overload + def clear( + self, + *, + dry_run: Literal[True], + task_ids: Collection[str | tuple[str, int]] | None = None, + start_date: datetime.datetime | None = None, + end_date: datetime.datetime | None = None, + only_failed: bool = False, + only_running: bool = False, + dag_run_state: DagRunState = DagRunState.QUEUED, + session: Session = NEW_SESSION, + exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None = frozenset(), + exclude_run_ids: frozenset[str] | None = frozenset(), + run_on_latest_version: bool = False, + ) -> list[TaskInstance]: ... # pragma: no cover + + @overload + def clear( + self, + *, + task_ids: Collection[str | tuple[str, int]] | None = None, + start_date: datetime.datetime | None = None, + end_date: datetime.datetime | None = None, + only_failed: bool = False, + only_running: bool = False, + dag_run_state: DagRunState = DagRunState.QUEUED, + dry_run: Literal[False] = False, + session: Session = NEW_SESSION, + exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None = frozenset(), + exclude_run_ids: frozenset[str] | None = frozenset(), + run_on_latest_version: bool = False, + ) -> int: ... # pragma: no cover + + @provide_session + def clear( + self, + task_ids: Collection[str | tuple[str, int]] | None = None, + *, + run_id: str | None = None, + start_date: datetime.datetime | None = None, + end_date: datetime.datetime | None = None, + only_failed: bool = False, + only_running: bool = False, + dag_run_state: DagRunState = DagRunState.QUEUED, + dry_run: bool = False, + session: Session = NEW_SESSION, + exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None = frozenset(), + exclude_run_ids: frozenset[str] | None = frozenset(), + run_on_latest_version: bool = False, + ) -> int | Iterable[TaskInstance]: + """ + Clear a set of task instances associated with the current dag for a specified date range. + + :param task_ids: List of task ids or (``task_id``, ``map_index``) tuples to clear + :param run_id: The run_id for which the tasks should be cleared + :param start_date: The minimum logical_date to clear + :param end_date: The maximum logical_date to clear + :param only_failed: Only clear failed tasks + :param only_running: Only clear running tasks. + :param dag_run_state: state to set DagRun to. If set to False, dagrun state will not + be changed. + :param dry_run: Find the tasks to clear but don't clear them. + :param run_on_latest_version: whether to run on latest serialized DAG and Bundle version + :param session: The sqlalchemy session to use + :param exclude_task_ids: A set of ``task_id`` or (``task_id``, ``map_index``) + tuples that should not be cleared + :param exclude_run_ids: A set of ``run_id`` or (``run_id``) + """ + from airflow.models.taskinstance import clear_task_instances + + state: list[TaskInstanceState] = [] + if only_failed: + state += [TaskInstanceState.FAILED, TaskInstanceState.UPSTREAM_FAILED] + if only_running: + # Yes, having `+=` doesn't make sense, but this was the existing behaviour + state += [TaskInstanceState.RUNNING] + + tis_result = self._get_task_instances( + task_ids=task_ids, + start_date=start_date, + end_date=end_date, + run_id=run_id, + state=state, + session=session, + exclude_task_ids=exclude_task_ids, + exclude_run_ids=exclude_run_ids, + ) + + if dry_run: + return list(tis_result) + + tis = list(tis_result) + + count = len(tis) + if count == 0: + return 0 + + clear_task_instances( + list(tis), + session, + dag_run_state=dag_run_state, + run_on_latest_version=run_on_latest_version, + ) + + session.flush() + return count + + @classmethod + def clear_dags( + cls, + dags: Iterable[SerializedDAG], + *, + start_date=None, + end_date=None, + only_failed=False, + only_running=False, + dag_run_state=DagRunState.QUEUED, + dry_run: bool = False, + ): + if dry_run: + tis = itertools.chain.from_iterable( + dag.clear( + start_date=start_date, + end_date=end_date, + only_failed=only_failed, + only_running=only_running, + dag_run_state=dag_run_state, + dry_run=True, + ) + for dag in dags + ) + return list(tis) + + return sum( + dag.clear( + start_date=start_date, + end_date=end_date, + only_failed=only_failed, + only_running=only_running, + dag_run_state=dag_run_state, + dry_run=False, + ) + for dag in dags + ) + + def get_edge_info(self, upstream_task_id: str, downstream_task_id: str) -> EdgeInfoType: + """Return edge information for the given pair of tasks or an empty edge if there is no information.""" + # Note - older serialized dags may not have edge_info being a dict at all + empty = cast("EdgeInfoType", {}) + if self.edge_info: + return self.edge_info.get(upstream_task_id, {}).get(downstream_task_id, empty) + return empty + + +@provide_session +def _create_orm_dagrun( + *, + dag: SerializedDAG, + run_id: str, + logical_date: datetime.datetime | None, + data_interval: DataInterval | None, + run_after: datetime.datetime, + start_date: datetime.datetime | None, + conf: Any, + state: DagRunState | None, + run_type: DagRunType, + creating_job_id: int | None, + backfill_id: NonNegativeInt | None, + triggered_by: DagRunTriggeredByType, + triggering_user_name: str | None = None, + partition_key: str | None = None, + session: Session = NEW_SESSION, +) -> DagRun: + bundle_version = None + if not dag.disable_bundle_versioning: + bundle_version = session.scalar( + select(DagModel.bundle_version).where(DagModel.dag_id == dag.dag_id), + ) + dag_version = DagVersion.get_latest_version(dag.dag_id, session=session) + if not dag_version: + raise AirflowException(f"Cannot create DagRun for DAG {dag.dag_id} because the dag is not serialized") + + run = DagRun( + dag_id=dag.dag_id, + run_id=run_id, + logical_date=logical_date, + start_date=start_date, + run_after=run_after, + conf=conf, + state=state, + run_type=run_type, + creating_job_id=creating_job_id, + data_interval=data_interval, + triggered_by=triggered_by, + triggering_user_name=triggering_user_name, + backfill_id=backfill_id, + bundle_version=bundle_version, + partition_key=partition_key, + ) + # Load defaults into the following two fields to ensure result can be serialized detached + max_log_template_id = session.scalar(select(func.max(LogTemplate.__table__.c.id))) + run.log_template_id = int(max_log_template_id) if max_log_template_id is not None else 0 + run.created_dag_version = dag_version + run.consumed_asset_events = [] + session.add(run) + session.flush() + run.dag = dag + # create the associated task instances + # state is None at the moment of creation + run.verify_integrity(session=session, dag_version_id=dag_version.id) + return run diff --git a/airflow-core/src/airflow/serialization/definitions/taskgroup.py b/airflow-core/src/airflow/serialization/definitions/taskgroup.py index 6c0add8cdfb92..3dcb62aa30f4d 100644 --- a/airflow-core/src/airflow/serialization/definitions/taskgroup.py +++ b/airflow-core/src/airflow/serialization/definitions/taskgroup.py @@ -37,7 +37,7 @@ from airflow.serialization.serialized_objects import SerializedDAG, SerializedOperator -@attrs.define(kw_only=True, repr=False) +@attrs.define(eq=False, hash=False, kw_only=True) class SerializedTaskGroup(DAGNode): """Serialized representation of a TaskGroup used in protected processes.""" diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py b/airflow-core/src/airflow/serialization/serialized_objects.py index 7ada1aea7c03e..ce136f998cee8 100644 --- a/airflow-core/src/airflow/serialization/serialized_objects.py +++ b/airflow-core/src/airflow/serialization/serialized_objects.py @@ -21,31 +21,18 @@ from __future__ import annotations import collections.abc -import copy import datetime import enum import itertools import logging import math -import re import sys import weakref from collections.abc import Collection, Iterable, Iterator, Mapping, Sequence from functools import cache, cached_property, lru_cache from inspect import signature from textwrap import dedent -from typing import ( - TYPE_CHECKING, - Any, - ClassVar, - Literal, - NamedTuple, - TypeAlias, - TypeGuard, - TypeVar, - cast, - overload, -) +from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple, TypeAlias, TypeVar, cast, overload import attrs import lazy_object_proxy @@ -53,30 +40,17 @@ import pydantic from dateutil import relativedelta from pendulum.tz.timezone import FixedTimezone, Timezone -from sqlalchemy import func, or_, select, tuple_ from airflow import macros from airflow._shared.module_loading import import_string, qualname -from airflow._shared.timezones.timezone import coerce_datetime, from_timestamp, parse_timezone, utcnow +from airflow._shared.timezones.timezone import from_timestamp, parse_timezone, utcnow from airflow.callbacks.callback_requests import DagCallbackRequest, TaskCallbackRequest -from airflow.configuration import conf as airflow_conf -from airflow.exceptions import ( - AirflowException, - DeserializationError, - SerializationError, - TaskNotFound, -) +from airflow.exceptions import AirflowException, DeserializationError, SerializationError from airflow.models.connection import Connection -from airflow.models.dag import DagModel -from airflow.models.dag_version import DagVersion -from airflow.models.dagrun import RUN_ID_REGEX, DagRun -from airflow.models.deadline import Deadline from airflow.models.expandinput import create_expand_input from airflow.models.taskinstancekey import TaskInstanceKey -from airflow.models.tasklog import LogTemplate from airflow.models.xcom import XComModel from airflow.models.xcom_arg import SchedulerXComArg, deserialize_xcom_arg -from airflow.observability.stats import Stats from airflow.sdk import DAG, Asset, AssetAlias, BaseOperator, XComArg from airflow.sdk.bases.operator import OPERATOR_DEFAULTS # TODO: Copy this into the scheduler? from airflow.sdk.definitions._internal.node import DAGNode @@ -86,7 +60,7 @@ AssetUniqueKey, BaseAsset, ) -from airflow.sdk.definitions.deadline import DeadlineAlert, DeadlineReference +from airflow.sdk.definitions.deadline import DeadlineAlert from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.definitions.operator_resources import Resources from airflow.sdk.definitions.param import Param, ParamsDict @@ -101,6 +75,7 @@ SerializedAssetBase, SerializedAssetUniqueKey, ) +from airflow.serialization.definitions.dag import SerializedDAG from airflow.serialization.definitions.param import SerializedParam, SerializedParamsDict from airflow.serialization.definitions.taskgroup import SerializedMappedTaskGroup, SerializedTaskGroup from airflow.serialization.encoders import ( @@ -126,39 +101,29 @@ from airflow.ti_deps.deps.not_previously_skipped_dep import NotPreviouslySkippedDep from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep -from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable from airflow.triggers.base import BaseTrigger, StartTriggerArgs from airflow.utils.code_utils import get_python_source from airflow.utils.context import ConnectionAccessor, Context, VariableAccessor from airflow.utils.db import LazySelectSequence from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.session import NEW_SESSION, create_session, provide_session -from airflow.utils.state import DagRunState, TaskInstanceState -from airflow.utils.types import DagRunTriggeredByType, DagRunType +from airflow.utils.session import create_session if TYPE_CHECKING: from inspect import Parameter - from pydantic import NonNegativeInt - from sqlalchemy.orm import Session + from kubernetes.client import models as k8s # noqa: TC004 from airflow.models.expandinput import SchedulerExpandInput from airflow.models.mappedoperator import MappedOperator as SerializedMappedOperator from airflow.models.taskinstance import TaskInstance + from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator # noqa: TC004 from airflow.sdk import BaseOperatorLink - from airflow.sdk.definitions.edges import EdgeInfoType from airflow.serialization.json_schema import Validator from airflow.task.trigger_rule import TriggerRule from airflow.ti_deps.deps.base_ti_dep import BaseTIDep + from airflow.timetables.base import DagRunInfo, Timetable from airflow.timetables.simple import PartitionMapper - try: - from kubernetes.client import models as k8s # noqa: TC004 - - from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator # noqa: TC004 - except ImportError: - pass - SerializedOperator: TypeAlias = "SerializedMappedOperator | SerializedBaseOperator" SdkOperator: TypeAlias = BaseOperator | MappedOperator @@ -480,7 +445,7 @@ def from_json(cls, serialized_obj: str) -> BaseSerialization | dict | list | set return cls.from_dict(json.loads(serialized_obj)) @classmethod - def from_dict(cls, serialized_obj: dict[Encoding, Any]) -> BaseSerialization | dict | list | set | tuple: + def from_dict(cls, serialized_obj: dict[Encoding, Any]) -> Any: """Deserialize a dict of type decorators and reconstructs all DAGs and operators it contains.""" return cls.deserialize(serialized_obj) @@ -617,7 +582,7 @@ def serialize( type_=DAT.ASSET_ALIAS_UNIQUE_KEY, ) elif isinstance(var, DAG): - return cls._encode(SerializedDAG.serialize_dag(var), type_=DAT.DAG) + return cls._encode(DagSerialization.serialize_dag(var), type_=DAT.DAG) elif isinstance(var, DeadlineAlert): return cls._encode(DeadlineAlert.serialize_deadline_alert(var), type_=DAT.DEADLINE_ALERT) elif isinstance(var, Resources): @@ -761,7 +726,7 @@ def deserialize(cls, encoded_var: Any) -> Any: elif type_ == DAT.ASSET_ALIAS_UNIQUE_KEY: return AssetAliasUniqueKey(name=var["name"]) elif type_ == DAT.DAG: - return SerializedDAG.deserialize_dag(var) + return DagSerialization.deserialize_dag(var) elif type_ == DAT.OP: return SerializedBaseOperator.deserialize_operator(var) elif type_ == DAT.DATETIME: @@ -2190,141 +2155,11 @@ def get_parse_time_mapped_ti_count(self) -> int: return group.get_parse_time_mapped_ti_count() -@provide_session -def _create_orm_dagrun( - *, - dag: SerializedDAG, - run_id: str, - logical_date: datetime.datetime | None, - data_interval: DataInterval | None, - run_after: datetime.datetime, - start_date: datetime.datetime | None, - conf: Any, - state: DagRunState | None, - run_type: DagRunType, - creating_job_id: int | None, - backfill_id: NonNegativeInt | None, - triggered_by: DagRunTriggeredByType, - triggering_user_name: str | None = None, - partition_key: str | None = None, - session: Session = NEW_SESSION, -) -> DagRun: - bundle_version = None - if not dag.disable_bundle_versioning: - bundle_version = session.scalar( - select(DagModel.bundle_version).where(DagModel.dag_id == dag.dag_id), - ) - dag_version = DagVersion.get_latest_version(dag.dag_id, session=session) - if not dag_version: - raise AirflowException(f"Cannot create DagRun for DAG {dag.dag_id} because the dag is not serialized") - - run = DagRun( - dag_id=dag.dag_id, - run_id=run_id, - logical_date=logical_date, - start_date=start_date, - run_after=run_after, - conf=conf, - state=state, - run_type=run_type, - creating_job_id=creating_job_id, - data_interval=data_interval, - triggered_by=triggered_by, - triggering_user_name=triggering_user_name, - backfill_id=backfill_id, - bundle_version=bundle_version, - partition_key=partition_key, - ) - # Load defaults into the following two fields to ensure result can be serialized detached - max_log_template_id = session.scalar(select(func.max(LogTemplate.__table__.c.id))) - run.log_template_id = int(max_log_template_id) if max_log_template_id is not None else 0 - run.created_dag_version = dag_version - run.consumed_asset_events = [] - session.add(run) - session.flush() - run.dag = dag - # create the associated task instances - # state is None at the moment of creation - run.verify_integrity(session=session, dag_version_id=dag_version.id) - return run - - -class SerializedDAG(BaseSerialization): - """ - A JSON serializable representation of DAG. - - A stringified DAG can only be used in the scope of scheduler and webserver, because fields - that are not serializable, such as functions and customer defined classes, are casted to - strings. - """ +class DagSerialization(BaseSerialization): + """Logic to encode a ``DAG`` object and decode the data into ``SerializedDAG``.""" _decorated_fields: ClassVar[set[str]] = {"default_args", "access_control"} - access_control: dict[str, dict[str, Collection[str]]] | None = None - catchup: bool - dag_id: str - dag_display_name: str - dagrun_timeout: datetime.timedelta | None - deadline: list[DeadlineAlert] | DeadlineAlert | None - default_args: dict[str, Any] - description: str | None - disable_bundle_versioning: bool - doc_md: str | None - edge_info: dict[str, dict[str, EdgeInfoType]] - end_date: datetime.datetime | None - fail_fast: bool - has_on_failure_callback: bool - has_on_success_callback: bool - is_paused_upon_creation: bool | None - max_active_runs: int - max_active_tasks: int - max_consecutive_failed_dag_runs: int - owner_links: dict[str, str] - params: SerializedParamsDict - partial: bool - render_template_as_native_obj: bool - start_date: datetime.datetime | None - tags: set[str] - task_dict: dict[str, SerializedOperator] - task_group: SerializedTaskGroup - template_searchpath: tuple[str, ...] | None - timetable: Timetable - timezone: FixedTimezone | Timezone - - last_loaded: datetime.datetime - # this will only be set at serialization time - # it's only use is for determining the relative fileloc based only on the serialize dag - _processor_dags_folder: str - - def __init__(self, *, dag_id: str) -> None: - self.catchup = False # Schema default - self.dag_id = self.dag_display_name = dag_id - self.dagrun_timeout = None - self.deadline = None - self.default_args = {} - self.description = None - self.disable_bundle_versioning = False - self.doc_md = None - self.edge_info = {} - self.end_date = None - self.fail_fast = False - self.has_on_failure_callback = False - self.has_on_success_callback = False - self.is_paused_upon_creation = None - self.max_active_runs = 16 # Schema default - self.max_active_tasks = 16 # Schema default - self.max_consecutive_failed_dag_runs = 0 # Schema default - self.owner_links = {} - self.params = SerializedParamsDict() - self.partial = False - self.render_template_as_native_obj = False - self.start_date = None - self.tags = set() - self.template_searchpath = None - - def __repr__(self) -> str: - return f"" - @staticmethod def __get_constructor_defaults(): param_to_attr = { @@ -2341,39 +2176,6 @@ def __get_constructor_defaults(): _json_schema: ClassVar[Validator] = lazy_object_proxy.Proxy(load_dag_schema) - @classmethod - def get_serialized_fields(cls) -> frozenset[str]: - return frozenset( - { - "access_control", - "catchup", - "dag_display_name", - "dag_id", - "dagrun_timeout", - "deadline", - "default_args", - "description", - "disable_bundle_versioning", - "doc_md", - "edge_info", - "end_date", - "fail_fast", - "fileloc", - "is_paused_upon_creation", - "max_active_runs", - "max_active_tasks", - "max_consecutive_failed_dag_runs", - "owner_links", - "relative_fileloc", - "render_template_as_native_obj", - "start_date", - "tags", - "task_group", - "timetable", - "timezone", - } - ) - @classmethod def serialize_dag(cls, dag: DAG) -> dict: """Serialize a DAG into a JSON object.""" @@ -2732,7 +2534,7 @@ def _create_compat_timetable(value): var_["group"] = "asset" for k, v in list(task_var.items()): - op_defaults = SerializedDAG.get_schema_defaults("operator") + op_defaults = DagSerialization.get_schema_defaults("operator") if k in op_defaults and v == op_defaults[k]: del task_var[k] @@ -2762,899 +2564,6 @@ def from_dict(cls, serialized_obj: dict) -> SerializedDAG: # Pass client_defaults directly to deserialize_dag return cls.deserialize_dag(serialized_obj["dag"], client_defaults) - @classmethod - @provide_session - def bulk_write_to_db( - cls, - bundle_name: str, - bundle_version: str | None, - dags: Collection[DAG | LazyDeserializedDAG], - parse_duration: float | None = None, - session: Session = NEW_SESSION, - ) -> None: - """ - Ensure the DagModel rows for the given dags are up-to-date in the dag table in the DB. - - :param dags: the DAG objects to save to the DB - :return: None - """ - if not dags: - return - - from airflow.dag_processing.collection import AssetModelOperation, DagModelOperation - - log.info("Sync %s DAGs", len(dags)) - dag_op = DagModelOperation( - bundle_name=bundle_name, - bundle_version=bundle_version, - dags={d.dag_id: LazyDeserializedDAG.from_dag(d) for d in dags}, - ) - - orm_dags = dag_op.add_dags(session=session) - dag_op.update_dags(orm_dags, parse_duration, session=session) - - asset_op = AssetModelOperation.collect(dag_op.dags) - - orm_assets = asset_op.sync_assets(session=session) - orm_asset_aliases = asset_op.sync_asset_aliases(session=session) - session.flush() # This populates id so we can create fks in later calls. - - orm_dags = dag_op.find_orm_dags(session=session) # Refetch so relationship is up to date. - asset_op.add_dag_asset_references(orm_dags, orm_assets, session=session) - asset_op.add_dag_asset_alias_references(orm_dags, orm_asset_aliases, session=session) - asset_op.add_dag_asset_name_uri_references(session=session) - asset_op.add_task_asset_references(orm_dags, orm_assets, session=session) - asset_op.activate_assets_if_possible(orm_assets.values(), session=session) - session.flush() # Activation is needed when we add trigger references. - - asset_op.add_asset_trigger_references(orm_assets, session=session) - dag_op.update_dag_asset_expression(orm_dags=orm_dags, orm_assets=orm_assets) - session.flush() - - @property - def tasks(self) -> Sequence[SerializedOperator]: - return list(self.task_dict.values()) - - @property - def task_ids(self) -> list[str]: - return list(self.task_dict) - - @property - def roots(self) -> list[SerializedOperator]: - return [task for task in self.tasks if not task.upstream_list] - - @property - def owner(self) -> str: - return ", ".join({t.owner for t in self.tasks}) - - def has_task(self, task_id: str) -> bool: - return task_id in self.task_dict - - def get_task(self, task_id: str) -> SerializedOperator: - if task_id in self.task_dict: - return self.task_dict[task_id] - raise TaskNotFound(f"Task {task_id} not found") - - @property - def task_group_dict(self): - return {k: v for k, v in self.task_group.get_task_group_dict().items() if k is not None} - - def partial_subset( - self, - task_ids: str | Iterable[str], - include_downstream: bool = False, - include_upstream: bool = True, - include_direct_upstream: bool = False, - exclude_original: bool = False, - ): - from airflow.models.mappedoperator import MappedOperator as SerializedMappedOperator - - def is_task(obj) -> TypeGuard[SerializedOperator]: - return isinstance(obj, (SerializedMappedOperator, SerializedBaseOperator)) - - # deep-copying self.task_dict and self.task_group takes a long time, and we don't want all - # the tasks anyway, so we copy the tasks manually later - memo = {id(self.task_dict): None, id(self.task_group): None} - dag = copy.deepcopy(self, memo) - - if isinstance(task_ids, str): - matched_tasks = [t for t in self.tasks if task_ids in t.task_id] - else: - matched_tasks = [t for t in self.tasks if t.task_id in task_ids] - - also_include_ids: set[str] = set() - for t in matched_tasks: - if include_downstream: - for rel in t.get_flat_relatives(upstream=False): - also_include_ids.add(rel.task_id) - if rel not in matched_tasks: # if it's in there, we're already processing it - # need to include setups and teardowns for tasks that are in multiple - # non-collinear setup/teardown paths - if not rel.is_setup and not rel.is_teardown: - also_include_ids.update( - x.task_id for x in rel.get_upstreams_only_setups_and_teardowns() - ) - if include_upstream: - also_include_ids.update(x.task_id for x in t.get_upstreams_follow_setups()) - else: - if not t.is_setup and not t.is_teardown: - also_include_ids.update(x.task_id for x in t.get_upstreams_only_setups_and_teardowns()) - if t.is_setup and not include_downstream: - also_include_ids.update(x.task_id for x in t.downstream_list if x.is_teardown) - - also_include: list[SerializedOperator] = [self.task_dict[x] for x in also_include_ids] - direct_upstreams: list[SerializedOperator] = [] - if include_direct_upstream: - for t in itertools.chain(matched_tasks, also_include): - upstream = (u for u in t.upstream_list if is_task(u)) - direct_upstreams.extend(upstream) - - # Make sure to not recursively deepcopy the dag or task_group while copying the task. - # task_group is reset later - def _deepcopy_task(t) -> SerializedOperator: - memo.setdefault(id(t.task_group), None) - return copy.deepcopy(t, memo) - - # Compiling the unique list of tasks that made the cut - if exclude_original: - matched_tasks = [] - dag.task_dict = { - t.task_id: _deepcopy_task(t) - for t in itertools.chain(matched_tasks, also_include, direct_upstreams) - } - - def filter_task_group(group, parent_group): - """Exclude tasks not included in the partial dag from the given TaskGroup.""" - # We want to deepcopy _most but not all_ attributes of the task group, so we create a shallow copy - # and then manually deep copy the instances. (memo argument to deepcopy only works for instances - # of classes, not "native" properties of an instance) - copied = copy.copy(group) - - memo[id(group.children)] = {} - if parent_group: - memo[id(group.parent_group)] = parent_group - for attr in type(group).__slots__: - value = getattr(group, attr) - value = copy.deepcopy(value, memo) - object.__setattr__(copied, attr, value) - - proxy = weakref.proxy(copied) - - for child in group.children.values(): - if is_task(child): - if child.task_id in dag.task_dict: - task = copied.children[child.task_id] = dag.task_dict[child.task_id] - task.task_group = proxy - else: - filtered_child = filter_task_group(child, proxy) - - # Only include this child TaskGroup if it is non-empty. - if filtered_child.children: - copied.children[child.group_id] = filtered_child - - return copied - - object.__setattr__(dag, "task_group", filter_task_group(self.task_group, None)) - - # Removing upstream/downstream references to tasks and TaskGroups that did not make - # the cut. - groups = dag.task_group.get_task_group_dict() - for g in groups.values(): - g.upstream_group_ids.intersection_update(groups) - g.downstream_group_ids.intersection_update(groups) - g.upstream_task_ids.intersection_update(dag.task_dict) - g.downstream_task_ids.intersection_update(dag.task_dict) - - for t in dag.tasks: - # Removing upstream/downstream references to tasks that did not - # make the cut - t.upstream_task_ids.intersection_update(dag.task_dict) - t.downstream_task_ids.intersection_update(dag.task_dict) - - dag.partial = len(dag.tasks) < len(self.tasks) - - return dag - - @cached_property - def _time_restriction(self) -> TimeRestriction: - start_dates = [t.start_date for t in self.tasks if t.start_date] - if self.start_date is not None: - start_dates.append(self.start_date) - earliest = None - if start_dates: - earliest = coerce_datetime(min(start_dates)) - latest = coerce_datetime(self.end_date) - end_dates = [t.end_date for t in self.tasks if t.end_date] - if len(end_dates) == len(self.tasks): # not exists null end_date - if self.end_date is not None: - end_dates.append(self.end_date) - if end_dates: - latest = coerce_datetime(max(end_dates)) - return TimeRestriction(earliest, latest, self.catchup) - - def next_dagrun_info( - self, - last_automated_dagrun: None | DataInterval, - *, - restricted: bool = True, - ) -> DagRunInfo | None: - """ - Get information about the next DagRun of this dag after ``date_last_automated_dagrun``. - - This calculates what time interval the next DagRun should operate on - (its logical date) and when it can be scheduled, according to the - dag's timetable, start_date, end_date, etc. This doesn't check max - active run or any other "max_active_tasks" type limits, but only - performs calculations based on the various date and interval fields of - this dag and its tasks. - - :param last_automated_dagrun: The ``max(logical_date)`` of - existing "automated" DagRuns for this dag (scheduled or backfill, - but not manual). - :param restricted: If set to *False* (default is *True*), ignore - ``start_date``, ``end_date``, and ``catchup`` specified on the DAG - or tasks. - :return: DagRunInfo of the next dagrun, or None if a dagrun is not - going to be scheduled. - """ - if restricted: - restriction = self._time_restriction - else: - restriction = TimeRestriction(earliest=None, latest=None, catchup=True) - try: - info = self.timetable.next_dagrun_info( - last_automated_data_interval=last_automated_dagrun, - restriction=restriction, - ) - except Exception: - log.exception( - "Failed to fetch run info after data interval %s for DAG %r", - last_automated_dagrun, - self.dag_id, - ) - info = None - return info - - def iter_dagrun_infos_between( - self, - earliest: datetime.datetime | None, - latest: datetime.datetime, - *, - align: bool = True, - ) -> Iterable[DagRunInfo]: - """ - Yield DagRunInfo using this DAG's timetable between given interval. - - DagRunInfo instances yielded if their ``logical_date`` is not earlier - than ``earliest``, nor later than ``latest``. The instances are ordered - by their ``logical_date`` from earliest to latest. - - If ``align`` is ``False``, the first run will happen immediately on - ``earliest``, even if it does not fall on the logical timetable schedule. - The default is ``True``. - - Example: A DAG is scheduled to run every midnight (``0 0 * * *``). If - ``earliest`` is ``2021-06-03 23:00:00``, the first DagRunInfo would be - ``2021-06-03 23:00:00`` if ``align=False``, and ``2021-06-04 00:00:00`` - if ``align=True``. - """ - if earliest is None: - earliest = self._time_restriction.earliest - if earliest is None: - raise ValueError("earliest was None and we had no value in time_restriction to fallback on") - earliest = coerce_datetime(earliest) - latest = coerce_datetime(latest) - - restriction = TimeRestriction(earliest, latest, catchup=True) - - try: - info = self.timetable.next_dagrun_info( - last_automated_data_interval=None, - restriction=restriction, - ) - except Exception: - log.exception( - "Failed to fetch run info after data interval %s for DAG %r", - None, - self.dag_id, - ) - info = None - - if info is None: - # No runs to be scheduled between the user-supplied timeframe. But - # if align=False, "invent" a data interval for the timeframe itself. - if not align: - yield DagRunInfo.interval(earliest, latest) - return - - # If align=False and earliest does not fall on the timetable's logical - # schedule, "invent" a data interval for it. - if not align and info.logical_date != earliest: - yield DagRunInfo.interval(earliest, info.data_interval.start) - - # Generate naturally according to schedule. - while info is not None: - yield info - try: - info = self.timetable.next_dagrun_info( - last_automated_data_interval=info.data_interval, - restriction=restriction, - ) - except Exception: - log.exception( - "Failed to fetch run info after data interval %s for DAG %r", - info.data_interval if info else "", - self.dag_id, - ) - break - - @provide_session - def get_concurrency_reached(self, session=NEW_SESSION) -> bool: - """Return a boolean indicating whether the max_active_tasks limit for this DAG has been reached.""" - from airflow.models.taskinstance import TaskInstance - - total_tasks = session.scalar( - select(func.count(TaskInstance.task_id)).where( - TaskInstance.dag_id == self.dag_id, - TaskInstance.state == TaskInstanceState.RUNNING, - ) - ) - return total_tasks >= self.max_active_tasks - - @provide_session - def create_dagrun( - self, - *, - run_id: str, - logical_date: datetime.datetime | None = None, - data_interval: tuple[datetime.datetime, datetime.datetime] | None = None, - run_after: datetime.datetime, - conf: dict | None = None, - run_type: DagRunType, - triggered_by: DagRunTriggeredByType, - triggering_user_name: str | None = None, - state: DagRunState, - start_date: datetime.datetime | None = None, - creating_job_id: int | None = None, - backfill_id: NonNegativeInt | None = None, - partition_key: str | None = None, - session: Session = NEW_SESSION, - ) -> DagRun: - """ - Create a run for this DAG to run its tasks. - - :param run_id: ID of the dag_run - :param logical_date: date of execution - :param run_after: the datetime before which dag won't run - :param conf: Dict containing configuration/parameters to pass to the DAG - :param triggered_by: the entity which triggers the dag_run - :param triggering_user_name: the user name who triggers the dag_run - :param start_date: the date this dag run should be evaluated - :param creating_job_id: ID of the job creating this DagRun - :param backfill_id: ID of the backfill run if one exists - :param session: Unused. Only added in compatibility with database isolation mode - :return: The created DAG run. - - :meta private: - """ - logical_date = coerce_datetime(logical_date) - # For manual runs where logical_date is None, ensure no data_interval is set. - if logical_date is None and data_interval is not None: - raise ValueError("data_interval must be None when logical_date is None") - - if data_interval and not isinstance(data_interval, DataInterval): - data_interval = DataInterval(*map(coerce_datetime, data_interval)) - - if isinstance(run_type, DagRunType): - pass - elif isinstance(run_type, str): # Ensure the input value is valid. - run_type = DagRunType(run_type) - else: - raise ValueError(f"run_type should be a DagRunType, not {type(run_type)}") - - if not isinstance(run_id, str): - raise ValueError(f"`run_id` should be a str, not {type(run_id)}") - - # This is also done on the DagRun model class, but SQLAlchemy column - # validator does not work well for some reason. - if not re.match(RUN_ID_REGEX, run_id): - regex = airflow_conf.get("scheduler", "allowed_run_id_pattern").strip() - if not regex or not re.match(regex, run_id): - raise ValueError( - f"The run_id provided '{run_id}' does not match regex pattern " - f"'{regex}' or '{RUN_ID_REGEX}'" - ) - - # Prevent a manual run from using an ID that looks like a scheduled run. - if run_type == DagRunType.MANUAL: - if (inferred_run_type := DagRunType.from_run_id(run_id)) != DagRunType.MANUAL: - raise ValueError( - f"A {run_type.value} DAG run cannot use ID {run_id!r} since it " - f"is reserved for {inferred_run_type.value} runs" - ) - - # todo: AIP-78 add verification that if run type is backfill then we have a backfill id - copied_params = self.params.deep_merge(conf) - copied_params.validate() - orm_dagrun = _create_orm_dagrun( - dag=self, - run_id=run_id, - logical_date=logical_date, - data_interval=data_interval, - run_after=coerce_datetime(run_after), - start_date=coerce_datetime(start_date), - conf=conf, - state=state, - run_type=run_type, - creating_job_id=creating_job_id, - backfill_id=backfill_id, - triggered_by=triggered_by, - triggering_user_name=triggering_user_name, - partition_key=partition_key, - session=session, - ) - - if self.deadline: - for deadline in cast("list", self.deadline): - if isinstance(deadline.reference, DeadlineReference.TYPES.DAGRUN): - deadline_time = deadline.reference.evaluate_with( - session=session, - interval=deadline.interval, - dag_id=self.dag_id, - run_id=run_id, - ) - if deadline_time is not None: - session.add( - Deadline( - deadline_time=deadline_time, - callback=deadline.callback, - dagrun_id=orm_dagrun.id, - dag_id=orm_dagrun.dag_id, - ) - ) - Stats.incr("deadline_alerts.deadline_created", tags={"dag_id": self.dag_id}) - - return orm_dagrun - - @provide_session - def set_task_instance_state( - self, - *, - task_id: str, - map_indexes: Collection[int] | None = None, - run_id: str | None = None, - state: TaskInstanceState, - upstream: bool = False, - downstream: bool = False, - future: bool = False, - past: bool = False, - commit: bool = True, - session=NEW_SESSION, - ) -> list[TaskInstance]: - """ - Set the state of a TaskInstance and clear downstream tasks in failed or upstream_failed state. - - :param task_id: Task ID of the TaskInstance - :param map_indexes: Only set TaskInstance if its map_index matches. - If None (default), all mapped TaskInstances of the task are set. - :param run_id: The run_id of the TaskInstance - :param state: State to set the TaskInstance to - :param upstream: Include all upstream tasks of the given task_id - :param downstream: Include all downstream tasks of the given task_id - :param future: Include all future TaskInstances of the given task_id - :param commit: Commit changes - :param past: Include all past TaskInstances of the given task_id - """ - from airflow.api.common.mark_tasks import set_state - - task = self.get_task(task_id) - task.dag = self - - tasks_to_set_state: list[SerializedOperator | tuple[SerializedOperator, int]] - if map_indexes is None: - tasks_to_set_state = [task] - else: - tasks_to_set_state = [(task, map_index) for map_index in map_indexes] - - altered = set_state( - tasks=tasks_to_set_state, - run_id=run_id, - upstream=upstream, - downstream=downstream, - future=future, - past=past, - state=state, - commit=commit, - session=session, - ) - - if not commit: - return altered - - # Clear downstream tasks that are in failed/upstream_failed state to resume them. - # Flush the session so that the tasks marked success are reflected in the db. - session.flush() - subset = self.partial_subset( - task_ids={task_id}, - include_downstream=True, - include_upstream=False, - ) - - # Raises an error if not found - dr_id, logical_date = session.execute( - select(DagRun.id, DagRun.logical_date).where( - DagRun.run_id == run_id, DagRun.dag_id == self.dag_id - ) - ).one() - - # Now we want to clear downstreams of tasks that had their state set... - clear_kwargs = { - "only_failed": True, - "session": session, - # Exclude the task itself from being cleared. - "exclude_task_ids": frozenset((task_id,)), - } - if not future and not past: # Simple case 1: we're only dealing with exactly one run. - clear_kwargs["run_id"] = run_id - subset.clear(**clear_kwargs) - elif future and past: # Simple case 2: we're clearing ALL runs. - subset.clear(**clear_kwargs) - else: # Complex cases: we may have more than one run, based on a date range. - # Make 'future' and 'past' make some sense when multiple runs exist - # for the same logical date. We order runs by their id and only - # clear runs have larger/smaller ids. - exclude_run_id_stmt = select(DagRun.run_id).where(DagRun.logical_date == logical_date) - if future: - clear_kwargs["start_date"] = logical_date - exclude_run_id_stmt = exclude_run_id_stmt.where(DagRun.id > dr_id) - else: - clear_kwargs["end_date"] = logical_date - exclude_run_id_stmt = exclude_run_id_stmt.where(DagRun.id < dr_id) - subset.clear(exclude_run_ids=frozenset(session.scalars(exclude_run_id_stmt)), **clear_kwargs) - return altered - - @overload - def _get_task_instances( - self, - *, - task_ids: Collection[str | tuple[str, int]] | None, - start_date: datetime.datetime | None, - end_date: datetime.datetime | None, - run_id: str | None, - state: TaskInstanceState | Sequence[TaskInstanceState], - exclude_task_ids: Collection[str | tuple[str, int]] | None, - exclude_run_ids: frozenset[str] | None, - session: Session, - ) -> Iterable[TaskInstance]: ... # pragma: no cover - - @overload - def _get_task_instances( - self, - *, - task_ids: Collection[str | tuple[str, int]] | None, - as_pk_tuple: Literal[True], - start_date: datetime.datetime | None, - end_date: datetime.datetime | None, - run_id: str | None, - state: TaskInstanceState | Sequence[TaskInstanceState], - exclude_task_ids: Collection[str | tuple[str, int]] | None, - exclude_run_ids: frozenset[str] | None, - session: Session, - ) -> set[TaskInstanceKey]: ... # pragma: no cover - - def _get_task_instances( - self, - *, - task_ids: Collection[str | tuple[str, int]] | None, - as_pk_tuple: Literal[True, None] = None, - start_date: datetime.datetime | None, - end_date: datetime.datetime | None, - run_id: str | None, - state: TaskInstanceState | Sequence[TaskInstanceState], - exclude_task_ids: Collection[str | tuple[str, int]] | None, - exclude_run_ids: frozenset[str] | None, - session: Session, - ) -> Iterable[TaskInstance] | set[TaskInstanceKey]: - from airflow.models.taskinstance import TaskInstance - - # If we are looking at dependent dags we want to avoid UNION calls - # in SQL (it doesn't play nice with fields that have no equality operator, - # like JSON types), we instead build our result set separately. - # - # This will be empty if we are only looking at one dag, in which case - # we can return the filtered TI query object directly. - result: set[TaskInstanceKey] = set() - - # Do we want full objects, or just the primary columns? - if as_pk_tuple: - tis_pk = select( - TaskInstance.dag_id, - TaskInstance.task_id, - TaskInstance.run_id, - TaskInstance.map_index, - ) - tis_pk = tis_pk.join(TaskInstance.dag_run) - else: - tis_full = select(TaskInstance) - tis_full = tis_full.join(TaskInstance.dag_run) - - # Apply common filters - def apply_filters(query): - if self.partial: - query = query.where( - TaskInstance.dag_id == self.dag_id, TaskInstance.task_id.in_(self.task_ids) - ) - else: - query = query.where(TaskInstance.dag_id == self.dag_id) - if run_id: - query = query.where(TaskInstance.run_id == run_id) - if start_date: - query = query.where(DagRun.logical_date >= start_date) - if task_ids is not None: - # Use the selector condition directly without intermediate variable - query = query.where(TaskInstance.ti_selector_condition(task_ids)) - if end_date: - query = query.where(DagRun.logical_date <= end_date) - return query - - if as_pk_tuple: - tis_pk = apply_filters(tis_pk) - else: - tis_full = apply_filters(tis_full) - - def apply_state_filter(query): - if state: - if isinstance(state, (str, TaskInstanceState)): - query = query.where(TaskInstance.state == state) - elif len(state) == 1: - query = query.where(TaskInstance.state == state[0]) - else: - # this is required to deal with NULL values - if None in state: - if all(x is None for x in state): - query = query.where(TaskInstance.state.is_(None)) - else: - not_none_state = [s for s in state if s] - query = query.where( - or_(TaskInstance.state.in_(not_none_state), TaskInstance.state.is_(None)) - ) - else: - query = query.where(TaskInstance.state.in_(state)) - - if exclude_run_ids: - query = query.where(TaskInstance.run_id.not_in(exclude_run_ids)) - return query - - if as_pk_tuple: - tis_pk = apply_state_filter(tis_pk) - else: - tis_full = apply_state_filter(tis_full) - - if result or as_pk_tuple: - # Only execute the `ti` query if we have also collected some other results - if as_pk_tuple: - tis_query = session.execute(tis_pk).all() - result.update(TaskInstanceKey(**cols._mapping) for cols in tis_query) - else: - result.update(ti.key for ti in session.scalars(tis_full)) - - if exclude_task_ids is not None: - result = { - task - for task in result - if task.task_id not in exclude_task_ids - and (task.task_id, task.map_index) not in exclude_task_ids - } - - if as_pk_tuple: - return result - if result: - # We've been asked for objects, lets combine it all back in to a result set - ti_filters = TaskInstance.filter_for_tis(result) - if ti_filters is not None: - tis_final = select(TaskInstance).where(ti_filters) - return session.scalars(tis_final) - elif exclude_task_ids is None: - pass # Disable filter if not set. - elif isinstance(next(iter(exclude_task_ids), None), str): - tis_full = tis_full.where(TaskInstance.task_id.notin_(exclude_task_ids)) - else: - tis_full = tis_full.where( - tuple_(TaskInstance.task_id, TaskInstance.map_index).not_in(exclude_task_ids) - ) - - return session.scalars(tis_full) - - @overload - def clear( - self, - *, - dry_run: Literal[True], - task_ids: Collection[str | tuple[str, int]] | None = None, - run_id: str, - only_failed: bool = False, - only_running: bool = False, - dag_run_state: DagRunState = DagRunState.QUEUED, - session: Session = NEW_SESSION, - exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None = frozenset(), - exclude_run_ids: frozenset[str] | None = frozenset(), - run_on_latest_version: bool = False, - ) -> list[TaskInstance]: ... # pragma: no cover - - @overload - def clear( - self, - *, - task_ids: Collection[str | tuple[str, int]] | None = None, - run_id: str, - only_failed: bool = False, - only_running: bool = False, - dag_run_state: DagRunState = DagRunState.QUEUED, - dry_run: Literal[False] = False, - session: Session = NEW_SESSION, - exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None = frozenset(), - exclude_run_ids: frozenset[str] | None = frozenset(), - run_on_latest_version: bool = False, - ) -> int: ... # pragma: no cover - - @overload - def clear( - self, - *, - dry_run: Literal[True], - task_ids: Collection[str | tuple[str, int]] | None = None, - start_date: datetime.datetime | None = None, - end_date: datetime.datetime | None = None, - only_failed: bool = False, - only_running: bool = False, - dag_run_state: DagRunState = DagRunState.QUEUED, - session: Session = NEW_SESSION, - exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None = frozenset(), - exclude_run_ids: frozenset[str] | None = frozenset(), - run_on_latest_version: bool = False, - ) -> list[TaskInstance]: ... # pragma: no cover - - @overload - def clear( - self, - *, - task_ids: Collection[str | tuple[str, int]] | None = None, - start_date: datetime.datetime | None = None, - end_date: datetime.datetime | None = None, - only_failed: bool = False, - only_running: bool = False, - dag_run_state: DagRunState = DagRunState.QUEUED, - dry_run: Literal[False] = False, - session: Session = NEW_SESSION, - exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None = frozenset(), - exclude_run_ids: frozenset[str] | None = frozenset(), - run_on_latest_version: bool = False, - ) -> int: ... # pragma: no cover - - @provide_session - def clear( - self, - task_ids: Collection[str | tuple[str, int]] | None = None, - *, - run_id: str | None = None, - start_date: datetime.datetime | None = None, - end_date: datetime.datetime | None = None, - only_failed: bool = False, - only_running: bool = False, - dag_run_state: DagRunState = DagRunState.QUEUED, - dry_run: bool = False, - session: Session = NEW_SESSION, - exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None = frozenset(), - exclude_run_ids: frozenset[str] | None = frozenset(), - run_on_latest_version: bool = False, - ) -> int | Iterable[TaskInstance]: - """ - Clear a set of task instances associated with the current dag for a specified date range. - - :param task_ids: List of task ids or (``task_id``, ``map_index``) tuples to clear - :param run_id: The run_id for which the tasks should be cleared - :param start_date: The minimum logical_date to clear - :param end_date: The maximum logical_date to clear - :param only_failed: Only clear failed tasks - :param only_running: Only clear running tasks. - :param dag_run_state: state to set DagRun to. If set to False, dagrun state will not - be changed. - :param dry_run: Find the tasks to clear but don't clear them. - :param run_on_latest_version: whether to run on latest serialized DAG and Bundle version - :param session: The sqlalchemy session to use - :param exclude_task_ids: A set of ``task_id`` or (``task_id``, ``map_index``) - tuples that should not be cleared - :param exclude_run_ids: A set of ``run_id`` or (``run_id``) - """ - from airflow.models.taskinstance import clear_task_instances - - state: list[TaskInstanceState] = [] - if only_failed: - state += [TaskInstanceState.FAILED, TaskInstanceState.UPSTREAM_FAILED] - if only_running: - # Yes, having `+=` doesn't make sense, but this was the existing behaviour - state += [TaskInstanceState.RUNNING] - - tis_result = self._get_task_instances( - task_ids=task_ids, - start_date=start_date, - end_date=end_date, - run_id=run_id, - state=state, - session=session, - exclude_task_ids=exclude_task_ids, - exclude_run_ids=exclude_run_ids, - ) - - if dry_run: - return list(tis_result) - - tis = list(tis_result) - - count = len(tis) - if count == 0: - return 0 - - clear_task_instances( - list(tis), - session, - dag_run_state=dag_run_state, - run_on_latest_version=run_on_latest_version, - ) - - session.flush() - return count - - @classmethod - def clear_dags( - cls, - dags, - start_date=None, - end_date=None, - only_failed=False, - only_running=False, - dag_run_state=DagRunState.QUEUED, - dry_run=False, - ): - def _coerce_dag(dag): - if isinstance(dag, SerializedDAG): - return dag - return SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag)) - - if dry_run: - tis = itertools.chain.from_iterable( - _coerce_dag(dag).clear( - start_date=start_date, - end_date=end_date, - only_failed=only_failed, - only_running=only_running, - dag_run_state=dag_run_state, - dry_run=True, - ) - for dag in dags - ) - return list(tis) - - return sum( - _coerce_dag(dag).clear( - start_date=start_date, - end_date=end_date, - only_failed=only_failed, - only_running=only_running, - dag_run_state=dag_run_state, - dry_run=False, - ) - for dag in dags - ) - - def get_edge_info(self, upstream_task_id: str, downstream_task_id: str) -> EdgeInfoType: - """Return edge information for the given pair of tasks or an empty edge if there is no information.""" - # Note - older serialized dags may not have edge_info being a dict at all - empty = cast("EdgeInfoType", {}) - if self.edge_info: - return self.edge_info.get(upstream_task_id, {}).get(downstream_task_id, empty) - return empty - class TaskGroupSerialization(BaseSerialization): """JSON serializable representation of a task group.""" @@ -3816,7 +2725,7 @@ class LazyDeserializedDAG(pydantic.BaseModel): def from_dag(cls, dag: DAG | LazyDeserializedDAG) -> LazyDeserializedDAG: if isinstance(dag, LazyDeserializedDAG): return dag - return cls(data=SerializedDAG.to_dict(dag)) + return cls(data=DagSerialization.to_dict(dag)) @property def hash(self) -> str: @@ -3835,7 +2744,7 @@ def access_control(self) -> Mapping[str, Mapping[str, Collection[str]] | Collect @cached_property def _real_dag(self): try: - return SerializedDAG.from_dict(self.data) + return DagSerialization.from_dict(self.data) except Exception: log.exception("Failed to deserialize DAG") raise diff --git a/airflow-core/tests/integration/otel/test_otel.py b/airflow-core/tests/integration/otel/test_otel.py index 8d6f665811570..a35f2a4785bc4 100644 --- a/airflow-core/tests/integration/otel/test_otel.py +++ b/airflow-core/tests/integration/otel/test_otel.py @@ -34,11 +34,12 @@ from airflow.models import DAG, DagRun from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import TaskInstance -from airflow.serialization.serialized_objects import SerializedDAG +from airflow.serialization.definitions.dag import SerializedDAG from airflow.utils.session import create_session from airflow.utils.span_status import SpanStatus from airflow.utils.state import State +from tests_common.test_utils.dag import create_scheduler_dag from tests_common.test_utils.otel_utils import ( assert_parent_children_spans, assert_parent_children_spans_for_non_root, @@ -673,7 +674,7 @@ def serialize_and_get_dags(cls) -> dict[str, SerializedDAG]: SerializedDAG.bulk_write_to_db( bundle_name="testing", bundle_version=None, dags=[dag], session=session ) - dag_dict[dag_id] = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag)) + dag_dict[dag_id] = create_scheduler_dag(dag) else: dag.sync_to_db(session=session) dag_dict[dag_id] = dag diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py index 4abfd6ee8085c..597c1f819874a 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py @@ -1725,7 +1725,7 @@ def test_post_dag_runs_with_empty_payload(self, test_client): }, ] - @mock.patch("airflow.serialization.serialized_objects.SerializedDAG.create_dagrun") + @mock.patch("airflow.serialization.definitions.dag.SerializedDAG.create_dagrun") def test_dagrun_creation_exception_is_handled(self, mock_create_dagrun, test_client): now = timezone.utcnow().isoformat() error_message = "Encountered Error" diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py index 778024eebc6bb..c5ac6f419c415 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py @@ -3956,7 +3956,7 @@ def test_patch_task_instance_notifies_listeners(self, test_client, session, stat assert response2.json()["state"] == state assert listener.state == listener_state - @mock.patch("airflow.serialization.serialized_objects.SerializedDAG.set_task_instance_state") + @mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state") def test_should_call_mocked_api(self, mock_set_ti_state, test_client, session): self.create_task_instances(session) @@ -4302,7 +4302,7 @@ def test_should_raise_422_for_invalid_task_instance_state(self, payload, expecte ), ], ) - @mock.patch("airflow.serialization.serialized_objects.SerializedDAG.set_task_instance_state") + @mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state") def test_update_mask_should_call_mocked_api( self, mock_set_ti_state, @@ -4614,7 +4614,7 @@ def test_set_note_should_respond_200_when_note_is_empty(self, test_client, sessi assert response_ti["note"] == new_note_value _check_task_instance_note(session, response_ti["id"], {"content": new_note_value, "user_id": "test"}) - @mock.patch("airflow.serialization.serialized_objects.SerializedDAG.set_task_instance_state") + @mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state") def test_should_raise_409_for_updating_same_task_instance_state( self, mock_set_ti_state, test_client, session ): @@ -4640,7 +4640,7 @@ class TestPatchTaskInstanceDryRun(TestTaskInstanceEndpoint): RUN_ID = "TEST_DAG_RUN_ID" DAG_DISPLAY_NAME = "example_python_operator" - @mock.patch("airflow.serialization.serialized_objects.SerializedDAG.set_task_instance_state") + @mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state") def test_should_call_mocked_api(self, mock_set_ti_state, test_client, session): self.create_task_instances(session) @@ -4998,7 +4998,7 @@ def test_should_raise_422_for_invalid_task_instance_state(self, payload, expecte ), ], ) - @mock.patch("airflow.serialization.serialized_objects.SerializedDAG.set_task_instance_state") + @mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state") def test_update_mask_should_call_mocked_api( self, mock_set_ti_state, @@ -5031,7 +5031,7 @@ def test_update_mask_should_call_mocked_api( assert response.json() == expected_json assert mock_set_ti_state.call_count == set_ti_state_call_count - @mock.patch("airflow.serialization.serialized_objects.SerializedDAG.set_task_instance_state") + @mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state") def test_should_return_empty_list_for_updating_same_task_instance_state( self, mock_set_ti_state, test_client, session ): diff --git a/airflow-core/tests/unit/cli/commands/test_task_command.py b/airflow-core/tests/unit/cli/commands/test_task_command.py index 1592e9d7cdac3..0e32eb40be829 100644 --- a/airflow-core/tests/unit/cli/commands/test_task_command.py +++ b/airflow-core/tests/unit/cli/commands/test_task_command.py @@ -43,7 +43,7 @@ from airflow.models.dagbag import DBDagBag from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.standard.operators.bash import BashOperator -from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG +from airflow.serialization.serialized_objects import DagSerialization, LazyDeserializedDAG from airflow.utils.session import create_session from airflow.utils.state import State, TaskInstanceState from airflow.utils.types import DagRunTriggeredByType, DagRunType @@ -354,7 +354,7 @@ def test_task_states_for_dag_run(self): SerializedDagModel.write_dag(lazy_deserialized_dag2, bundle_name="testing") task2 = dag2.get_task(task_id="print_the_context") - dag2 = SerializedDAG.from_dict(lazy_deserialized_dag2.data) + dag2 = DagSerialization.from_dict(lazy_deserialized_dag2.data) default_date2 = timezone.datetime(2016, 1, 9) dag2.clear() diff --git a/airflow-core/tests/unit/dag_processing/test_collection.py b/airflow-core/tests/unit/dag_processing/test_collection.py index d4190de99095a..4da728294ac5e 100644 --- a/airflow-core/tests/unit/dag_processing/test_collection.py +++ b/airflow-core/tests/unit/dag_processing/test_collection.py @@ -391,7 +391,7 @@ def _sync_to_db(): serialized_dags_count = session.query(func.count(SerializedDagModel.dag_id)).scalar() @patch.object(SerializedDagModel, "write_dag") - @patch("airflow.serialization.serialized_objects.SerializedDAG.bulk_write_to_db") + @patch("airflow.serialization.definitions.dag.SerializedDAG.bulk_write_to_db") def test_sync_to_db_is_retried( self, mock_bulk_write_to_db, mock_s10n_write_dag, testing_dag_bundle, session ): diff --git a/airflow-core/tests/unit/models/test_cleartasks.py b/airflow-core/tests/unit/models/test_cleartasks.py index 4d8063f841982..d387480550201 100644 --- a/airflow-core/tests/unit/models/test_cleartasks.py +++ b/airflow-core/tests/unit/models/test_cleartasks.py @@ -606,7 +606,7 @@ def test_dags_clear(self, dag_maker, session): ) ti = dr.task_instances[0] ti.task = task - dags.append(dag_maker.dag) + dags.append(dag_maker.serialized_model.dag) tis.append(ti) # test clear all dags diff --git a/airflow-core/tests/unit/models/test_dag.py b/airflow-core/tests/unit/models/test_dag.py index 33b35d954bbb8..caff1dff704b1 100644 --- a/airflow-core/tests/unit/models/test_dag.py +++ b/airflow-core/tests/unit/models/test_dag.py @@ -72,8 +72,9 @@ from airflow.sdk.definitions.callback import AsyncCallback from airflow.sdk.definitions.deadline import DeadlineAlert, DeadlineReference from airflow.sdk.definitions.param import Param +from airflow.serialization.definitions.dag import SerializedDAG from airflow.serialization.encoders import coerce_to_core_timetable -from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG +from airflow.serialization.serialized_objects import LazyDeserializedDAG from airflow.task.trigger_rule import TriggerRule from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable from airflow.timetables.simple import ( @@ -391,10 +392,8 @@ def test_dagtag_repr(self, testing_dag_bundle): def test_bulk_write_to_db(self, testing_dag_bundle): clear_db_dags() dags = [ - SerializedDAG.deserialize_dag( - SerializedDAG.serialize_dag( - DAG(f"dag-bulk-sync-{i}", schedule=None, start_date=DEFAULT_DATE, tags=["test-dag"]) - ) + create_scheduler_dag( + DAG(f"dag-bulk-sync-{i}", schedule=None, start_date=DEFAULT_DATE, tags=["test-dag"]) ) for i in range(4) ] @@ -478,10 +477,8 @@ def test_bulk_write_to_db_single_dag(self, testing_dag_bundle): """ clear_db_dags() dags = [ - SerializedDAG.deserialize_dag( - SerializedDAG.serialize_dag( - DAG(f"dag-bulk-sync-{i}", schedule=None, start_date=DEFAULT_DATE, tags=["test-dag"]) - ) + create_scheduler_dag( + DAG(f"dag-bulk-sync-{i}", schedule=None, start_date=DEFAULT_DATE, tags=["test-dag"]) ) for i in range(1) ] @@ -509,10 +506,8 @@ def test_bulk_write_to_db_multiple_dags(self, testing_dag_bundle): """ clear_db_dags() dags = [ - SerializedDAG.deserialize_dag( - SerializedDAG.serialize_dag( - DAG(f"dag-bulk-sync-{i}", schedule=None, start_date=DEFAULT_DATE, tags=["test-dag"]) - ) + create_scheduler_dag( + DAG(f"dag-bulk-sync-{i}", schedule=None, start_date=DEFAULT_DATE, tags=["test-dag"]) ) for i in range(4) ] @@ -544,14 +539,8 @@ def test_bulk_write_to_db_interval_save_runtime(self, testing_dag_bundle, interv mock_active_runs_of_dags = mock.MagicMock(side_effect=DagRun.active_runs_of_dags) with mock.patch.object(DagRun, "active_runs_of_dags", mock_active_runs_of_dags): dags_null_timetable = [ - SerializedDAG.deserialize_dag( - SerializedDAG.serialize_dag(DAG("dag-interval-None", schedule=None, start_date=TEST_DATE)) - ), - SerializedDAG.deserialize_dag( - SerializedDAG.serialize_dag( - DAG("dag-interval-test", schedule=interval, start_date=TEST_DATE) - ) - ), + create_scheduler_dag(DAG("dag-interval-None", schedule=None, start_date=TEST_DATE)), + create_scheduler_dag(DAG("dag-interval-test", schedule=interval, start_date=TEST_DATE)), ] SerializedDAG.bulk_write_to_db("testing", None, dags_null_timetable) if interval: @@ -690,7 +679,7 @@ def test_bulk_write_to_db_assets(self, testing_dag_bundle): ) session = settings.Session() - SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag1)).clear() + create_scheduler_dag(dag1).clear() SerializedDAG.bulk_write_to_db("testing", None, [dag1, dag2], session=session) session.commit() stored_assets = {x.uri: x for x in session.query(AssetModel).all()} @@ -2756,12 +2745,12 @@ def _get_registered_timetable(s): assert caplog.record_tuples == [ ( - "airflow.serialization.serialized_objects", + "airflow.serialization.definitions.dag", logging.ERROR, f"Failed to fetch run info after data interval {DataInterval(start, end)} for DAG {dag.dag_id!r}", ), ] - assert caplog.entries[0].get("exc_info") is not None, "should contain exception context" + assert caplog.entries[0].get("exception"), "should contain exception context" @pytest.mark.parametrize( @@ -2940,7 +2929,7 @@ def test_create_dagrun_disallow_manual_to_use_automated_run_id(run_id_type: DagR f"A manual DAG run cannot use ID {run_id!r} since it is reserved for {run_id_type.value} runs" ), ): - SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag)).create_dagrun( + create_scheduler_dag(dag).create_dagrun( run_type=DagRunType.MANUAL, run_id=run_id, logical_date=DEFAULT_DATE, diff --git a/airflow-core/tests/unit/models/test_serialized_dag.py b/airflow-core/tests/unit/models/test_serialized_dag.py index 2472c0720461e..54b4d11de3938 100644 --- a/airflow-core/tests/unit/models/test_serialized_dag.py +++ b/airflow-core/tests/unit/models/test_serialized_dag.py @@ -37,7 +37,8 @@ from airflow.providers.standard.operators.python import PythonOperator from airflow.sdk import DAG, Asset, AssetAlias, task as task_decorator from airflow.serialization.dag_dependency import DagDependency -from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG +from airflow.serialization.definitions.dag import SerializedDAG +from airflow.serialization.serialized_objects import DagSerialization, LazyDeserializedDAG from airflow.settings import json from airflow.utils.hashlib_wrapper import md5 from airflow.utils.session import create_session @@ -45,7 +46,7 @@ from airflow.utils.types import DagRunTriggeredByType, DagRunType from tests_common.test_utils import db -from tests_common.test_utils.dag import sync_dag_to_db +from tests_common.test_utils.dag import create_scheduler_dag, sync_dag_to_db logger = logging.getLogger(__name__) @@ -65,7 +66,7 @@ def make_example_dags(module): dagbag = DagBag(module.__path__[0]) - dags = [LazyDeserializedDAG(data=SerializedDAG.to_dict(dag)) for dag in dagbag.dags.values()] + dags = [LazyDeserializedDAG(data=DagSerialization.to_dict(dag)) for dag in dagbag.dags.values()] SerializedDAG.bulk_write_to_db("testing", None, dags) return dagbag.dags @@ -105,7 +106,7 @@ def test_write_dag(self, testing_dag_bundle): assert result.dag_version.dag_code.fileloc == dag.fileloc # Verifies JSON schema. - SerializedDAG.validate_schema(result.data) + DagSerialization.validate_schema(result.data) def test_write_dag_when_python_callable_name_changes(self, dag_maker, session): def my_callable(): @@ -201,7 +202,7 @@ def test_read_all_dags_only_picks_the_latest_serdags(self, session): assert len(example_dags) == len(serialized_dags) dag = example_dags.get("example_bash_operator") - SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag=dag)).create_dagrun( + create_scheduler_dag(dag=dag).create_dagrun( run_id="test1", run_after=pendulum.datetime(2025, 1, 1, tz="UTC"), state=DagRunState.QUEUED, @@ -263,13 +264,13 @@ def test_order_of_deps_is_consistent(self, session): outlets=[Asset(uri="test://asset0", name="0*"), Asset(uri="test://asset6", name="6*")], bash_command="sleep 5", ) - deps_order = [x["label"] for x in SerializedDAG.serialize_dag(dag6)["dag_dependencies"]] + deps_order = [x["label"] for x in DagSerialization.serialize_dag(dag6)["dag_dependencies"]] # in below assert, 0 and 6 both come at end because "source" is different for them and source # is the first field in DagDependency class assert deps_order == ["1", "2", "3", "4", "5", "0*", "6*"] # for good measure, let's check that the dag hash is consistent - dag_json = json.dumps(SerializedDAG.to_dict(dag6), sort_keys=True).encode("utf-8") + dag_json = json.dumps(DagSerialization.to_dict(dag6), sort_keys=True).encode("utf-8") this_dag_hash = md5(dag_json).hexdigest() # set first dag hash on first pass @@ -568,7 +569,7 @@ def __init__(self, *, task_id: str, **kwargs): with dag_maker("test_dag") as dag1: MyCustomOp(task_id="task1") - serialized_dag_1 = SerializedDAG.to_dict(dag1) + serialized_dag_1 = DagSerialization.to_dict(dag1) # Create second DAG with env_vars in different order with dag_maker("test_dag") as dag2: @@ -576,7 +577,7 @@ def __init__(self, *, task_id: str, **kwargs): # Recreate dict with different insertion order task.env_vars = {"KEY3": "value3", "KEY1": "value1", "KEY2": "value2"} - serialized_dag_2 = SerializedDAG.to_dict(dag2) + serialized_dag_2 = DagSerialization.to_dict(dag2) # Verify that the original env_vars have different ordering env_vars_1 = None diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py b/airflow-core/tests/unit/serialization/test_dag_serialization.py index 81f7f398ed38b..73d45db5ec6e6 100644 --- a/airflow-core/tests/unit/serialization/test_dag_serialization.py +++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py @@ -69,14 +69,15 @@ from airflow.sdk.definitions.param import Param, ParamsDict from airflow.security import permissions from airflow.serialization.definitions.assets import SerializedAssetUniqueKey +from airflow.serialization.definitions.dag import SerializedDAG from airflow.serialization.definitions.notset import NOTSET from airflow.serialization.encoders import ensure_serialized_asset from airflow.serialization.enums import Encoding from airflow.serialization.json_schema import load_dag_schema_dict from airflow.serialization.serialized_objects import ( BaseSerialization, + DagSerialization, SerializedBaseOperator, - SerializedDAG, SerializedParam, XComOperatorLink, ) @@ -459,7 +460,7 @@ def serialize_subprocess(queue, dag_folder): """Validate pickle in a subprocess.""" dags, _ = collect_dags(dag_folder) for dag in dags.values(): - queue.put(SerializedDAG.to_json(dag)) + queue.put(DagSerialization.to_json(dag)) queue.put(None) @@ -512,8 +513,8 @@ def test_serialization(self): dags, import_errors = collect_dags() serialized_dags = {} for v in dags.values(): - dag = SerializedDAG.to_dict(v) - SerializedDAG.validate_schema(dag) + dag = DagSerialization.to_dict(v) + DagSerialization.validate_schema(dag) serialized_dags[v.dag_id] = dag # Ignore some errors. @@ -563,8 +564,8 @@ def test_serialization(self): def test_dag_serialization_to_timetable(self, timetable, serialized_timetable): """Verify a timetable-backed DAG is serialized correctly.""" dag = get_timetable_based_simple_dag(timetable) - serialized_dag = SerializedDAG.to_dict(dag) - SerializedDAG.validate_schema(serialized_dag) + serialized_dag = DagSerialization.to_dict(dag) + DagSerialization.validate_schema(serialized_dag) expected = copy.deepcopy(serialized_simple_dag_ground_truth) expected["dag"]["timetable"] = serialized_timetable @@ -584,8 +585,8 @@ def test_dag_serialization_preserves_empty_access_roles(self): """Verify that an explicitly empty access_control dict is preserved.""" dag = make_simple_dag() dag.access_control = {} - serialized_dag = SerializedDAG.to_dict(dag) - SerializedDAG.validate_schema(serialized_dag) + serialized_dag = DagSerialization.to_dict(dag) + DagSerialization.validate_schema(serialized_dag) assert serialized_dag["dag"]["access_control"] == { "__type": "dict", @@ -597,7 +598,7 @@ def test_dag_serialization_unregistered_custom_timetable(self): """Verify serialization fails without timetable registration.""" dag = get_timetable_based_simple_dag(CustomSerializationTimetable("bar")) with pytest.raises(SerializationError) as ctx: - SerializedDAG.to_dict(dag) + DagSerialization.to_dict(dag) message = ( "Failed to serialize DAG 'simple_dag': Timetable class " @@ -655,7 +656,7 @@ def test_deserialization_across_process(self): v = queue.get() if v is None: break - dag = SerializedDAG.from_json(v) + dag = DagSerialization.from_json(v) assert isinstance(dag, SerializedDAG) stringified_dags[dag.dag_id] = dag @@ -678,7 +679,7 @@ def test_roundtrip_provider_example_dags(self): # Verify deserialized DAGs. for dag in dags.values(): - serialized_dag = SerializedDAG.from_json(SerializedDAG.to_json(dag)) + serialized_dag = DagSerialization.from_json(DagSerialization.to_json(dag)) self.validate_deserialized_dag(serialized_dag, dag) # Let's not be exact about this, but if everything fails to parse we should fail this test too @@ -693,7 +694,7 @@ def test_roundtrip_provider_example_dags(self): def test_dag_roundtrip_from_timetable(self, timetable): """Verify a timetable-backed serialization can be deserialized.""" dag = get_timetable_based_simple_dag(timetable) - roundtripped = SerializedDAG.from_json(SerializedDAG.to_json(dag)) + roundtripped = DagSerialization.from_json(DagSerialization.to_json(dag)) self.validate_deserialized_dag(roundtripped, dag) def validate_deserialized_dag(self, serialized_dag: SerializedDAG, dag: DAG): @@ -902,7 +903,7 @@ def test_deserialization_start_date(self, dag_start_date, task_start_date, expec dag = DAG(dag_id="simple_dag", schedule=None, start_date=dag_start_date) BaseOperator(task_id="simple_task", dag=dag, start_date=task_start_date) - serialized_dag = SerializedDAG.to_dict(dag) + serialized_dag = DagSerialization.to_dict(dag) if not task_start_date or dag_start_date >= task_start_date: # If dag.start_date > task.start_date -> task.start_date=dag.start_date # because of the logic in dag.add_task() @@ -910,7 +911,7 @@ def test_deserialization_start_date(self, dag_start_date, task_start_date, expec else: assert "start_date" in serialized_dag["dag"]["tasks"][0]["__var"] - dag = SerializedDAG.from_dict(serialized_dag) + dag = DagSerialization.from_dict(serialized_dag) simple_task = dag.task_dict["simple_task"] assert simple_task.start_date == expected_task_start_date @@ -922,7 +923,7 @@ def test_deserialization_with_dag_context(self): ) as dag: BaseOperator(task_id="simple_task") # should not raise RuntimeError: dictionary changed size during iteration - SerializedDAG.to_dict(dag) + DagSerialization.to_dict(dag) @pytest.mark.parametrize( ("dag_end_date", "task_end_date", "expected_task_end_date"), @@ -953,7 +954,7 @@ def test_deserialization_end_date(self, dag_end_date, task_end_date, expected_ta ) BaseOperator(task_id="simple_task", dag=dag, end_date=task_end_date) - serialized_dag = SerializedDAG.to_dict(dag) + serialized_dag = DagSerialization.to_dict(dag) if not task_end_date or dag_end_date <= task_end_date: # If dag.end_date < task.end_date -> task.end_date=dag.end_date # because of the logic in dag.add_task() @@ -961,7 +962,7 @@ def test_deserialization_end_date(self, dag_end_date, task_end_date, expected_ta else: assert "end_date" in serialized_dag["dag"]["tasks"][0]["__var"] - dag = SerializedDAG.from_dict(serialized_dag) + dag = DagSerialization.from_dict(serialized_dag) simple_task = dag.task_dict["simple_task"] assert simple_task.end_date == expected_task_end_date @@ -1010,8 +1011,8 @@ def test_deserialization_timetable( "timetable": serialized_timetable, }, } - SerializedDAG.validate_schema(serialized) - dag = SerializedDAG.from_dict(serialized) + DagSerialization.validate_schema(serialized) + dag = DagSerialization.from_dict(serialized) assert dag.timetable == expected_timetable @pytest.mark.parametrize( @@ -1059,8 +1060,8 @@ def test_deserialization_timetable_summary( "timetable": serialized_timetable, }, } - SerializedDAG.validate_schema(serialized) - dag = SerializedDAG.from_dict(serialized) + DagSerialization.validate_schema(serialized) + dag = DagSerialization.from_dict(serialized) assert dag.timetable.summary == expected_timetable_summary def test_deserialization_timetable_unregistered(self): @@ -1075,7 +1076,7 @@ def test_deserialization_timetable_unregistered(self): "timetable": CUSTOM_TIMETABLE_SERIALIZED, }, } - SerializedDAG.validate_schema(serialized) + DagSerialization.validate_schema(serialized) message = ( "Timetable class " "'tests_common.test_utils.timetables.CustomSerializationTimetable' " @@ -1084,7 +1085,7 @@ def test_deserialization_timetable_unregistered(self): "Please check the airflow best practices documentation." ) with pytest.raises(ValueError, match=message): - SerializedDAG.from_dict(serialized) + DagSerialization.from_dict(serialized) @pytest.mark.parametrize( ("val", "expected"), @@ -1110,10 +1111,10 @@ def test_deserialization_timetable_unregistered(self): ], ) def test_roundtrip_relativedelta(self, val, expected): - serialized = SerializedDAG.serialize(val) + serialized = DagSerialization.serialize(val) assert serialized == expected - round_tripped = SerializedDAG.deserialize(serialized) + round_tripped = DagSerialization.deserialize(serialized) assert val == round_tripped @pytest.mark.parametrize( @@ -1136,13 +1137,13 @@ def test_dag_params_roundtrip(self, val, expected_val): dag = DAG(dag_id="simple_dag", schedule=None, params=val) BaseOperator(task_id="simple_task", dag=dag, start_date=datetime(2019, 8, 1)) - serialized_dag_json = SerializedDAG.to_json(dag) + serialized_dag_json = DagSerialization.to_json(dag) serialized_dag = json.loads(serialized_dag_json) assert "params" in serialized_dag["dag"] - deserialized_dag = SerializedDAG.from_dict(serialized_dag) + deserialized_dag = DagSerialization.from_dict(serialized_dag) deserialized_simple_task = deserialized_dag.task_dict["simple_task"] assert expected_val == deserialized_dag.params.dump() assert expected_val == deserialized_simple_task.params.dump() @@ -1164,7 +1165,7 @@ def __init__(self, path: str): ) with pytest.raises(SerializationError): - SerializedDAG.to_dict(dag) + DagSerialization.to_dict(dag) dag = DAG(dag_id="simple_dag", schedule=None) BaseOperator( @@ -1189,10 +1190,10 @@ def test_full_param_roundtrip(self, param: Param): Test to make sure that only native Param objects are being passed as dag or task params """ sdk_dag = DAG(dag_id="simple_dag", schedule=None, params={"my_param": param}) - serialized_json = SerializedDAG.to_json(sdk_dag) + serialized_json = DagSerialization.to_json(sdk_dag) serialized = json.loads(serialized_json) - SerializedDAG.validate_schema(serialized) - dag = SerializedDAG.from_dict(serialized) + DagSerialization.validate_schema(serialized) + dag = DagSerialization.from_dict(serialized) assert dag.params.get_param("my_param").value == param.value observed_param = dag.params.get_param("my_param") @@ -1235,8 +1236,8 @@ def test_task_params_roundtrip(self, val, expected_val): params=val, start_date=datetime(2019, 8, 1), ) - serialized_dag = SerializedDAG.to_dict(dag) - deserialized_dag = SerializedDAG.from_dict(serialized_dag) + serialized_dag = DagSerialization.to_dict(dag) + deserialized_dag = DagSerialization.from_dict(serialized_dag) if val: assert "params" in serialized_dag["dag"]["tasks"][0]["__var"] @@ -1290,10 +1291,10 @@ def test_extra_serialized_field_and_operator_links( with dag_maker(dag_id="simple_dag", start_date=test_date) as dag: CustomOperator(task_id="simple_task", bash_command=bash_command) - serialized_dag = SerializedDAG.to_dict(dag) + serialized_dag = DagSerialization.to_dict(dag) assert "bash_command" in serialized_dag["dag"]["tasks"][0]["__var"] - dag = SerializedDAG.from_dict(serialized_dag) + dag = DagSerialization.from_dict(serialized_dag) simple_task = dag.task_dict["simple_task"] assert getattr(simple_task, "bash_command") == bash_command @@ -1435,8 +1436,8 @@ def test_templated_fields_exist_in_serialized_dag(self, templated_field, expecte with dag: BashOperator(task_id="test", bash_command=templated_field) - serialized_dag = SerializedDAG.to_dict(dag) - deserialized_dag = SerializedDAG.from_dict(serialized_dag) + serialized_dag = DagSerialization.to_dict(dag) + deserialized_dag = DagSerialization.from_dict(serialized_dag) deserialized_test_task = deserialized_dag.task_dict["test"] assert expected_field == getattr(deserialized_test_task, "bash_command") @@ -1587,7 +1588,10 @@ def test_operator_deserialize_old_names(self): "ui_fgcolor": "#000", } - SerializedDAG._json_schema.validate(blob, _schema=load_dag_schema_dict()["definitions"]["operator"]) + DagSerialization._json_schema.validate( + blob, + _schema=load_dag_schema_dict()["definitions"]["operator"], + ) serialized_op = SerializedBaseOperator.deserialize_operator(blob) assert serialized_op.downstream_task_ids == {"foo"} @@ -1602,9 +1606,9 @@ def test_task_resources(self): with DAG("test_task_resources", schedule=None, start_date=logical_date) as dag: task = EmptyOperator(task_id=task_id, resources={"cpus": 0.1, "ram": 2048}) - SerializedDAG.validate_schema(SerializedDAG.to_dict(dag)) + DagSerialization.validate_schema(DagSerialization.to_dict(dag)) - json_dag = SerializedDAG.from_json(SerializedDAG.to_json(dag)) + json_dag = DagSerialization.from_json(DagSerialization.to_json(dag)) deserialized_task = json_dag.get_task(task_id) assert deserialized_task.resources == task.resources assert isinstance(deserialized_task.resources, Resources) @@ -1629,12 +1633,12 @@ def test_task_group_serialization(self): task1 >> group234 group34 >> task5 - dag_dict = SerializedDAG.to_dict(dag) - SerializedDAG.validate_schema(dag_dict) - json_dag = SerializedDAG.from_json(SerializedDAG.to_json(dag)) + dag_dict = DagSerialization.to_dict(dag) + DagSerialization.validate_schema(dag_dict) + json_dag = DagSerialization.from_json(DagSerialization.to_json(dag)) self.validate_deserialized_dag(json_dag, dag) - serialized_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag)) + serialized_dag = DagSerialization.deserialize_dag(DagSerialization.serialize_dag(dag)) assert serialized_dag.task_group.children assert serialized_dag.task_group.children.keys() == dag.task_group.children.keys() @@ -1691,12 +1695,12 @@ def test_setup_teardown_tasks(self): EmptyOperator(task_id="task2") EmptyOperator(task_id="teardown2").as_teardown() - dag_dict = SerializedDAG.to_dict(dag) - SerializedDAG.validate_schema(dag_dict) - json_dag = SerializedDAG.from_json(SerializedDAG.to_json(dag)) + dag_dict = DagSerialization.to_dict(dag) + DagSerialization.validate_schema(dag_dict) + json_dag = DagSerialization.from_json(DagSerialization.to_json(dag)) self.validate_deserialized_dag(json_dag, dag) - serialized_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag)) + serialized_dag = DagSerialization.deserialize_dag(DagSerialization.serialize_dag(dag)) self.assert_taskgroup_children( serialized_dag.task_group, dag.task_group, {"setup", "teardown", "group1"} @@ -1738,12 +1742,12 @@ def mytask(): mytask() - dag_dict = SerializedDAG.to_dict(dag) - SerializedDAG.validate_schema(dag_dict) - json_dag = SerializedDAG.from_json(SerializedDAG.to_json(dag)) + dag_dict = DagSerialization.to_dict(dag) + DagSerialization.validate_schema(dag_dict) + json_dag = DagSerialization.from_json(DagSerialization.to_json(dag)) self.validate_deserialized_dag(json_dag, dag) - serialized_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag)) + serialized_dag = DagSerialization.deserialize_dag(DagSerialization.serialize_dag(dag)) task = serialized_dag.task_group.children["mytask"] assert task.is_teardown is True assert task.on_failure_fail_dagrun is True @@ -1755,10 +1759,10 @@ def test_basic_mapped_dag(self, dag_maker): ) assert not dagbag.import_errors dag = dagbag.dags["example_dynamic_task_mapping"] - ser_dag = SerializedDAG.to_dict(dag) + ser_dag = DagSerialization.to_dict(dag) # We should not include `_is_sensor` most of the time (as it would be wasteful). Check we don't assert "_is_sensor" not in ser_dag["dag"]["tasks"][0]["__var"] - SerializedDAG.validate_schema(ser_dag) + DagSerialization.validate_schema(ser_dag) @pytest.mark.db_test def test_teardown_mapped_serialization(self, dag_maker): @@ -1774,12 +1778,12 @@ def mytask(val=None): assert task.partial_kwargs["is_teardown"] is True assert task.partial_kwargs["on_failure_fail_dagrun"] is True - dag_dict = SerializedDAG.to_dict(dag) - SerializedDAG.validate_schema(dag_dict) - json_dag = SerializedDAG.from_json(SerializedDAG.to_json(dag)) + dag_dict = DagSerialization.to_dict(dag) + DagSerialization.validate_schema(dag_dict) + json_dag = DagSerialization.from_json(DagSerialization.to_json(dag)) self.validate_deserialized_dag(json_dag, dag) - serialized_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag)) + serialized_dag = DagSerialization.deserialize_dag(DagSerialization.serialize_dag(dag)) task = serialized_dag.task_group.children["mytask"] assert task.partial_kwargs["is_teardown"] is True assert task.partial_kwargs["on_failure_fail_dagrun"] is True @@ -1831,7 +1835,7 @@ class DerivedSensor(ExternalTaskSensor): task2 = EmptyOperator(task_id="task2") task1 >> task2 - dag = SerializedDAG.to_dict(dag) + dag = DagSerialization.to_dict(dag) assert dag["dag"]["dag_dependencies"] == [ { "source": "external_dag_id", @@ -1873,7 +1877,7 @@ def other_asset_writer(x): for asset in testing_assets ] - dag = SerializedDAG.to_dict(dag) + dag = DagSerialization.to_dict(dag) actual = sorted(dag["dag"]["dag_dependencies"], key=lambda x: tuple(x.values())) expected = sorted( [ @@ -1975,7 +1979,7 @@ def other_asset_writer(x): for asset in testing_assets ] - dag = SerializedDAG.to_dict(dag) + dag = DagSerialization.to_dict(dag) actual = sorted(dag["dag"]["dag_dependencies"], key=lambda x: tuple(x.values())) expected = sorted( [ @@ -2052,7 +2056,7 @@ class DerivedOperator(TriggerDagRunOperator): ) task1 >> task2 - dag = SerializedDAG.to_dict(dag) + dag = DagSerialization.to_dict(dag) assert dag["dag"]["dag_dependencies"] == [ { "source": "test_derived_dag_deps_trigger", @@ -2153,12 +2157,12 @@ def test_edge_info_serialization(self): task2 = EmptyOperator(task_id="task2") task1 >> Label("test label") >> task2 - dag_dict = SerializedDAG.to_dict(dag) - SerializedDAG.validate_schema(dag_dict) - json_dag = SerializedDAG.from_json(SerializedDAG.to_json(dag)) + dag_dict = DagSerialization.to_dict(dag) + DagSerialization.validate_schema(dag_dict) + json_dag = DagSerialization.from_json(DagSerialization.to_json(dag)) self.validate_deserialized_dag(json_dag, dag) - serialized_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag)) + serialized_dag = DagSerialization.deserialize_dag(DagSerialization.serialize_dag(dag)) assert serialized_dag.edge_info == dag.edge_info @@ -2220,13 +2224,13 @@ def test_dag_on_success_callback_roundtrip(self, passed_success_callback, expect ) BaseOperator(task_id="simple_task", dag=dag, start_date=datetime(2019, 8, 1)) - serialized_dag = SerializedDAG.to_dict(dag) + serialized_dag = DagSerialization.to_dict(dag) if expected_value: assert "has_on_success_callback" in serialized_dag["dag"] else: assert "has_on_success_callback" not in serialized_dag["dag"] - deserialized_dag = SerializedDAG.from_dict(serialized_dag) + deserialized_dag = DagSerialization.from_dict(serialized_dag) assert deserialized_dag.has_on_success_callback is expected_value @@ -2252,13 +2256,13 @@ def test_dag_on_failure_callback_roundtrip(self, passed_failure_callback, expect ) BaseOperator(task_id="simple_task", dag=dag, start_date=datetime(2019, 8, 1)) - serialized_dag = SerializedDAG.to_dict(dag) + serialized_dag = DagSerialization.to_dict(dag) if expected_value: assert "has_on_failure_callback" in serialized_dag["dag"] else: assert "has_on_failure_callback" not in serialized_dag["dag"] - deserialized_dag = SerializedDAG.from_dict(serialized_dag) + deserialized_dag = DagSerialization.from_dict(serialized_dag) assert deserialized_dag.has_on_failure_callback is expected_value @@ -2291,8 +2295,8 @@ def test_dag_disable_bundle_versioning_roundtrip(self, dag_arg, conf_arg, expect **kwargs, ) BaseOperator(task_id="simple_task", dag=dag, start_date=datetime(2019, 8, 1)) - serialized_dag = SerializedDAG.to_dict(dag) - deserialized_dag = SerializedDAG.from_dict(serialized_dag) + serialized_dag = DagSerialization.to_dict(dag) + deserialized_dag = DagSerialization.from_dict(serialized_dag) assert deserialized_dag.disable_bundle_versioning is expected @pytest.mark.parametrize( @@ -2341,7 +2345,7 @@ def test_dag_disable_bundle_versioning_roundtrip(self, dag_arg, conf_arg, expect ) def test_serialized_objects_are_sorted(self, object_to_serialized, expected_output): """Test Serialized Sets are sorted while list and tuple preserve order""" - serialized_obj = SerializedDAG.serialize(object_to_serialized) + serialized_obj = DagSerialization.serialize(object_to_serialized) if isinstance(serialized_obj, dict) and "__type" in serialized_obj: serialized_obj = serialized_obj["__var"] assert serialized_obj == expected_output @@ -2358,7 +2362,7 @@ def test_params_upgrade(self): "params": {"none": None, "str": "str", "dict": {"a": "b"}}, }, } - dag = SerializedDAG.from_dict(serialized) + dag = DagSerialization.from_dict(serialized) assert dag.params["none"] is None # After decoupling, server-side deserialization uses SerializedParam @@ -2385,7 +2389,7 @@ def test_params_serialization_from_dict_upgrade(self): }, }, } - dag = SerializedDAG.from_dict(serialized) + dag = DagSerialization.from_dict(serialized) param = dag.params.get_param("my_param") # After decoupling, server-side deserialization uses SerializedParam @@ -2408,8 +2412,8 @@ def test_params_serialize_default_2_2_0(self): "params": [["str", {"__class": "airflow.models.param.Param", "default": "str"}]], }, } - SerializedDAG.validate_schema(serialized) - dag = SerializedDAG.from_dict(serialized) + DagSerialization.validate_schema(serialized) + dag = DagSerialization.from_dict(serialized) # After decoupling, server-side deserialization uses SerializedParam assert isinstance(dag.params.get_param("str"), SerializedParam) @@ -2436,8 +2440,8 @@ def test_params_serialize_default(self): ], }, } - SerializedDAG.validate_schema(serialized) - dag = SerializedDAG.from_dict(serialized) + DagSerialization.validate_schema(serialized) + dag = DagSerialization.from_dict(serialized) assert dag.params["my_param"] == "a string value" param = dag.params.get_param("my_param") @@ -2481,7 +2485,7 @@ def execute(self, context: Context): ) ), ): - SerializedDAG.to_dict(dag) + DagSerialization.to_dict(dag) @pytest.mark.db_test def test_start_trigger_args_in_serialized_dag(self): @@ -2529,7 +2533,7 @@ def execute_complete(self): TestOperator(task_id="test_task_1") Test2Operator(task_id="test_task_2") - serialized_obj = SerializedDAG.to_dict(dag) + serialized_obj = DagSerialization.to_dict(dag) tasks = serialized_obj["dag"]["tasks"] assert tasks[0]["__var"]["start_trigger_args"] == { @@ -2590,7 +2594,7 @@ def mock__import__(name, globals_=None, locals_=None, fromlist=(), level=0): module.BaseSerialization.from_dict(pod_override) # basic serialization should succeed - module.SerializedDAG.to_dict(make_simple_dag()) + module.DagSerialization.to_dict(make_simple_dag()) def test_operator_expand_serde(): @@ -2681,7 +2685,7 @@ def test_operator_expand_xcomarg_serde(): xcom_ref = op.expand_input.value["arg2"] assert xcom_ref == _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY}) - serialized_dag: DAG = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + serialized_dag: DAG = DagSerialization.from_dict(DagSerialization.to_dict(dag)) xcom_arg = serialized_dag.task_dict["task_2"].expand_input.value["arg2"] assert isinstance(xcom_arg, SchedulerPlainXComArg) @@ -2740,7 +2744,7 @@ def test_operator_expand_kwargs_literal_serde(strict): {"a": _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY})}, ] - serialized_dag: DAG = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + serialized_dag: DAG = DagSerialization.from_dict(DagSerialization.to_dict(dag)) resolved_expand_value = serialized_dag.task_dict["task_2"].expand_input.value assert resolved_expand_value == [ @@ -2787,7 +2791,7 @@ def test_operator_expand_kwargs_xcomarg_serde(strict): xcom_ref = op.expand_input.value assert xcom_ref == _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY}) - serialized_dag: DAG = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + serialized_dag: DAG = DagSerialization.from_dict(DagSerialization.to_dict(dag)) xcom_arg = serialized_dag.task_dict["task_2"].expand_input.value assert isinstance(xcom_arg, SchedulerPlainXComArg) @@ -3059,7 +3063,7 @@ def tg(a: str) -> None: }, ) - serde_dag = SerializedDAG.deserialize_dag(ser_dag[Encoding.VAR]) + serde_dag = DagSerialization.deserialize_dag(ser_dag[Encoding.VAR]) serde_tg = serde_dag.task_group.children["tg"] assert isinstance(serde_tg, SerializedTaskGroup) assert serde_tg._expand_input == SchedulerDictOfListsExpandInput({"a": [".", ".."]}) @@ -3096,7 +3100,7 @@ def operator_extra_links(self): "_task_module": "unit.serialization.test_dag_serialization", "_is_mapped": True, } - deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag[Encoding.VAR]) + deserialized_dag = DagSerialization.deserialize_dag(serialized_dag[Encoding.VAR]) # operator defined links have to be instances of XComOperatorLink assert deserialized_dag.task_dict["task"].operator_extra_links == [ XComOperatorLink(name="airflow", xcom_key="_link_AirflowLink2") @@ -3408,13 +3412,13 @@ def test_handle_v1_serdag(): }, ] - SerializedDAG.conversion_v1_to_v2(v1) - SerializedDAG.conversion_v2_to_v3(v1) + DagSerialization.conversion_v1_to_v2(v1) + DagSerialization.conversion_v2_to_v3(v1) - dag = SerializedDAG.from_dict(v1) + dag = DagSerialization.from_dict(v1) expected_sdag = copy.deepcopy(serialized_simple_dag_ground_truth) - expected = SerializedDAG.from_dict(expected_sdag) + expected = DagSerialization.from_dict(expected_sdag) fields_to_verify = set(vars(expected).keys()) - { "task_group", # Tested separately @@ -3616,10 +3620,10 @@ def test_handle_v2_serdag(): } # Test that v2 DAGs can be deserialized without conversion - dag = SerializedDAG.from_dict(v2) + dag = DagSerialization.from_dict(v2) expected_sdag = copy.deepcopy(serialized_simple_dag_ground_truth) - expected = SerializedDAG.from_dict(expected_sdag) + expected = DagSerialization.from_dict(expected_sdag) fields_to_verify = set(vars(expected).keys()) - { "task_group", # Tested separately @@ -3664,11 +3668,11 @@ def test_dag_schema_defaults_optimization(): ) # Serialize and check exclusions - serialized = SerializedDAG.to_dict(dag_with_defaults) + serialized = DagSerialization.to_dict(dag_with_defaults) dag_data = serialized["dag"] # Schema default fields should be excluded - for field in SerializedDAG.get_schema_defaults("dag").keys(): + for field in DagSerialization.get_schema_defaults("dag").keys(): assert field not in dag_data, f"Schema default field '{field}' should be excluded" # None fields should also be excluded @@ -3677,7 +3681,7 @@ def test_dag_schema_defaults_optimization(): assert field not in dag_data, f"None field '{field}' should be excluded" # Test deserialization restores defaults correctly - deserialized_dag = SerializedDAG.from_dict(serialized) + deserialized_dag = DagSerialization.from_dict(serialized) # Verify schema defaults are restored assert deserialized_dag.catchup is False @@ -3697,7 +3701,7 @@ def test_dag_schema_defaults_optimization(): description="Test description", # Non-None ) - serialized_non_defaults = SerializedDAG.to_dict(dag_non_defaults) + serialized_non_defaults = DagSerialization.to_dict(dag_non_defaults) dag_non_defaults_data = serialized_non_defaults["dag"] # Non-default values should be included @@ -3720,7 +3724,7 @@ def test_email_optimization_removes_email_attrs_when_email_empty(): email_on_retry=True, # This should be removed during serialization ) - serialized_dag = SerializedDAG.to_dict(dag) + serialized_dag = DagSerialization.to_dict(dag) task_serialized = serialized_dag["dag"]["tasks"][0]["__var"] assert task_serialized is not None @@ -3737,7 +3741,7 @@ def test_email_optimization_removes_email_attrs_when_email_empty(): email_on_retry=True, ) - serialized_dag_with_email = SerializedDAG.to_dict(dag_with_email) + serialized_dag_with_email = DagSerialization.to_dict(dag_with_email) task_with_email_serialized = serialized_dag_with_email["dag"]["tasks"][0]["__var"] assert task_with_email_serialized is not None @@ -3984,10 +3988,10 @@ def test_task(): test_task() - serialized_dag = SerializedDAG.to_dict(dag) + serialized_dag = DagSerialization.to_dict(dag) assert serialized_dag["dag"]["tasks"][0]["__var"]["weight_rule"] == "absolute" - deserialized_dag = SerializedDAG.from_dict(serialized_dag) + deserialized_dag = DagSerialization.from_dict(serialized_dag) deserialized_task = deserialized_dag.task_dict["test_task"] assert isinstance(deserialized_task.weight_rule, _AbsolutePriorityWeightStrategy) @@ -4132,7 +4136,7 @@ def test_multiple_tasks_share_client_defaults(self, operator_defaults): BashOperator(task_id="task1", bash_command="echo 1") BashOperator(task_id="task2", bash_command="echo 2") - serialized = SerializedDAG.to_dict(dag) + serialized = DagSerialization.to_dict(dag) # Should have one client_defaults section for all tasks assert "client_defaults" in serialized @@ -4142,7 +4146,7 @@ def test_multiple_tasks_share_client_defaults(self, operator_defaults): client_defaults = serialized["client_defaults"]["tasks"] # Deserialize and check both tasks get the defaults - deserialized_dag = SerializedDAG.from_dict(serialized) + deserialized_dag = DagSerialization.from_dict(serialized) deserialized_task1 = deserialized_dag.get_task("task1") deserialized_task2 = deserialized_dag.get_task("task2") @@ -4164,7 +4168,7 @@ def test_default_args_when_equal_to_schema_defaults(self, operator_defaults): BashOperator(task_id="task1", bash_command="echo 1") BashOperator(task_id="task2", bash_command="echo 1", retries=2) - serialized = SerializedDAG.to_dict(dag) + serialized = DagSerialization.to_dict(dag) # verify client_defaults has retries=3 assert "client_defaults" in serialized @@ -4178,10 +4182,10 @@ def test_default_args_when_equal_to_schema_defaults(self, operator_defaults): task2_data = serialized["dag"]["tasks"][1]["__var"] assert task2_data.get("retries", -1) == 2 - deserialized_task1 = SerializedDAG.from_dict(serialized).get_task("task1") + deserialized_task1 = DagSerialization.from_dict(serialized).get_task("task1") assert deserialized_task1.retries == 0 - deserialized_task2 = SerializedDAG.from_dict(serialized).get_task("task2") + deserialized_task2 = DagSerialization.from_dict(serialized).get_task("task2") assert deserialized_task2.retries == 2 @@ -4199,14 +4203,14 @@ def test_mapped_operator_client_defaults_application(self, operator_defaults): ).expand(bash_command=["echo 1", "echo 2", "echo 3"]) # Serialize the DAG - serialized_dag = SerializedDAG.to_dict(dag) + serialized_dag = DagSerialization.to_dict(dag) # Should have client_defaults section assert "client_defaults" in serialized_dag assert "tasks" in serialized_dag["client_defaults"] # Deserialize and check that client_defaults are applied - deserialized_dag = SerializedDAG.from_dict(serialized_dag) + deserialized_dag = DagSerialization.from_dict(serialized_dag) deserialized_task = deserialized_dag.get_task("mapped_task") # Verify it's still a MappedOperator @@ -4265,7 +4269,7 @@ def test_mapped_operator_client_defaults_optimization( **task_config, ).expand(bash_command=["echo 1", "echo 2", "echo 3"]) - serialized_dag = SerializedDAG.to_dict(dag) + serialized_dag = DagSerialization.to_dict(dag) mapped_task_serialized = serialized_dag["dag"]["tasks"][0]["__var"] assert mapped_task_serialized is not None @@ -4380,8 +4384,8 @@ def test_partial_kwargs_end_to_end_deserialization(self): ).expand(bash_command=["echo 1", "echo 2"]) # Serialize and deserialize the DAG - serialized_dag = SerializedDAG.to_dict(dag) - deserialized_dag = SerializedDAG.from_dict(serialized_dag) + serialized_dag = DagSerialization.to_dict(dag) + deserialized_dag = DagSerialization.from_dict(serialized_dag) deserialized_task = deserialized_dag.get_task("mapped_task") # Verify the task has correct values after round-trip @@ -4441,7 +4445,7 @@ def test_dag_default_args_callbacks_serialization(callbacks, expected_has_flags, with DAG(dag_id="test_default_args_callbacks", default_args=default_args) as dag: BashOperator(task_id="task1", bash_command="echo 1", dag=dag) - serialized_dag_dict = SerializedDAG.serialize_dag(dag) + serialized_dag_dict = DagSerialization.serialize_dag(dag) default_args_dict = serialized_dag_dict["default_args"][Encoding.VAR] for flag in expected_has_flags: @@ -4453,5 +4457,5 @@ def test_dag_default_args_callbacks_serialization(callbacks, expected_has_flags, assert default_args_dict["owner"] == "test_owner" assert default_args_dict["retries"] == 2 - deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag_dict) + deserialized_dag = DagSerialization.deserialize_dag(serialized_dag_dict) assert deserialized_dag.dag_id == "test_default_args_callbacks" diff --git a/airflow-core/tests/unit/serialization/test_serialized_objects.py b/airflow-core/tests/unit/serialization/test_serialized_objects.py index 412d1d283c0b9..0ddc396bd889f 100644 --- a/airflow-core/tests/unit/serialization/test_serialized_objects.py +++ b/airflow-core/tests/unit/serialization/test_serialized_objects.py @@ -81,8 +81,8 @@ from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding from airflow.serialization.serialized_objects import ( BaseSerialization, + DagSerialization, LazyDeserializedDAG, - SerializedDAG, _has_kubernetes, create_scheduler_operator, ) @@ -584,7 +584,7 @@ def test_serialized_dag_has_task_concurrency_limits(dag_maker, concurrency_param with dag_maker() as dag: BashOperator(task_id="task1", bash_command="echo 1", **{concurrency_parameter: 1}) - ser_dict = SerializedDAG.to_dict(dag) + ser_dict = DagSerialization.to_dict(dag) lazy_serialized_dag = LazyDeserializedDAG(data=ser_dict) assert lazy_serialized_dag.has_task_concurrency_limits @@ -611,7 +611,7 @@ def map_me_but_slowly(a): map_me_but_slowly.expand(a=my_task()) - ser_dict = SerializedDAG.to_dict(dag) + ser_dict = DagSerialization.to_dict(dag) lazy_serialized_dag = LazyDeserializedDAG(data=ser_dict) assert lazy_serialized_dag.has_task_concurrency_limits diff --git a/airflow-core/tests/unit/timetables/test_assets_timetable.py b/airflow-core/tests/unit/timetables/test_assets_timetable.py index b39be1130f5ab..8eb7ca0f6b7d6 100644 --- a/airflow-core/tests/unit/timetables/test_assets_timetable.py +++ b/airflow-core/tests/unit/timetables/test_assets_timetable.py @@ -26,10 +26,11 @@ from sqlalchemy import select from airflow.models.asset import AssetDagRunQueue, AssetEvent, AssetModel -from airflow.models.serialized_dag import SerializedDAG, SerializedDagModel +from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.standard.operators.empty import EmptyOperator from airflow.sdk import Asset, AssetAll, AssetAny, AssetOrTimeSchedule as SdkAssetOrTimeSchedule from airflow.serialization.definitions.assets import SerializedAsset, SerializedAssetAll, SerializedAssetAny +from airflow.serialization.serialized_objects import DagSerialization from airflow.timetables.assets import AssetOrTimeSchedule as CoreAssetOrTimeSchedule from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable from airflow.timetables.simple import AssetTriggeredTimetable @@ -308,7 +309,7 @@ def test_asset_dag_run_queue_processing(self, session, dag_maker, create_test_as ) for serialized_dag in serialized_dags: - dag = SerializedDAG.deserialize(serialized_dag.data) + dag = DagSerialization.deserialize(serialized_dag.data) for asset_uri, status in dag_statuses[dag.dag_id].items(): cond = dag.timetable.asset_condition assert evaluator.run(cond, {asset_uri: status}), "DAG trigger evaluation failed" @@ -324,8 +325,8 @@ def test_dag_with_complex_asset_condition(self, dag_maker): assert dag.timetable.asset_condition == AssetAny(asset1, AssetAll(asset2, asset1)) - serialized_triggers = SerializedDAG.serialize(dag.timetable.asset_condition) - deserialized_triggers = SerializedDAG.deserialize(serialized_triggers) + serialized_triggers = DagSerialization.serialize(dag.timetable.asset_condition) + deserialized_triggers = DagSerialization.deserialize(serialized_triggers) assert deserialized_triggers == SerializedAssetAny( [ SerializedAsset("hello1", "test://asset1/", "asset", {}, []), @@ -338,7 +339,7 @@ def test_dag_with_complex_asset_condition(self, dag_maker): ], ) - serialized_timetable_dict = SerializedDAG.to_dict(dag)["dag"]["timetable"]["__var"] + serialized_timetable_dict = DagSerialization.to_dict(dag)["dag"]["timetable"]["__var"] assert serialized_timetable_dict == { "asset_condition": { "__type": "asset_any", diff --git a/airflow-core/tests/unit/utils/test_db_cleanup.py b/airflow-core/tests/unit/utils/test_db_cleanup.py index efb8e306a9ee0..e005e476de42e 100644 --- a/airflow-core/tests/unit/utils/test_db_cleanup.py +++ b/airflow-core/tests/unit/utils/test_db_cleanup.py @@ -36,8 +36,9 @@ from airflow.models import DagModel, DagRun, TaskInstance from airflow.models.dag_version import DagVersion from airflow.models.dagbundle import DagBundleModel -from airflow.models.serialized_dag import LazyDeserializedDAG, SerializedDagModel +from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.standard.operators.python import PythonOperator +from airflow.serialization.serialized_objects import LazyDeserializedDAG from airflow.utils.db_cleanup import ( ARCHIVE_TABLE_PREFIX, CreateTableAs, diff --git a/airflow-core/tests/unit/utils/test_task_group.py b/airflow-core/tests/unit/utils/test_task_group.py index 67ba538b6cbc8..3b62ad75a7290 100644 --- a/airflow-core/tests/unit/utils/test_task_group.py +++ b/airflow-core/tests/unit/utils/test_task_group.py @@ -33,9 +33,9 @@ task_group as task_group_decorator, teardown, ) -from airflow.serialization.serialized_objects import SerializedDAG from airflow.utils.dag_edges import dag_edges +from tests_common.test_utils.dag import create_scheduler_dag from unit.models import DEFAULT_DATE pytestmark = [pytest.mark.db_test, pytest.mark.need_serialized_dag] @@ -238,7 +238,7 @@ def test_task_group_to_dict_alternative_syntax(): task1 >> group234 group34 >> task5 - serialized_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + serialized_dag = create_scheduler_dag(dag) assert task_group_to_dict(serialized_dag.task_group) == EXPECTED_JSON @@ -1155,7 +1155,7 @@ def work(): ... assert set(t2.operator.downstream_task_ids) == set() def get_nodes(group): - serialized_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + serialized_dag = create_scheduler_dag(dag) group = serialized_dag.task_group_dict[g1.group_id] d = task_group_to_dict(group) new_d = {} diff --git a/devel-common/src/tests_common/test_utils/compat.py b/devel-common/src/tests_common/test_utils/compat.py index 718c9f8340fe3..6b3b1b03d353f 100644 --- a/devel-common/src/tests_common/test_utils/compat.py +++ b/devel-common/src/tests_common/test_utils/compat.py @@ -39,6 +39,15 @@ # Compatibility for Airflow 2.7.* from airflow.models.baseoperator import BaseOperatorLink # type: ignore[no-redef] +try: + from airflow.serialization.definitions.dag import SerializedDAG + from airflow.serialization.serialized_objects import DagSerialization +except ImportError: + # Compatibility for Airflow < 3.2.* + from airflow.serialization.serialized_objects import SerializedDAG # type: ignore[no-redef] + + DagSerialization = SerializedDAG # type: ignore[assignment,misc,no-redef] + try: from airflow.providers.common.sql.operators.generic_transfer import GenericTransfer from airflow.providers.standard.operators.bash import BashOperator @@ -63,6 +72,10 @@ # Compatibility for Airflow < 3.1 from airflow.utils.xcom import XCOM_RETURN_KEY # type: ignore[no-redef,attr-defined] +try: + from airflow.sdk import timezone +except ImportError: + from airflow.utils import timezone # type: ignore[no-redef,attr-defined] try: from airflow.sdk import TriggerRule diff --git a/devel-common/src/tests_common/test_utils/dag.py b/devel-common/src/tests_common/test_utils/dag.py index 8d5cd1da069ca..6e02ddf61ab67 100644 --- a/devel-common/src/tests_common/test_utils/dag.py +++ b/devel-common/src/tests_common/test_utils/dag.py @@ -23,19 +23,18 @@ from airflow.utils.session import NEW_SESSION, provide_session +from tests_common.test_utils.compat import DagSerialization, SerializedDAG + if TYPE_CHECKING: from sqlalchemy.orm import Session from airflow.sdk import DAG - from airflow.serialization.serialized_objects import SerializedDAG def create_scheduler_dag(dag: DAG | SerializedDAG) -> SerializedDAG: - from airflow.serialization.serialized_objects import SerializedDAG - if isinstance(dag, SerializedDAG): return dag - return SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag)) + return DagSerialization.deserialize_dag(DagSerialization.serialize_dag(dag)) @provide_session @@ -62,15 +61,15 @@ def sync_dags_to_db( """ from airflow.models.dagbundle import DagBundleModel from airflow.models.serialized_dag import SerializedDagModel - from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG + from airflow.serialization.serialized_objects import LazyDeserializedDAG session.merge(DagBundleModel(name=bundle_name)) session.flush() def _write_dag(dag: DAG) -> SerializedDAG: - data = SerializedDAG.to_dict(dag) + data = DagSerialization.to_dict(dag) SerializedDagModel.write_dag(LazyDeserializedDAG(data=data), bundle_name, session=session) - return SerializedDAG.from_dict(data) + return DagSerialization.from_dict(data) SerializedDAG.bulk_write_to_db(bundle_name, None, dags, session=session) scheduler_dags = [_write_dag(dag) for dag in dags] diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_base_aws.py b/providers/amazon/tests/unit/amazon/aws/links/test_base_aws.py index 0a1f713002766..bd3b4d669b3ee 100644 --- a/providers/amazon/tests/unit/amazon/aws/links/test_base_aws.py +++ b/providers/amazon/tests/unit/amazon/aws/links/test_base_aws.py @@ -24,8 +24,8 @@ from airflow.providers.amazon.aws.links.base_aws import BaseAwsLink from airflow.providers.common.compat.sdk import XCom -from airflow.serialization.serialized_objects import SerializedDAG +from tests_common.test_utils.compat import DagSerialization from tests_common.test_utils.mock_operators import MockOperator if TYPE_CHECKING: @@ -197,7 +197,7 @@ def test_link_serialize(self): """Test: Operator links should exist for serialized DAG.""" self.create_op_and_ti(self.link_class, dag_id="test_link_serialize", task_id=self.task_id) serialized_dag = self.dag_maker.get_serialized_data() - deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag["dag"]) + deserialized_dag = DagSerialization.deserialize_dag(serialized_dag["dag"]) operator_extra_link = deserialized_dag.tasks[0].operator_extra_links[0] error_message = "Operator links should exist for serialized DAG" assert operator_extra_link.name == self.link_class.name, error_message @@ -209,7 +209,7 @@ def test_empty_xcom(self): ).task_instance serialized_dag = self.dag_maker.get_serialized_data() - deserialized_dag = SerializedDAG.from_dict(serialized_dag) + deserialized_dag = DagSerialization.from_dict(serialized_dag) deserialized_task = deserialized_dag.task_dict[self.task_id] assert ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key) == "", ( diff --git a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py index 64d90fe022a57..9912e4f2f0c06 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py +++ b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py @@ -66,9 +66,8 @@ DataprocSubmitTrigger, ) from airflow.providers.google.common.consts import GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME -from airflow.serialization.serialized_objects import SerializedDAG -from airflow.utils.timezone import datetime +from tests_common.test_utils.compat import DagSerialization, timezone from tests_common.test_utils.db import clear_db_runs, clear_db_xcom from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS @@ -347,7 +346,7 @@ "jobs": [{"step_id": "pig_job_1", "pig_job": {}}], } TEST_DAG_ID = "test-dataproc-operators" -DEFAULT_DATE = datetime(2020, 1, 1) +DEFAULT_DATE = timezone.datetime(2020, 1, 1) TEST_JOB_ID = "test-job" TEST_WORKFLOW_ID = "test-workflow" @@ -592,7 +591,7 @@ def test_build(self): service_account="service_account", service_account_scopes=["service_account_scopes"], idle_delete_ttl=60, - auto_delete_time=datetime(2019, 9, 12), + auto_delete_time=timezone.datetime(2019, 9, 12), auto_delete_ttl=250, customer_managed_key="customer_managed_key", driver_pool_id="cluster_driver_pool", @@ -633,7 +632,7 @@ def test_build_with_custom_image_family(self): service_account="service_account", service_account_scopes=["service_account_scopes"], idle_delete_ttl=60, - auto_delete_time=datetime(2019, 9, 12), + auto_delete_time=timezone.datetime(2019, 9, 12), auto_delete_ttl=250, customer_managed_key="customer_managed_key", enable_component_gateway=True, @@ -672,7 +671,7 @@ def test_build_with_flex_migs(self): service_account="service_account", service_account_scopes=["service_account_scopes"], idle_delete_ttl=60, - auto_delete_time=datetime(2019, 9, 12), + auto_delete_time=timezone.datetime(2019, 9, 12), auto_delete_ttl=250, customer_managed_key="customer_managed_key", secondary_worker_instance_flexibility_policy=InstanceFlexibilityPolicy( @@ -729,7 +728,7 @@ def test_build_with_gpu_accelerator(self): service_account="service_account", service_account_scopes=["service_account_scopes"], idle_delete_ttl=60, - auto_delete_time=datetime(2019, 9, 12), + auto_delete_time=timezone.datetime(2019, 9, 12), auto_delete_ttl=250, customer_managed_key="customer_managed_key", ) @@ -1137,7 +1136,7 @@ def test_create_cluster_operator_extra_links( serialized_dag = dag_maker.get_serialized_data() # Assert operator links for serialized DAG - deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag["dag"]) + deserialized_dag = DagSerialization.deserialize_dag(serialized_dag["dag"]) operator_extra_link = deserialized_dag.tasks[0].operator_extra_links[0] assert operator_extra_link.name == "Dataproc Cluster" @@ -2020,7 +2019,7 @@ def test_submit_job_operator_extra_links( serialized_dag = dag_maker.get_serialized_data() # Assert operator links for serialized DAG - deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag["dag"]) + deserialized_dag = DagSerialization.deserialize_dag(serialized_dag["dag"]) operator_extra_link = deserialized_dag.tasks[0].operator_extra_links[0] assert operator_extra_link.name == "Dataproc Job" @@ -2229,7 +2228,7 @@ def test_update_cluster_operator_extra_links( serialized_dag = dag_maker.get_serialized_data() # Assert operator links for serialized DAG - deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag["dag"]) + deserialized_dag = DagSerialization.deserialize_dag(serialized_dag["dag"]) operator_extra_link = deserialized_dag.tasks[0].operator_extra_links[0] assert operator_extra_link.name == "Dataproc Cluster" @@ -2455,7 +2454,7 @@ def test_instantiate_workflow_operator_extra_links( serialized_dag = dag_maker.get_serialized_data() # Assert operator links for serialized DAG - deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag["dag"]) + deserialized_dag = DagSerialization.deserialize_dag(serialized_dag["dag"]) operator_extra_link = deserialized_dag.tasks[0].operator_extra_links[0] assert operator_extra_link.name == "Dataproc Workflow" @@ -3149,7 +3148,7 @@ def test_instantiate_inline_workflow_operator_extra_links( serialized_dag = dag_maker.get_serialized_data() # Assert operator links for serialized DAG - deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag["dag"]) + deserialized_dag = DagSerialization.deserialize_dag(serialized_dag["dag"]) operator_extra_link = deserialized_dag.tasks[0].operator_extra_links[0] assert operator_extra_link.name == "Dataproc Workflow" if AIRFLOW_V_3_0_PLUS: diff --git a/scripts/in_container/run_schema_defaults_check.py b/scripts/in_container/run_schema_defaults_check.py index eee5b8daba32d..8ad321ac8e2c5 100755 --- a/scripts/in_container/run_schema_defaults_check.py +++ b/scripts/in_container/run_schema_defaults_check.py @@ -101,7 +101,7 @@ def get_server_side_operator_defaults() -> dict[str, Any]: def get_server_side_dag_defaults() -> dict[str, Any]: """Get default values from server-side SerializedDAG class.""" try: - from airflow.serialization.serialized_objects import SerializedDAG + from airflow.serialization.definitions.dag import SerializedDAG # DAG defaults are set in __init__, so we create a temporary instance temp_dag = SerializedDAG(dag_id="temp") diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py b/task-sdk/src/airflow/sdk/definitions/dag.py index bef413af515bf..887493e1ca703 100644 --- a/task-sdk/src/airflow/sdk/definitions/dag.py +++ b/task-sdk/src/airflow/sdk/definitions/dag.py @@ -1185,8 +1185,9 @@ def test( from airflow import settings from airflow.models.dagrun import DagRun, get_or_create_dagrun from airflow.sdk import DagRunState, timezone + from airflow.serialization.definitions.dag import SerializedDAG from airflow.serialization.encoders import coerce_to_core_timetable - from airflow.serialization.serialized_objects import SerializedDAG + from airflow.serialization.serialized_objects import DagSerialization from airflow.utils.types import DagRunTriggeredByType, DagRunType exit_stack = ExitStack() @@ -1214,6 +1215,7 @@ def test( with exit_stack: self.validate() + scheduler_dag = DagSerialization.deserialize_dag(DagSerialization.serialize_dag(self)) # Allow users to explicitly pass None. If it isn't set, we default to current time. logical_date = logical_date if is_arg_set(logical_date) else timezone.utcnow() @@ -1221,7 +1223,7 @@ def test( log.debug("Clearing existing task instances for logical date %s", logical_date) # TODO: Replace with calling client.dag_run.clear in Execution API at some point SerializedDAG.clear_dags( - dags=[self], + dags=[scheduler_dag], start_date=logical_date, end_date=logical_date, dag_run_state=False, @@ -1261,7 +1263,7 @@ def test( version = DagVersion.get_version(self.dag_id) if version: break - scheduler_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(self)) + # Preserve callback functions from original Dag since they're lost during serialization # and yes it is a hack for now! It is a tradeoff for code simplicity. # Without it, we need "Scheduler Dag" (Serialized dag) for the scheduler bits diff --git a/task-sdk/tests/task_sdk/definitions/test_asset.py b/task-sdk/tests/task_sdk/definitions/test_asset.py index 1dafb674f752b..f7e4502e3790a 100644 --- a/task-sdk/tests/task_sdk/definitions/test_asset.py +++ b/task-sdk/tests/task_sdk/definitions/test_asset.py @@ -39,7 +39,8 @@ from airflow.sdk.definitions.dag import DAG from airflow.sdk.io import ObjectStoragePath from airflow.serialization.definitions.assets import SerializedAsset, SerializedAssetAny -from airflow.serialization.serialized_objects import SerializedDAG + +from tests_common.test_utils.dag import create_scheduler_dag ASSET_MODULE_PATH = "airflow.sdk.definitions.asset" @@ -237,7 +238,7 @@ def test_asset_trigger_setup_and_serialization(create_test_assets): assert isinstance(dag.timetable.asset_condition, AssetAny), "Dag assets should be an instance of AssetAny" # Round-trip the Dag through serialization - deserialized_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag)) + deserialized_dag = create_scheduler_dag(dag) # Verify serialization and deserialization integrity assert deserialized_dag.timetable.asset_condition == SerializedAssetAny(