diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 4e70b8781705b..622c5b2a11b95 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -73,7 +73,7 @@ update, ) from sqlalchemy.ext.associationproxy import association_proxy -from sqlalchemy.orm import backref, joinedload, relationship +from sqlalchemy.orm import backref, joinedload, load_only, relationship from sqlalchemy.sql import Select, expression import airflow.templates @@ -3063,27 +3063,13 @@ def bulk_write_to_db( session.add(orm_dag) orm_dags.append(orm_dag) - dag_id_to_last_automated_run: dict[str, DagRun] = {} + latest_runs: dict[str, DagRun] = {} num_active_runs: dict[str, int] = {} # Skip these queries entirely if no DAGs can be scheduled to save time. if any(dag.timetable.can_be_scheduled for dag in dags): # Get the latest automated dag run for each existing dag as a single query (avoid n+1 query) - last_automated_runs_subq = ( - select(DagRun.dag_id, func.max(DagRun.execution_date).label("max_execution_date")) - .where( - DagRun.dag_id.in_(existing_dags), - or_(DagRun.run_type == DagRunType.BACKFILL_JOB, DagRun.run_type == DagRunType.SCHEDULED), - ) - .group_by(DagRun.dag_id) - .subquery() - ) - last_automated_runs = session.scalars( - select(DagRun).where( - DagRun.dag_id == last_automated_runs_subq.c.dag_id, - DagRun.execution_date == last_automated_runs_subq.c.max_execution_date, - ) - ) - dag_id_to_last_automated_run = {run.dag_id: run for run in last_automated_runs} + query = cls._get_latest_runs_query(existing_dags, session) + latest_runs = {run.dag_id: run for run in session.scalars(query)} # Get number of active dagruns for all dags we are processing as a single query. num_active_runs = DagRun.active_runs_of_dags(dag_ids=existing_dags, session=session) @@ -3117,7 +3103,7 @@ def bulk_write_to_db( orm_dag.timetable_description = dag.timetable.description orm_dag.processor_subdir = processor_subdir - last_automated_run: DagRun | None = dag_id_to_last_automated_run.get(dag.dag_id) + last_automated_run: DagRun | None = latest_runs.get(dag.dag_id) if last_automated_run is None: last_automated_data_interval = None else: @@ -3254,6 +3240,51 @@ def bulk_write_to_db( for dag in dags: cls.bulk_write_to_db(dag.subdags, processor_subdir=processor_subdir, session=session) + @classmethod + def _get_latest_runs_query(cls, dags, session) -> Query: + """ + Query the database to retrieve the last automated run for each dag. + + :param dags: dags to query + :param session: sqlalchemy session object + """ + if len(dags) == 1: + # Index optimized fast path to avoid more complicated & slower groupby queryplan + existing_dag_id = list(dags)[0].dag_id + last_automated_runs_subq = ( + select(func.max(DagRun.execution_date).label("max_execution_date")) + .where( + DagRun.dag_id == existing_dag_id, + DagRun.run_type.in_((DagRunType.BACKFILL_JOB, DagRunType.SCHEDULED)), + ) + .subquery() + ) + query = select(DagRun).where( + DagRun.dag_id == existing_dag_id, DagRun.execution_date == last_automated_runs_subq + ) + else: + last_automated_runs_subq = ( + select(DagRun.dag_id, func.max(DagRun.execution_date).label("max_execution_date")) + .where( + DagRun.dag_id.in_(dags), + DagRun.run_type.in_((DagRunType.BACKFILL_JOB, DagRunType.SCHEDULED)), + ) + .group_by(DagRun.dag_id) + .subquery() + ) + query = select(DagRun).where( + DagRun.dag_id == last_automated_runs_subq.c.dag_id, + DagRun.execution_date == last_automated_runs_subq.c.max_execution_date, + ) + return query.options( + load_only( + DagRun.dag_id, + DagRun.execution_date, + DagRun.data_interval_start, + DagRun.data_interval_end, + ) + ) + @provide_session def sync_to_db(self, processor_subdir: str | None = None, session=NEW_SESSION): """ diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index f367b00abe64e..7c337ed965c28 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -952,6 +952,59 @@ def test_bulk_write_to_db(self): for row in session.query(DagModel.last_parsed_time).all(): assert row[0] is not None + def test_bulk_write_to_db_single_dag(self): + """ + Test bulk_write_to_db for a single dag using the index optimized query + """ + clear_db_dags() + dags = [DAG(f"dag-bulk-sync-{i}", start_date=DEFAULT_DATE, tags=["test-dag"]) for i in range(1)] + + with assert_queries_count(5): + DAG.bulk_write_to_db(dags) + with create_session() as session: + assert {"dag-bulk-sync-0"} == {row[0] for row in session.query(DagModel.dag_id).all()} + assert { + ("dag-bulk-sync-0", "test-dag"), + } == set(session.query(DagTag.dag_id, DagTag.name).all()) + + for row in session.query(DagModel.last_parsed_time).all(): + assert row[0] is not None + + # Re-sync should do fewer queries + with assert_queries_count(8): + DAG.bulk_write_to_db(dags) + with assert_queries_count(8): + DAG.bulk_write_to_db(dags) + + def test_bulk_write_to_db_multiple_dags(self): + """ + Test bulk_write_to_db for multiple dags which does not use the index optimized query + """ + clear_db_dags() + dags = [DAG(f"dag-bulk-sync-{i}", start_date=DEFAULT_DATE, tags=["test-dag"]) for i in range(4)] + + with assert_queries_count(5): + DAG.bulk_write_to_db(dags) + with create_session() as session: + assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == { + row[0] for row in session.query(DagModel.dag_id).all() + } + assert { + ("dag-bulk-sync-0", "test-dag"), + ("dag-bulk-sync-1", "test-dag"), + ("dag-bulk-sync-2", "test-dag"), + ("dag-bulk-sync-3", "test-dag"), + } == set(session.query(DagTag.dag_id, DagTag.name).all()) + + for row in session.query(DagModel.last_parsed_time).all(): + assert row[0] is not None + + # Re-sync should do fewer queries + with assert_queries_count(8): + DAG.bulk_write_to_db(dags) + with assert_queries_count(8): + DAG.bulk_write_to_db(dags) + @pytest.mark.parametrize("interval", [None, "@daily"]) def test_bulk_write_to_db_interval_save_runtime(self, interval): mock_active_runs_of_dags = mock.MagicMock(side_effect=DagRun.active_runs_of_dags) @@ -4082,3 +4135,36 @@ def test_validate_setup_teardown_trigger_rule(self): Exception, match="Setup tasks must be followed with trigger rule ALL_SUCCESS." ): dag.validate_setup_teardown() + + +def test_get_latest_runs_query_one_dag(dag_maker, session): + with dag_maker(dag_id="dag1") as dag1: + ... + query = DAG._get_latest_runs_query(dags=[dag1], session=session) + actual = [x.strip() for x in str(query.compile()).splitlines()] + expected = [ + "SELECT dag_run.id, dag_run.dag_id, dag_run.execution_date, dag_run.data_interval_start, dag_run.data_interval_end", + "FROM dag_run", + "WHERE dag_run.dag_id = :dag_id_1 AND dag_run.execution_date = (SELECT max(dag_run.execution_date) AS max_execution_date", + "FROM dag_run", + "WHERE dag_run.dag_id = :dag_id_2 AND dag_run.run_type IN (__[POSTCOMPILE_run_type_1]))", + ] + assert actual == expected + + +def test_get_latest_runs_query_two_dags(dag_maker, session): + with dag_maker(dag_id="dag1") as dag1: + ... + with dag_maker(dag_id="dag2") as dag2: + ... + query = DAG._get_latest_runs_query(dags=[dag1, dag2], session=session) + actual = [x.strip() for x in str(query.compile()).splitlines()] + print("\n".join(actual)) + expected = [ + "SELECT dag_run.id, dag_run.dag_id, dag_run.execution_date, dag_run.data_interval_start, dag_run.data_interval_end", + "FROM dag_run, (SELECT dag_run.dag_id AS dag_id, max(dag_run.execution_date) AS max_execution_date", + "FROM dag_run", + "WHERE dag_run.dag_id IN (__[POSTCOMPILE_dag_id_1]) AND dag_run.run_type IN (__[POSTCOMPILE_run_type_1]) GROUP BY dag_run.dag_id) AS anon_1", + "WHERE dag_run.dag_id = anon_1.dag_id AND dag_run.execution_date = anon_1.max_execution_date", + ] + assert actual == expected