Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2177,16 +2177,14 @@ def get_task_assets(
if isinstance(obj, of_type):
yield task["task_id"], obj

def get_run_data_interval(self, run: DagRun) -> DataInterval:
def get_run_data_interval(self, run: DagRun) -> DataInterval | None:
"""Get the data interval of this run."""
if run.dag_id is not None and run.dag_id != self.dag_id:
raise ValueError(f"Arguments refer to different DAGs: {self.dag_id} != {run.dag_id}")

data_interval = _get_model_data_interval(run, "data_interval_start", "data_interval_end")
# the older implementation has call to infer_automated_data_interval if data_interval is None, do we want to keep that or raise
# an exception?
if data_interval is None:
raise ValueError(f"Cannot calculate data interval for run {run}")
if data_interval is None and run.logical_date is not None:
data_interval = self._real_dag.timetable.infer_manual_data_interval(run_after=run.logical_date)

return data_interval

Expand Down
76 changes: 75 additions & 1 deletion airflow-core/tests/unit/serialization/test_serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,20 @@
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.providers.standard.operators.python import PythonOperator
from airflow.providers.standard.triggers.file import FileDeleteTrigger
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, AssetUniqueKey, AssetWatcher
from airflow.sdk import BaseOperator
from airflow.sdk.definitions.asset import (
Asset,
AssetAlias,
AssetAliasEvent,
AssetUniqueKey,
AssetWatcher,
)
from airflow.sdk.definitions.decorators import task
from airflow.sdk.definitions.param import Param
from airflow.sdk.execution_time.context import OutletEventAccessor, OutletEventAccessors
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
from airflow.serialization.serialized_objects import BaseSerialization, LazyDeserializedDAG, SerializedDAG
from airflow.timetables.base import DataInterval
from airflow.triggers.base import BaseTrigger
from airflow.utils import timezone
from airflow.utils.db import LazySelectSequence
Expand Down Expand Up @@ -494,6 +502,47 @@ def map_me_but_slowly(a):
assert lazy_serialized_dag.has_task_concurrency_limits


@pytest.mark.db_test
@pytest.mark.parametrize(
"create_dag_run_kwargs",
(
{},
{
"data_interval": None,
"logical_date": pendulum.DateTime(2016, 1, 1, 0, 0, 0, tzinfo=Timezone("UTC")),
},
{"data_interval": None, "logical_date": None},
),
ids=["post-AIP-39", "pre-AIP-39-should-infer", "pre-AIP-39"],
)
def test_serialized_dag_get_run_data_interval(create_dag_run_kwargs, dag_maker, session):
"""Test whether LazyDeserializedDAG can correctly get dag run data_interval

post-AIP-39: the dag run itself contains both data_interval start and data_interval end, and thus can
be retrieved directly
pre-AIP-39-should-infer: the dag run itself has neither data_interval_start nor data_interval_end,
and thus needs to infer the data_interval from its timetable
pre-AIP-39: the dag run itself has neither data_interval_start nor data_interval_end, and its logical_date
is none. it should return data_interval as none
"""
with dag_maker(dag_id="test_dag", session=session, serialized=True) as dag:
BaseOperator(task_id="test_task")
session.commit()

dr = dag_maker.create_dagrun(**create_dag_run_kwargs)
ser_dict = SerializedDAG.to_dict(dag)
deser_dag = LazyDeserializedDAG(data=ser_dict)
if "logical_date" in create_dag_run_kwargs and create_dag_run_kwargs["logical_date"] is None:
data_interval = deser_dag.get_run_data_interval(dr)
assert data_interval is None
else:
data_interval = deser_dag.get_run_data_interval(dr)
assert data_interval == DataInterval(
start=pendulum.DateTime(2015, 12, 31, 0, 0, 0, tzinfo=Timezone("UTC")),
end=pendulum.DateTime(2016, 1, 1, 0, 0, 0, tzinfo=Timezone("UTC")),
)


def test_get_task_assets():
asset1 = Asset("1")
with DAG("testdag") as source_dag:
Expand All @@ -510,3 +559,28 @@ def test_get_task_assets():
("c", asset1),
("d", asset1),
]


def test_lazy_dag_run_interval_wrong_dag():
lazy = LazyDeserializedDAG(data={"dag": {"dag_id": "dag1"}})

with pytest.raises(ValueError, match="different DAGs"):
lazy.get_run_data_interval(DAG_RUN)


def test_lazy_dag_run_interval_missing_interval():
lazy = LazyDeserializedDAG(data={"dag": {"dag_id": "test_dag_id"}})

with pytest.raises(ValueError, match="Unsure how to deserialize version '<not present>'"):
lazy.get_run_data_interval(DAG_RUN)


def test_lazy_dag_run_interval_success():
run = DAG_RUN
run.data_interval_start = datetime(2025, 1, 1)
run.data_interval_end = datetime(2025, 1, 2)

lazy = LazyDeserializedDAG(data={"dag": {"dag_id": "test_dag_id"}})
interval = lazy.get_run_data_interval(run)

assert isinstance(interval, DataInterval)