From 142ce8824471dccb3b9036b9e78048adbb8cedfc Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Tue, 22 Apr 2025 12:06:09 +0100 Subject: [PATCH 1/4] Make LatestOnlyOperator work for default data-interval-less DAGs Remove the check/don't skip logic whe the data interval is zero-wdith. Even if a DAG doesn't have the concept of a data-interval (i.e. it is zero width) it still is logically consistent for it to have to concept of latest or not, so we now only compare against the end date of the interval. (And a few drive-by refactors too, `context["task"]` is `self`, `context["dag"]` is `self.dag`) --- .../standard/operators/latest_only.py | 72 +++++++----- .../operators/test_latest_only_operator.py | 107 +++++++----------- 2 files changed, 82 insertions(+), 97 deletions(-) diff --git a/providers/standard/src/airflow/providers/standard/operators/latest_only.py b/providers/standard/src/airflow/providers/standard/operators/latest_only.py index fd5cf62bbad2f..ddc5959c41c4b 100644 --- a/providers/standard/src/airflow/providers/standard/operators/latest_only.py +++ b/providers/standard/src/airflow/providers/standard/operators/latest_only.py @@ -20,6 +20,7 @@ from __future__ import annotations from collections.abc import Iterable +from datetime import timedelta from typing import TYPE_CHECKING import pendulum @@ -29,8 +30,8 @@ from airflow.utils.types import DagRunType if TYPE_CHECKING: - from airflow.models import DAG, DagRun - from airflow.timetables.base import DagRunInfo + from airflow.models import DagRun + from airflow.utils.timezone import DateTime try: from airflow.sdk.definitions.context import Context @@ -62,16 +63,16 @@ def choose_branch(self, context: Context) -> str | Iterable[str]: dag_run: DagRun = context["dag_run"] # type: ignore[assignment] if dag_run.run_type == DagRunType.MANUAL: self.log.info("Manually triggered DAG_Run: allowing execution to proceed.") - return list(context["task"].get_direct_relative_ids(upstream=False)) + return list(self.get_direct_relative_ids(upstream=False)) - next_info = self._get_next_run_info(context, dag_run) - now = pendulum.now("UTC") + dates = self._get_compare_dates(dag_run) - if next_info is None: + if dates is None: self.log.info("Last scheduled execution: allowing execution to proceed.") - return list(context["task"].get_direct_relative_ids(upstream=False)) + return list(self.get_direct_relative_ids(upstream=False)) - left_window, right_window = next_info.data_interval + now = pendulum.now("UTC") + left_window, right_window = dates self.log.info( "Checking latest only with left_window: %s right_window: %s now: %s", left_window, @@ -79,37 +80,48 @@ def choose_branch(self, context: Context) -> str | Iterable[str]: now, ) - if left_window == right_window: - self.log.info( - "Zero-length interval [%s, %s) from timetable (%s); treating current run as latest.", - left_window, - right_window, - self.dag.timetable.__class__, - ) - return list(context["task"].get_direct_relative_ids(upstream=False)) - if not left_window < now <= right_window: self.log.info("Not latest execution, skipping downstream.") # we return an empty list, thus the parent BaseBranchOperator # won't exclude any downstream tasks from skipping. return [] - self.log.info("Latest, allowing execution to proceed.") - return list(context["task"].get_direct_relative_ids(upstream=False)) - def _get_next_run_info(self, context: Context, dag_run: DagRun) -> DagRunInfo | None: - dag: DAG = context["dag"] # type: ignore[assignment] + self.log.info("Latest, allowing execution to proceed.") + return list(self.get_direct_relative_ids(upstream=False)) + def _get_compare_dates(self, dag_run: DagRun) -> tuple[DateTime, DateTime] | None: + dagrun_date: DateTime if AIRFLOW_V_3_0_PLUS: - from airflow.timetables.base import DataInterval, TimeRestriction + dagrun_date = dag_run.logical_date or dag_run.run_after + else: + dagrun_date = dag_run.logical_date - time_restriction = TimeRestriction(earliest=None, latest=None, catchup=True) - current_interval = DataInterval(start=dag_run.data_interval_start, end=dag_run.data_interval_end) + # breakpoint() + from airflow.timetables.base import DataInterval, TimeRestriction - next_info = dag.timetable.next_dagrun_info( - last_automated_data_interval=current_interval, - restriction=time_restriction, - ) + current_interval = DataInterval( + start=dag_run.data_interval_start or dagrun_date, + end=dag_run.data_interval_end or dagrun_date, + ) + time_restriction = TimeRestriction( + earliest=None, latest=current_interval.end - timedelta(microseconds=1), catchup=True + ) + if prev_info := self.dag.timetable.next_dagrun_info( + last_automated_data_interval=current_interval, + restriction=time_restriction, + ): + left = prev_info.data_interval.end else: - next_info = dag.next_dagrun_info(dag.get_run_data_interval(dag_run), restricted=False) - return next_info + left = current_interval.start + + time_restriction = TimeRestriction(earliest=current_interval.end, latest=None, catchup=True) + next_info = self.dag.timetable.next_dagrun_info( + last_automated_data_interval=current_interval, + restriction=time_restriction, + ) + + if not next_info: + return None + + return (left, next_info.data_interval.end) diff --git a/providers/standard/tests/unit/standard/operators/test_latest_only_operator.py b/providers/standard/tests/unit/standard/operators/test_latest_only_operator.py index b976f41fa9a2e..81f89a5fdd936 100644 --- a/providers/standard/tests/unit/standard/operators/test_latest_only_operator.py +++ b/providers/standard/tests/unit/standard/operators/test_latest_only_operator.py @@ -18,6 +18,7 @@ from __future__ import annotations import datetime +import operator import pytest import time_machine @@ -115,9 +116,7 @@ def test_skipping_non_latest(self, dag_maker): start_date=timezone.utcnow(), logical_date=timezone.datetime(2016, 1, 1, 12), state=State.RUNNING, - data_interval=DataInterval( - timezone.datetime(2016, 1, 1, 12), timezone.datetime(2016, 1, 1, 12) + INTERVAL - ), + data_interval=DataInterval(timezone.datetime(2016, 1, 1, 12), timezone.datetime(2016, 1, 1, 12)), **triggered_by_kwargs, ) @@ -126,7 +125,7 @@ def test_skipping_non_latest(self, dag_maker): start_date=timezone.utcnow(), logical_date=END_DATE, state=State.RUNNING, - data_interval=DataInterval(END_DATE, END_DATE + INTERVAL), + data_interval=DataInterval(END_DATE + INTERVAL, END_DATE + INTERVAL), **triggered_by_kwargs, ) @@ -145,6 +144,7 @@ def test_skipping_non_latest(self, dag_maker): latest_ti0.run() assert exc_info.value.tasks == [("downstream", -1)] + # TODO: Set state is needed until #45549 is completed. latest_ti0.set_state(State.SUCCESS) dr0.get_task_instance(task_id="downstream").set_state(State.SKIPPED) @@ -156,6 +156,7 @@ def test_skipping_non_latest(self, dag_maker): latest_ti1.run() assert exc_info.value.tasks == [("downstream", -1)] + # TODO: Set state is needed until #45549 is completed. latest_ti1.set_state(State.SUCCESS) dr1.get_task_instance(task_id="downstream").set_state(State.SKIPPED) @@ -165,77 +166,49 @@ def test_skipping_non_latest(self, dag_maker): latest_ti2.task = latest_task latest_ti2.run() - latest_ti2.set_state(State.SUCCESS) - - # Verify the state of the other downstream tasks - downstream_task.run(start_date=DEFAULT_DATE, end_date=END_DATE) - downstream_task2.run(start_date=DEFAULT_DATE, end_date=END_DATE) - downstream_task3.run(start_date=DEFAULT_DATE, end_date=END_DATE) - - downstream_instances = get_task_instances("downstream") - exec_date_to_downstream_state = {ti.logical_date: ti.state for ti in downstream_instances} - assert exec_date_to_downstream_state == { - timezone.datetime(2016, 1, 1): "skipped", - timezone.datetime(2016, 1, 1, 12): "skipped", - timezone.datetime(2016, 1, 2): "success", - } - - downstream_instances = get_task_instances("downstream_2") - exec_date_to_downstream_state = {ti.logical_date: ti.state for ti in downstream_instances} - assert exec_date_to_downstream_state == { - timezone.datetime(2016, 1, 1): None, - timezone.datetime(2016, 1, 1, 12): None, - timezone.datetime(2016, 1, 2): "success", - } - - downstream_instances = get_task_instances("downstream_3") - exec_date_to_downstream_state = {ti.logical_date: ti.state for ti in downstream_instances} - assert exec_date_to_downstream_state == { - timezone.datetime(2016, 1, 1): "success", - timezone.datetime(2016, 1, 1, 12): "success", - timezone.datetime(2016, 1, 2): "success", - } - + date_getter = operator.attrgetter("logical_date") else: latest_task.run(start_date=DEFAULT_DATE, end_date=END_DATE) + date_getter = operator.attrgetter("execution_date") - downstream_task.run(start_date=DEFAULT_DATE, end_date=END_DATE) - downstream_task2.run(start_date=DEFAULT_DATE, end_date=END_DATE) - downstream_task3.run(start_date=DEFAULT_DATE, end_date=END_DATE) + latest_instances = get_task_instances("latest") + exec_date_to_latest_state = {date_getter(ti): ti.state for ti in latest_instances} + assert exec_date_to_latest_state == { + timezone.datetime(2016, 1, 1): "success", + timezone.datetime(2016, 1, 1, 12): "success", + timezone.datetime(2016, 1, 2): "success", + } - latest_instances = get_task_instances("latest") - exec_date_to_latest_state = {ti.execution_date: ti.state for ti in latest_instances} - assert exec_date_to_latest_state == { - timezone.datetime(2016, 1, 1): "success", - timezone.datetime(2016, 1, 1, 12): "success", - timezone.datetime(2016, 1, 2): "success", - } + # Verify the state of the other downstream tasks + downstream_task.run(start_date=DEFAULT_DATE, end_date=END_DATE) + downstream_task2.run(start_date=DEFAULT_DATE, end_date=END_DATE) + downstream_task3.run(start_date=DEFAULT_DATE, end_date=END_DATE) - downstream_instances = get_task_instances("downstream") - exec_date_to_downstream_state = {ti.execution_date: ti.state for ti in downstream_instances} - assert exec_date_to_downstream_state == { - timezone.datetime(2016, 1, 1): "skipped", - timezone.datetime(2016, 1, 1, 12): "skipped", - timezone.datetime(2016, 1, 2): "success", - } + downstream_instances = get_task_instances("downstream") + exec_date_to_downstream_state = {date_getter(ti): ti.state for ti in downstream_instances} + assert exec_date_to_downstream_state == { + timezone.datetime(2016, 1, 1): "skipped", + timezone.datetime(2016, 1, 1, 12): "skipped", + timezone.datetime(2016, 1, 2): "success", + } - downstream_instances = get_task_instances("downstream_2") - exec_date_to_downstream_state = {ti.execution_date: ti.state for ti in downstream_instances} - assert exec_date_to_downstream_state == { - timezone.datetime(2016, 1, 1): None, - timezone.datetime(2016, 1, 1, 12): None, - timezone.datetime(2016, 1, 2): "success", - } + downstream_instances = get_task_instances("downstream_2") + exec_date_to_downstream_state = {date_getter(ti): ti.state for ti in downstream_instances} + assert exec_date_to_downstream_state == { + timezone.datetime(2016, 1, 1): None, + timezone.datetime(2016, 1, 1, 12): None, + timezone.datetime(2016, 1, 2): "success", + } - downstream_instances = get_task_instances("downstream_3") - exec_date_to_downstream_state = {ti.execution_date: ti.state for ti in downstream_instances} - assert exec_date_to_downstream_state == { - timezone.datetime(2016, 1, 1): "success", - timezone.datetime(2016, 1, 1, 12): "success", - timezone.datetime(2016, 1, 2): "success", - } + downstream_instances = get_task_instances("downstream_3") + exec_date_to_downstream_state = {date_getter(ti): ti.state for ti in downstream_instances} + assert exec_date_to_downstream_state == { + timezone.datetime(2016, 1, 1): "success", + timezone.datetime(2016, 1, 1, 12): "success", + timezone.datetime(2016, 1, 2): "success", + } - def test_not_skipping_external(self, dag_maker): + def test_not_skipping_manual(self, dag_maker): with dag_maker( default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, schedule=INTERVAL, serialized=True ): From da4875b858a7cb106f3feed9816752196a5fde4f Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Wed, 23 Apr 2025 14:33:17 +0100 Subject: [PATCH 2/4] Apply suggestions from code review Co-authored-by: Kaxil Naik Co-authored-by: Jens Scheffler <95105677+jscheffl@users.noreply.github.com> --- .../src/airflow/providers/standard/operators/latest_only.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/providers/standard/src/airflow/providers/standard/operators/latest_only.py b/providers/standard/src/airflow/providers/standard/operators/latest_only.py index ddc5959c41c4b..82a7bd4cdb0cf 100644 --- a/providers/standard/src/airflow/providers/standard/operators/latest_only.py +++ b/providers/standard/src/airflow/providers/standard/operators/latest_only.py @@ -31,7 +31,7 @@ if TYPE_CHECKING: from airflow.models import DagRun - from airflow.utils.timezone import DateTime + from pendulum.datetime import DateTime try: from airflow.sdk.definitions.context import Context @@ -96,7 +96,6 @@ def _get_compare_dates(self, dag_run: DagRun) -> tuple[DateTime, DateTime] | Non else: dagrun_date = dag_run.logical_date - # breakpoint() from airflow.timetables.base import DataInterval, TimeRestriction current_interval = DataInterval( From 593234210734609c978a1de760a87d6bdd9391d7 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Wed, 23 Apr 2025 16:50:01 +0100 Subject: [PATCH 3/4] Update providers/standard/src/airflow/providers/standard/operators/latest_only.py --- .../src/airflow/providers/standard/operators/latest_only.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/providers/standard/src/airflow/providers/standard/operators/latest_only.py b/providers/standard/src/airflow/providers/standard/operators/latest_only.py index 82a7bd4cdb0cf..397e935e318d8 100644 --- a/providers/standard/src/airflow/providers/standard/operators/latest_only.py +++ b/providers/standard/src/airflow/providers/standard/operators/latest_only.py @@ -30,8 +30,9 @@ from airflow.utils.types import DagRunType if TYPE_CHECKING: - from airflow.models import DagRun from pendulum.datetime import DateTime + + from airflow.models import DagRun try: from airflow.sdk.definitions.context import Context From 81af0335c9554b71332fe9e0eca246eacea43dd5 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Thu, 24 Apr 2025 09:45:45 +0100 Subject: [PATCH 4/4] fixup! --- .../src/airflow/providers/standard/operators/latest_only.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/providers/standard/src/airflow/providers/standard/operators/latest_only.py b/providers/standard/src/airflow/providers/standard/operators/latest_only.py index 397e935e318d8..251473dfcf110 100644 --- a/providers/standard/src/airflow/providers/standard/operators/latest_only.py +++ b/providers/standard/src/airflow/providers/standard/operators/latest_only.py @@ -31,7 +31,7 @@ if TYPE_CHECKING: from pendulum.datetime import DateTime - + from airflow.models import DagRun try: