diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 688faa648d589..198b12258ce91 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -3785,59 +3785,58 @@ def dags_needing_dagruns(cls, session: Session) -> tuple[Query, dict[str, tuple[ """ from airflow.models.serialized_dag import SerializedDagModel + NUM_DAGS_PER_DAGRUN_QUERY = cls.NUM_DAGS_PER_DAGRUN_QUERY + dataset_triggered_dag_info = {} + def dag_ready(dag_id: str, cond: BaseDatasetEventInput, statuses: dict) -> bool | None: # if dag was serialized before 2.9 and we *just* upgraded, - # we may be dealing with old version. In that case, + # we may be dealing with old version. In that case, # just wait for the dag to be reserialized. try: return cond.evaluate(statuses) except AttributeError: - log.warning("dag '%s' has old serialization; skipping DAG run creation.", dag_id) + logging.warning("dag '%s' has old serialization; skipping DAG run creation.", dag_id) return None - # this loads all the DDRQ records.... may need to limit num dags - all_records = session.scalars(select(DatasetDagRunQueue)).all() - by_dag = defaultdict(list) - for r in all_records: - by_dag[r.target_dag_id].append(r) - del all_records - dag_statuses = {} - for dag_id, records in by_dag.items(): - dag_statuses[dag_id] = {x.dataset.uri: True for x in records} - ser_dags = session.scalars( - select(SerializedDagModel).where(SerializedDagModel.dag_id.in_(dag_statuses.keys())) - ).all() - for ser_dag in ser_dags: - dag_id = ser_dag.dag_id - statuses = dag_statuses[dag_id] - if not dag_ready(dag_id, cond=ser_dag.dag.dataset_triggers, statuses=statuses): - del by_dag[dag_id] - del dag_statuses[dag_id] - del dag_statuses - dataset_triggered_dag_info = {} - for dag_id, records in by_dag.items(): - times = sorted(x.created_at for x in records) - dataset_triggered_dag_info[dag_id] = (times[0], times[-1]) - del by_dag - dataset_triggered_dag_ids = set(dataset_triggered_dag_info.keys()) - if dataset_triggered_dag_ids: - exclusion_list = set( - session.scalars( - select(DagModel.dag_id) - .join(DagRun.dag_model) - .where(DagRun.state.in_((DagRunState.QUEUED, DagRunState.RUNNING))) - .where(DagModel.dag_id.in_(dataset_triggered_dag_ids)) - .group_by(DagModel.dag_id) - .having(func.count() >= func.max(DagModel.max_active_runs)) - ) + dag_statuses: dict[str, dict[str, bool]] = defaultdict(dict) + + # Get distinct target_dag_id from DatasetDagRunQueue + distinct_dag_ids_subq = session.query(DatasetDagRunQueue.target_dag_id).distinct().subquery() + + # Process in batches using NUM_DAGS_PER_DAGRUN_QUERY + batch_offset = 0 + while True: + batch = ( + session.query(distinct_dag_ids_subq.c.target_dag_id) + .order_by(distinct_dag_ids_subq.c.target_dag_id) + .limit(NUM_DAGS_PER_DAGRUN_QUERY) + .offset(batch_offset) + .all() ) - if exclusion_list: - dataset_triggered_dag_ids -= exclusion_list - dataset_triggered_dag_info = { - k: v for k, v in dataset_triggered_dag_info.items() if k not in exclusion_list + + if not batch: + break # Exit loop if no more batches + + batch_dag_ids = [row[0] for row in batch] + batch_offset += NUM_DAGS_PER_DAGRUN_QUERY + + for dag_id in batch_dag_ids: + # Populate dag_statuses for the current batch + dag_statuses[dag_id] = { + record.dataset.uri: True + for record in session.query(DatasetDagRunQueue).filter_by(target_dag_id=dag_id).all() } - # We limit so that _one_ scheduler doesn't try to do all the creation of dag runs + ser_dag = session.query(SerializedDagModel).filter_by(dag_id=dag_id).first() + if ser_dag and dag_ready(dag_id, ser_dag.dag.dataset_triggers, dag_statuses[dag_id]): + # The dag is ready, note down the times for dataset_triggered_dag_info + times = [ + record.created_at + for record in session.query(DatasetDagRunQueue).filter_by(target_dag_id=dag_id).all() + ] + if times: # Ensure times list is not empty + dataset_triggered_dag_info[dag_id] = (min(times), max(times)) + query = ( select(cls) .where( @@ -3846,7 +3845,7 @@ def dag_ready(dag_id: str, cond: BaseDatasetEventInput, statuses: dict) -> bool cls.has_import_errors == expression.false(), or_( cls.next_dagrun_create_after <= func.now(), - cls.dag_id.in_(dataset_triggered_dag_ids), + cls.dag_id.in_(dataset_triggered_dag_info.keys()), ), ) .order_by(cls.next_dagrun_create_after)