diff --git a/airflow/providers/common/sql/operators/sql.py b/airflow/providers/common/sql/operators/sql.py index 979f815c5bbc0..82499ad3c4c1e 100644 --- a/airflow/providers/common/sql/operators/sql.py +++ b/airflow/providers/common/sql/operators/sql.py @@ -827,10 +827,7 @@ def __init__( self.tol = tol if isinstance(tol, float) else None self.has_tolerance = self.tol is not None - def execute(self, context: Context): - self.log.info("Executing SQL check: %s", self.sql) - records = self.get_db_hook().get_first(self.sql) - + def check_value(self, records): if not records: self._raise_exception(f"The following query returned zero rows: {self.sql}") @@ -862,6 +859,11 @@ def execute(self, context: Context): if not all(tests): self._raise_exception(error_msg) + def execute(self, context: Context): + self.log.info("Executing SQL check: %s", self.sql) + records = self.get_db_hook().get_first(self.sql) + self.check_value(records) + def _to_float(self, records): return [float(record) for record in records] diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index 550c4096aa540..4a17c41ba2cc8 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -443,6 +443,10 @@ def execute(self, context: Context) -> None: # type: ignore[override] method_name="execute_complete", ) self._handle_job_error(job) + # job.result() returns a RowIterator. Mypy expects an instance of SupportsNext[Any] for + # the next() call which the RowIterator does not resemble to. Hence, ignore the arg-type error. + records = next(job.result()) # type: ignore[arg-type] + self.check_value(records) self.log.info("Current state of job %s is %s", job.job_id, job.state) @staticmethod diff --git a/tests/providers/google/cloud/operators/test_bigquery.py b/tests/providers/google/cloud/operators/test_bigquery.py index da61efa557c54..b0759068ee1aa 100644 --- a/tests/providers/google/cloud/operators/test_bigquery.py +++ b/tests/providers/google/cloud/operators/test_bigquery.py @@ -1919,11 +1919,11 @@ def test_bigquery_value_check_async(self, mock_hook, create_task_instance_of_ope exc.value.trigger, BigQueryValueCheckTrigger ), "Trigger is not a BigQueryValueCheckTrigger" - @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryValueCheckOperator.execute") @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryValueCheckOperator.defer") + @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryValueCheckOperator.check_value") @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") def test_bigquery_value_check_operator_async_finish_before_deferred( - self, mock_hook, mock_defer, mock_execute, create_task_instance_of_operator + self, mock_hook, mock_check_value, mock_defer, create_task_instance_of_operator ): job_id = "123456" hash_ = "hash" @@ -1944,7 +1944,7 @@ def test_bigquery_value_check_operator_async_finish_before_deferred( ti.task.execute(MagicMock()) assert not mock_defer.called - assert mock_execute.called + assert mock_check_value.called @pytest.mark.parametrize( "kwargs, expected",