diff --git a/airflow/datasets/manager.py b/airflow/datasets/manager.py index 2b009c1b09f23..83539d596540d 100644 --- a/airflow/datasets/manager.py +++ b/airflow/datasets/manager.py @@ -61,7 +61,9 @@ def register_dataset_change( extra=extra, ) ) - self._queue_dagruns(dataset_model, session) + if dataset_model.consuming_dags: + self._queue_dagruns(dataset_model, session) + session.flush() def _queue_dagruns(self, dataset: DatasetModel, session: Session) -> None: # Possible race condition: if multiple dags or multiple (usually @@ -91,8 +93,6 @@ def _slow_path_queue_dagruns(self, dataset: DatasetModel, session: Session) -> N except exc.IntegrityError: self.log.debug("Skipping record %s", item, exc_info=True) - session.flush() - def _postgres_queue_dagruns(self, dataset: DatasetModel, session: Session) -> None: from sqlalchemy.dialects.postgresql import insert @@ -101,7 +101,6 @@ def _postgres_queue_dagruns(self, dataset: DatasetModel, session: Session) -> No stmt, [{'target_dag_id': target_dag.dag_id} for target_dag in dataset.consuming_dags], ) - session.flush() def resolve_dataset_manager() -> "DatasetManager": diff --git a/tests/datasets/test_manager.py b/tests/datasets/test_manager.py index 4ff3b2884740b..42dffd76a11fd 100644 --- a/tests/datasets/test_manager.py +++ b/tests/datasets/test_manager.py @@ -80,3 +80,17 @@ def test_register_dataset_change(self, session, dag_maker, mock_task_instance): # Ensure we've created a dataset assert session.query(DatasetEvent).filter_by(dataset_id=dsm.id).count() == 1 assert session.query(DatasetDagRunQueue).count() == 2 + + def test_register_dataset_change_no_downstreams(self, session, mock_task_instance): + dsem = DatasetManager() + + ds = Dataset(uri="never_consumed") + dsm = DatasetModel(uri="never_consumed") + session.add(dsm) + session.flush() + + dsem.register_dataset_change(task_instance=mock_task_instance, dataset=ds, session=session) + + # Ensure we've created a dataset + assert session.query(DatasetEvent).filter_by(dataset_id=dsm.id).count() == 1 + assert session.query(DatasetDagRunQueue).count() == 0