From f340b32d1eeeb8f3e4d10953bde2ebc51aadb4a1 Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Fri, 1 Sep 2023 22:28:03 +0530 Subject: [PATCH] Fix BigQueryValueCheckOperator deferrable mode optimisation PR #31872 tried to optimise the deferrable mode in BigQueryValueCheckOperator. However for deciding on whether to defer it just checked the job status but did not actually verified the passed value to check for and returned a success prematurely. This PR adds on the missing logic with the optimisation to check and compare the pass value and tolerations. closes: #34010 --- airflow/providers/common/sql/operators/sql.py | 10 ++++++---- airflow/providers/google/cloud/operators/bigquery.py | 4 ++++ .../providers/google/cloud/operators/test_bigquery.py | 6 +++--- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/airflow/providers/common/sql/operators/sql.py b/airflow/providers/common/sql/operators/sql.py index 979f815c5bbc0..82499ad3c4c1e 100644 --- a/airflow/providers/common/sql/operators/sql.py +++ b/airflow/providers/common/sql/operators/sql.py @@ -827,10 +827,7 @@ def __init__( self.tol = tol if isinstance(tol, float) else None self.has_tolerance = self.tol is not None - def execute(self, context: Context): - self.log.info("Executing SQL check: %s", self.sql) - records = self.get_db_hook().get_first(self.sql) - + def check_value(self, records): if not records: self._raise_exception(f"The following query returned zero rows: {self.sql}") @@ -862,6 +859,11 @@ def execute(self, context: Context): if not all(tests): self._raise_exception(error_msg) + def execute(self, context: Context): + self.log.info("Executing SQL check: %s", self.sql) + records = self.get_db_hook().get_first(self.sql) + self.check_value(records) + def _to_float(self, records): return [float(record) for record in records] diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index 550c4096aa540..4a17c41ba2cc8 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -443,6 +443,10 @@ def execute(self, context: Context) -> None: # type: ignore[override] method_name="execute_complete", ) self._handle_job_error(job) + # job.result() returns a RowIterator. Mypy expects an instance of SupportsNext[Any] for + # the next() call which the RowIterator does not resemble to. Hence, ignore the arg-type error. + records = next(job.result()) # type: ignore[arg-type] + self.check_value(records) self.log.info("Current state of job %s is %s", job.job_id, job.state) @staticmethod diff --git a/tests/providers/google/cloud/operators/test_bigquery.py b/tests/providers/google/cloud/operators/test_bigquery.py index da61efa557c54..b0759068ee1aa 100644 --- a/tests/providers/google/cloud/operators/test_bigquery.py +++ b/tests/providers/google/cloud/operators/test_bigquery.py @@ -1919,11 +1919,11 @@ def test_bigquery_value_check_async(self, mock_hook, create_task_instance_of_ope exc.value.trigger, BigQueryValueCheckTrigger ), "Trigger is not a BigQueryValueCheckTrigger" - @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryValueCheckOperator.execute") @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryValueCheckOperator.defer") + @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryValueCheckOperator.check_value") @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") def test_bigquery_value_check_operator_async_finish_before_deferred( - self, mock_hook, mock_defer, mock_execute, create_task_instance_of_operator + self, mock_hook, mock_check_value, mock_defer, create_task_instance_of_operator ): job_id = "123456" hash_ = "hash" @@ -1944,7 +1944,7 @@ def test_bigquery_value_check_operator_async_finish_before_deferred( ti.task.execute(MagicMock()) assert not mock_defer.called - assert mock_execute.called + assert mock_check_value.called @pytest.mark.parametrize( "kwargs, expected",