From f0ba0ac2a40befd7d3f2d793fb01d0468e57721d Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Thu, 8 Sep 2022 19:17:22 -0700 Subject: [PATCH] Don't blow up when a task produces a dataset that is not consumed. If you had a dataset outlut on a task, and no DAG was recorded as consuming that dataset it failed with a null value constraint violation in the db --- airflow/datasets/manager.py | 7 +++---- tests/datasets/test_manager.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/airflow/datasets/manager.py b/airflow/datasets/manager.py index 2b009c1b09f23..83539d596540d 100644 --- a/airflow/datasets/manager.py +++ b/airflow/datasets/manager.py @@ -61,7 +61,9 @@ def register_dataset_change( extra=extra, ) ) - self._queue_dagruns(dataset_model, session) + if dataset_model.consuming_dags: + self._queue_dagruns(dataset_model, session) + session.flush() def _queue_dagruns(self, dataset: DatasetModel, session: Session) -> None: # Possible race condition: if multiple dags or multiple (usually @@ -91,8 +93,6 @@ def _slow_path_queue_dagruns(self, dataset: DatasetModel, session: Session) -> N except exc.IntegrityError: self.log.debug("Skipping record %s", item, exc_info=True) - session.flush() - def _postgres_queue_dagruns(self, dataset: DatasetModel, session: Session) -> None: from sqlalchemy.dialects.postgresql import insert @@ -101,7 +101,6 @@ def _postgres_queue_dagruns(self, dataset: DatasetModel, session: Session) -> No stmt, [{'target_dag_id': target_dag.dag_id} for target_dag in dataset.consuming_dags], ) - session.flush() def resolve_dataset_manager() -> "DatasetManager": diff --git a/tests/datasets/test_manager.py b/tests/datasets/test_manager.py index 4ff3b2884740b..42dffd76a11fd 100644 --- a/tests/datasets/test_manager.py +++ b/tests/datasets/test_manager.py @@ -80,3 +80,17 @@ def test_register_dataset_change(self, session, dag_maker, mock_task_instance): # Ensure we've created a dataset assert session.query(DatasetEvent).filter_by(dataset_id=dsm.id).count() == 1 assert session.query(DatasetDagRunQueue).count() == 2 + + def test_register_dataset_change_no_downstreams(self, session, mock_task_instance): + dsem = DatasetManager() + + ds = Dataset(uri="never_consumed") + dsm = DatasetModel(uri="never_consumed") + session.add(dsm) + session.flush() + + dsem.register_dataset_change(task_instance=mock_task_instance, dataset=ds, session=session) + + # Ensure we've created a dataset + assert session.query(DatasetEvent).filter_by(dataset_id=dsm.id).count() == 1 + assert session.query(DatasetDagRunQueue).count() == 0