diff --git a/airflow/jobs/scheduler_job_runner.py b/airflow/jobs/scheduler_job_runner.py index 951602e14e648..ac6bf5812037e 100644 --- a/airflow/jobs/scheduler_job_runner.py +++ b/airflow/jobs/scheduler_job_runner.py @@ -1337,9 +1337,8 @@ def _create_dag_runs(self, dag_models: Collection[DagModel], session: Session) - .all() ) - active_runs_of_dags = Counter( - DagRun.active_runs_of_dags(dag_ids=(dm.dag_id for dm in dag_models), session=session), - ) + dag_ids = (dm.dag_id for dm in dag_models) + active_runs_of_dags = Counter(DagRun.active_runs_of_dags(dag_ids=dag_ids, session=session)) for dag_model in dag_models: dag = self.dagbag.get_dag(dag_model.dag_id, session=session) @@ -1512,11 +1511,11 @@ def _should_update_dag_next_dagruns( if not dag.timetable.can_be_scheduled: return False - # get active dag runs from DB if not available - if not total_active_runs: - total_active_runs = dag.get_num_active_runs(only_running=False, session=session) + if total_active_runs is None: + runs_dict = DagRun.active_runs_of_dags(dag_ids=[dag.dag_id], session=session) + total_active_runs = runs_dict.get(dag.dag_id, 0) - if total_active_runs and total_active_runs >= dag.max_active_runs: + if total_active_runs >= dag.max_active_runs: self.log.info( "DAG %s is at (or above) max_active_runs (%d of %d), not creating any more runs", dag_model.dag_id, diff --git a/airflow/models/dag.py b/airflow/models/dag.py index f8d9f55e56e18..f0a7d7f56be2c 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -1283,28 +1283,6 @@ def get_active_runs(self): return active_dates - @provide_session - def get_num_active_runs(self, external_trigger=None, only_running=True, session=NEW_SESSION): - """ - Return the number of active "running" dag runs. - - :param external_trigger: True for externally triggered active dag runs - :param session: - :return: number greater than 0 for active dag runs - """ - query = select(func.count()).where(DagRun.dag_id == self.dag_id) - if only_running: - query = query.where(DagRun.state == DagRunState.RUNNING) - else: - query = query.where(DagRun.state.in_({DagRunState.RUNNING, DagRunState.QUEUED})) - - if external_trigger is not None: - query = query.where( - DagRun.external_trigger == (expression.true() if external_trigger else expression.false()) - ) - - return session.scalar(query) - @staticmethod @internal_api_call @provide_session diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index cad82e72b8b2d..4fd689d616b47 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -386,21 +386,20 @@ def refresh_from_db(self, session: Session = NEW_SESSION) -> None: @provide_session def active_runs_of_dags( cls, - dag_ids: Iterable[str] | None = None, - only_running: bool = False, + dag_ids: Iterable[str], session: Session = NEW_SESSION, ) -> dict[str, int]: - """Get the number of active dag runs for each dag.""" - query = select(cls.dag_id, func.count("*")) - if dag_ids is not None: - # 'set' called to avoid duplicate dag_ids, but converted back to 'list' - # because SQLAlchemy doesn't accept a set here. - query = query.where(cls.dag_id.in_(set(dag_ids))) - if only_running: - query = query.where(cls.state == DagRunState.RUNNING) - else: - query = query.where(cls.state.in_((DagRunState.RUNNING, DagRunState.QUEUED))) - query = query.group_by(cls.dag_id) + """ + Get the number of active dag runs for each dag. + + :meta private: + """ + query = ( + select(cls.dag_id, func.count("*")) + .where(cls.dag_id.in_(set(dag_ids))) + .where(cls.state.in_((DagRunState.RUNNING, DagRunState.QUEUED))) + .group_by(cls.dag_id) + ) return dict(iter(session.execute(query))) @classmethod diff --git a/newsfragments/43067.significant.rst b/newsfragments/43067.significant.rst new file mode 100644 index 0000000000000..d57fec6be9e72 --- /dev/null +++ b/newsfragments/43067.significant.rst @@ -0,0 +1,4 @@ +Remove DAG.get_num_active_runs + +We don't need this function. There's already an almost-identical function on DagRun that we can use, namely DagRun.active_runs_of_dags. +Also, make DagRun.active_runs_of_dags private. diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 01d1c5fe7a331..13ccfc74a0ffb 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -4477,14 +4477,14 @@ def complete_one_dagrun(): model: DagModel = session.get(DagModel, dag.dag_id) # Pre-condition - assert DagRun.active_runs_of_dags(session=session) == {"test_dag": 3} + assert DagRun.active_runs_of_dags(dag_ids=["test_dag"], session=session) == {"test_dag": 3} assert model.next_dagrun == timezone.DateTime(2016, 1, 3, tzinfo=UTC) assert model.next_dagrun_create_after is None complete_one_dagrun() - assert DagRun.active_runs_of_dags(session=session) == {"test_dag": 3} + assert DagRun.active_runs_of_dags(dag_ids=["test_dag"], session=session) == {"test_dag": 3} for _ in range(5): self.job_runner._do_scheduling(session)