diff --git a/providers/standard/src/airflow/providers/standard/operators/latest_only.py b/providers/standard/src/airflow/providers/standard/operators/latest_only.py index fd5cf62bbad2f..251473dfcf110 100644 --- a/providers/standard/src/airflow/providers/standard/operators/latest_only.py +++ b/providers/standard/src/airflow/providers/standard/operators/latest_only.py @@ -20,6 +20,7 @@ from __future__ import annotations from collections.abc import Iterable +from datetime import timedelta from typing import TYPE_CHECKING import pendulum @@ -29,8 +30,9 @@ from airflow.utils.types import DagRunType if TYPE_CHECKING: - from airflow.models import DAG, DagRun - from airflow.timetables.base import DagRunInfo + from pendulum.datetime import DateTime + + from airflow.models import DagRun try: from airflow.sdk.definitions.context import Context @@ -62,16 +64,16 @@ def choose_branch(self, context: Context) -> str | Iterable[str]: dag_run: DagRun = context["dag_run"] # type: ignore[assignment] if dag_run.run_type == DagRunType.MANUAL: self.log.info("Manually triggered DAG_Run: allowing execution to proceed.") - return list(context["task"].get_direct_relative_ids(upstream=False)) + return list(self.get_direct_relative_ids(upstream=False)) - next_info = self._get_next_run_info(context, dag_run) - now = pendulum.now("UTC") + dates = self._get_compare_dates(dag_run) - if next_info is None: + if dates is None: self.log.info("Last scheduled execution: allowing execution to proceed.") - return list(context["task"].get_direct_relative_ids(upstream=False)) + return list(self.get_direct_relative_ids(upstream=False)) - left_window, right_window = next_info.data_interval + now = pendulum.now("UTC") + left_window, right_window = dates self.log.info( "Checking latest only with left_window: %s right_window: %s now: %s", left_window, @@ -79,37 +81,47 @@ def choose_branch(self, context: Context) -> str | Iterable[str]: now, ) - if left_window == right_window: - self.log.info( - "Zero-length interval [%s, %s) from timetable (%s); treating current run as latest.", - left_window, - right_window, - self.dag.timetable.__class__, - ) - return list(context["task"].get_direct_relative_ids(upstream=False)) - if not left_window < now <= right_window: self.log.info("Not latest execution, skipping downstream.") # we return an empty list, thus the parent BaseBranchOperator # won't exclude any downstream tasks from skipping. return [] - self.log.info("Latest, allowing execution to proceed.") - return list(context["task"].get_direct_relative_ids(upstream=False)) - def _get_next_run_info(self, context: Context, dag_run: DagRun) -> DagRunInfo | None: - dag: DAG = context["dag"] # type: ignore[assignment] + self.log.info("Latest, allowing execution to proceed.") + return list(self.get_direct_relative_ids(upstream=False)) + def _get_compare_dates(self, dag_run: DagRun) -> tuple[DateTime, DateTime] | None: + dagrun_date: DateTime if AIRFLOW_V_3_0_PLUS: - from airflow.timetables.base import DataInterval, TimeRestriction + dagrun_date = dag_run.logical_date or dag_run.run_after + else: + dagrun_date = dag_run.logical_date - time_restriction = TimeRestriction(earliest=None, latest=None, catchup=True) - current_interval = DataInterval(start=dag_run.data_interval_start, end=dag_run.data_interval_end) + from airflow.timetables.base import DataInterval, TimeRestriction - next_info = dag.timetable.next_dagrun_info( - last_automated_data_interval=current_interval, - restriction=time_restriction, - ) + current_interval = DataInterval( + start=dag_run.data_interval_start or dagrun_date, + end=dag_run.data_interval_end or dagrun_date, + ) + time_restriction = TimeRestriction( + earliest=None, latest=current_interval.end - timedelta(microseconds=1), catchup=True + ) + if prev_info := self.dag.timetable.next_dagrun_info( + last_automated_data_interval=current_interval, + restriction=time_restriction, + ): + left = prev_info.data_interval.end else: - next_info = dag.next_dagrun_info(dag.get_run_data_interval(dag_run), restricted=False) - return next_info + left = current_interval.start + + time_restriction = TimeRestriction(earliest=current_interval.end, latest=None, catchup=True) + next_info = self.dag.timetable.next_dagrun_info( + last_automated_data_interval=current_interval, + restriction=time_restriction, + ) + + if not next_info: + return None + + return (left, next_info.data_interval.end) diff --git a/providers/standard/tests/unit/standard/operators/test_latest_only_operator.py b/providers/standard/tests/unit/standard/operators/test_latest_only_operator.py index b976f41fa9a2e..81f89a5fdd936 100644 --- a/providers/standard/tests/unit/standard/operators/test_latest_only_operator.py +++ b/providers/standard/tests/unit/standard/operators/test_latest_only_operator.py @@ -18,6 +18,7 @@ from __future__ import annotations import datetime +import operator import pytest import time_machine @@ -115,9 +116,7 @@ def test_skipping_non_latest(self, dag_maker): start_date=timezone.utcnow(), logical_date=timezone.datetime(2016, 1, 1, 12), state=State.RUNNING, - data_interval=DataInterval( - timezone.datetime(2016, 1, 1, 12), timezone.datetime(2016, 1, 1, 12) + INTERVAL - ), + data_interval=DataInterval(timezone.datetime(2016, 1, 1, 12), timezone.datetime(2016, 1, 1, 12)), **triggered_by_kwargs, ) @@ -126,7 +125,7 @@ def test_skipping_non_latest(self, dag_maker): start_date=timezone.utcnow(), logical_date=END_DATE, state=State.RUNNING, - data_interval=DataInterval(END_DATE, END_DATE + INTERVAL), + data_interval=DataInterval(END_DATE + INTERVAL, END_DATE + INTERVAL), **triggered_by_kwargs, ) @@ -145,6 +144,7 @@ def test_skipping_non_latest(self, dag_maker): latest_ti0.run() assert exc_info.value.tasks == [("downstream", -1)] + # TODO: Set state is needed until #45549 is completed. latest_ti0.set_state(State.SUCCESS) dr0.get_task_instance(task_id="downstream").set_state(State.SKIPPED) @@ -156,6 +156,7 @@ def test_skipping_non_latest(self, dag_maker): latest_ti1.run() assert exc_info.value.tasks == [("downstream", -1)] + # TODO: Set state is needed until #45549 is completed. latest_ti1.set_state(State.SUCCESS) dr1.get_task_instance(task_id="downstream").set_state(State.SKIPPED) @@ -165,77 +166,49 @@ def test_skipping_non_latest(self, dag_maker): latest_ti2.task = latest_task latest_ti2.run() - latest_ti2.set_state(State.SUCCESS) - - # Verify the state of the other downstream tasks - downstream_task.run(start_date=DEFAULT_DATE, end_date=END_DATE) - downstream_task2.run(start_date=DEFAULT_DATE, end_date=END_DATE) - downstream_task3.run(start_date=DEFAULT_DATE, end_date=END_DATE) - - downstream_instances = get_task_instances("downstream") - exec_date_to_downstream_state = {ti.logical_date: ti.state for ti in downstream_instances} - assert exec_date_to_downstream_state == { - timezone.datetime(2016, 1, 1): "skipped", - timezone.datetime(2016, 1, 1, 12): "skipped", - timezone.datetime(2016, 1, 2): "success", - } - - downstream_instances = get_task_instances("downstream_2") - exec_date_to_downstream_state = {ti.logical_date: ti.state for ti in downstream_instances} - assert exec_date_to_downstream_state == { - timezone.datetime(2016, 1, 1): None, - timezone.datetime(2016, 1, 1, 12): None, - timezone.datetime(2016, 1, 2): "success", - } - - downstream_instances = get_task_instances("downstream_3") - exec_date_to_downstream_state = {ti.logical_date: ti.state for ti in downstream_instances} - assert exec_date_to_downstream_state == { - timezone.datetime(2016, 1, 1): "success", - timezone.datetime(2016, 1, 1, 12): "success", - timezone.datetime(2016, 1, 2): "success", - } - + date_getter = operator.attrgetter("logical_date") else: latest_task.run(start_date=DEFAULT_DATE, end_date=END_DATE) + date_getter = operator.attrgetter("execution_date") - downstream_task.run(start_date=DEFAULT_DATE, end_date=END_DATE) - downstream_task2.run(start_date=DEFAULT_DATE, end_date=END_DATE) - downstream_task3.run(start_date=DEFAULT_DATE, end_date=END_DATE) + latest_instances = get_task_instances("latest") + exec_date_to_latest_state = {date_getter(ti): ti.state for ti in latest_instances} + assert exec_date_to_latest_state == { + timezone.datetime(2016, 1, 1): "success", + timezone.datetime(2016, 1, 1, 12): "success", + timezone.datetime(2016, 1, 2): "success", + } - latest_instances = get_task_instances("latest") - exec_date_to_latest_state = {ti.execution_date: ti.state for ti in latest_instances} - assert exec_date_to_latest_state == { - timezone.datetime(2016, 1, 1): "success", - timezone.datetime(2016, 1, 1, 12): "success", - timezone.datetime(2016, 1, 2): "success", - } + # Verify the state of the other downstream tasks + downstream_task.run(start_date=DEFAULT_DATE, end_date=END_DATE) + downstream_task2.run(start_date=DEFAULT_DATE, end_date=END_DATE) + downstream_task3.run(start_date=DEFAULT_DATE, end_date=END_DATE) - downstream_instances = get_task_instances("downstream") - exec_date_to_downstream_state = {ti.execution_date: ti.state for ti in downstream_instances} - assert exec_date_to_downstream_state == { - timezone.datetime(2016, 1, 1): "skipped", - timezone.datetime(2016, 1, 1, 12): "skipped", - timezone.datetime(2016, 1, 2): "success", - } + downstream_instances = get_task_instances("downstream") + exec_date_to_downstream_state = {date_getter(ti): ti.state for ti in downstream_instances} + assert exec_date_to_downstream_state == { + timezone.datetime(2016, 1, 1): "skipped", + timezone.datetime(2016, 1, 1, 12): "skipped", + timezone.datetime(2016, 1, 2): "success", + } - downstream_instances = get_task_instances("downstream_2") - exec_date_to_downstream_state = {ti.execution_date: ti.state for ti in downstream_instances} - assert exec_date_to_downstream_state == { - timezone.datetime(2016, 1, 1): None, - timezone.datetime(2016, 1, 1, 12): None, - timezone.datetime(2016, 1, 2): "success", - } + downstream_instances = get_task_instances("downstream_2") + exec_date_to_downstream_state = {date_getter(ti): ti.state for ti in downstream_instances} + assert exec_date_to_downstream_state == { + timezone.datetime(2016, 1, 1): None, + timezone.datetime(2016, 1, 1, 12): None, + timezone.datetime(2016, 1, 2): "success", + } - downstream_instances = get_task_instances("downstream_3") - exec_date_to_downstream_state = {ti.execution_date: ti.state for ti in downstream_instances} - assert exec_date_to_downstream_state == { - timezone.datetime(2016, 1, 1): "success", - timezone.datetime(2016, 1, 1, 12): "success", - timezone.datetime(2016, 1, 2): "success", - } + downstream_instances = get_task_instances("downstream_3") + exec_date_to_downstream_state = {date_getter(ti): ti.state for ti in downstream_instances} + assert exec_date_to_downstream_state == { + timezone.datetime(2016, 1, 1): "success", + timezone.datetime(2016, 1, 1, 12): "success", + timezone.datetime(2016, 1, 2): "success", + } - def test_not_skipping_external(self, dag_maker): + def test_not_skipping_manual(self, dag_maker): with dag_maker( default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, schedule=INTERVAL, serialized=True ):