From 151ba7549350d0d2600c2d9929d19db85f0c98a1 Mon Sep 17 00:00:00 2001 From: Dmytro Kazanzhy Date: Thu, 24 Feb 2022 00:29:41 +0200 Subject: [PATCH] Suppress hook warnings from the Bigquery transfers --- .../cloud/transfers/bigquery_to_bigquery.py | 21 +-- .../google/cloud/transfers/gcs_to_bigquery.py | 133 ++++++++++-------- .../transfers/test_bigquery_to_bigquery.py | 2 +- .../cloud/transfers/test_gcs_to_bigquery.py | 113 ++++++--------- 4 files changed, 130 insertions(+), 139 deletions(-) diff --git a/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py b/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py index d439062cf81b6..527033ac433d5 100644 --- a/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +++ b/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py @@ -127,13 +127,14 @@ def execute(self, context: 'Context') -> None: location=self.location, impersonation_chain=self.impersonation_chain, ) - conn = hook.get_conn() - cursor = conn.cursor() - cursor.run_copy( - source_project_dataset_tables=self.source_project_dataset_tables, - destination_project_dataset_table=self.destination_project_dataset_table, - write_disposition=self.write_disposition, - create_disposition=self.create_disposition, - labels=self.labels, - encryption_configuration=self.encryption_configuration, - ) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + hook.run_copy( + source_project_dataset_tables=self.source_project_dataset_tables, + destination_project_dataset_table=self.destination_project_dataset_table, + write_disposition=self.write_disposition, + create_disposition=self.create_disposition, + labels=self.labels, + encryption_configuration=self.encryption_configuration, + ) diff --git a/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py b/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py index b5b91d5e43b91..a58f01f2fe980 100644 --- a/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +++ b/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py @@ -18,6 +18,7 @@ """This module contains a Google Cloud Storage to BigQuery operator.""" import json +import warnings from typing import TYPE_CHECKING, Optional, Sequence, Union from airflow.models import BaseOperator @@ -163,8 +164,9 @@ def __init__( allow_jagged_rows=False, encoding="UTF-8", max_id_key=None, - bigquery_conn_id='google_cloud_default', - google_cloud_storage_conn_id='google_cloud_default', + gcp_conn_id='google_cloud_default', + bigquery_conn_id=None, + google_cloud_storage_conn_id=None, delegate_to=None, schema_update_options=(), src_fmt_configs=None, @@ -179,6 +181,15 @@ def __init__( description=None, **kwargs, ): + # To preserve backward compatibility. Remove one day + if bigquery_conn_id or google_cloud_storage_conn_id: + warnings.warn( + "The bigquery_conn_id and google_cloud_storage_conn_id parameters have been deprecated. " + "You should pass only gcp_conn_id parameter. " + "Will be used bigquery_conn_id or google_cloud_storage_conn_id if gcp_conn_id not passed.", + DeprecationWarning, + stacklevel=2, + ) super().__init__(**kwargs) @@ -209,8 +220,7 @@ def __init__( self.encoding = encoding self.max_id_key = max_id_key - self.bigquery_conn_id = bigquery_conn_id - self.google_cloud_storage_conn_id = google_cloud_storage_conn_id + self.gcp_conn_id = gcp_conn_id or bigquery_conn_id or google_cloud_storage_conn_id self.delegate_to = delegate_to self.schema_update_options = schema_update_options @@ -227,7 +237,7 @@ def __init__( def execute(self, context: 'Context'): bq_hook = BigQueryHook( - bigquery_conn_id=self.bigquery_conn_id, + gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, location=self.location, impersonation_chain=self.impersonation_chain, @@ -236,7 +246,7 @@ def execute(self, context: 'Context'): if not self.schema_fields: if self.schema_object and self.source_format != 'DATASTORE_BACKUP': gcs_hook = GCSHook( - gcp_conn_id=self.google_cloud_storage_conn_id, + gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, impersonation_chain=self.impersonation_chain, ) @@ -247,7 +257,6 @@ def execute(self, context: 'Context'): schema_fields = json.loads(blob.decode("utf-8")) else: schema_fields = None - else: schema_fields = self.schema_fields @@ -255,64 +264,66 @@ def execute(self, context: 'Context'): self.source_objects if isinstance(self.source_objects, list) else [self.source_objects] ) source_uris = [f'gs://{self.bucket}/{source_object}' for source_object in self.source_objects] - conn = bq_hook.get_conn() - cursor = conn.cursor() if self.external_table: - cursor.create_external_table( - external_project_dataset_table=self.destination_project_dataset_table, - schema_fields=schema_fields, - source_uris=source_uris, - source_format=self.source_format, - compression=self.compression, - skip_leading_rows=self.skip_leading_rows, - field_delimiter=self.field_delimiter, - max_bad_records=self.max_bad_records, - quote_character=self.quote_character, - ignore_unknown_values=self.ignore_unknown_values, - allow_quoted_newlines=self.allow_quoted_newlines, - allow_jagged_rows=self.allow_jagged_rows, - encoding=self.encoding, - src_fmt_configs=self.src_fmt_configs, - encryption_configuration=self.encryption_configuration, - labels=self.labels, - description=self.description, - ) - else: - cursor.run_load( - destination_project_dataset_table=self.destination_project_dataset_table, - schema_fields=schema_fields, - source_uris=source_uris, - source_format=self.source_format, - autodetect=self.autodetect, - create_disposition=self.create_disposition, - skip_leading_rows=self.skip_leading_rows, - write_disposition=self.write_disposition, - field_delimiter=self.field_delimiter, - max_bad_records=self.max_bad_records, - quote_character=self.quote_character, - ignore_unknown_values=self.ignore_unknown_values, - allow_quoted_newlines=self.allow_quoted_newlines, - allow_jagged_rows=self.allow_jagged_rows, - encoding=self.encoding, - schema_update_options=self.schema_update_options, - src_fmt_configs=self.src_fmt_configs, - time_partitioning=self.time_partitioning, - cluster_fields=self.cluster_fields, - encryption_configuration=self.encryption_configuration, - labels=self.labels, - description=self.description, - ) - - if cursor.use_legacy_sql: - escaped_table_name = f'[{self.destination_project_dataset_table}]' + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + bq_hook.create_external_table( + external_project_dataset_table=self.destination_project_dataset_table, + schema_fields=schema_fields, + source_uris=source_uris, + source_format=self.source_format, + compression=self.compression, + skip_leading_rows=self.skip_leading_rows, + field_delimiter=self.field_delimiter, + max_bad_records=self.max_bad_records, + quote_character=self.quote_character, + ignore_unknown_values=self.ignore_unknown_values, + allow_quoted_newlines=self.allow_quoted_newlines, + allow_jagged_rows=self.allow_jagged_rows, + encoding=self.encoding, + src_fmt_configs=self.src_fmt_configs, + encryption_configuration=self.encryption_configuration, + labels=self.labels, + description=self.description, + ) else: - escaped_table_name = f'`{self.destination_project_dataset_table}`' + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + bq_hook.run_load( + destination_project_dataset_table=self.destination_project_dataset_table, + schema_fields=schema_fields, + source_uris=source_uris, + source_format=self.source_format, + autodetect=self.autodetect, + create_disposition=self.create_disposition, + skip_leading_rows=self.skip_leading_rows, + write_disposition=self.write_disposition, + field_delimiter=self.field_delimiter, + max_bad_records=self.max_bad_records, + quote_character=self.quote_character, + ignore_unknown_values=self.ignore_unknown_values, + allow_quoted_newlines=self.allow_quoted_newlines, + allow_jagged_rows=self.allow_jagged_rows, + encoding=self.encoding, + schema_update_options=self.schema_update_options, + src_fmt_configs=self.src_fmt_configs, + time_partitioning=self.time_partitioning, + cluster_fields=self.cluster_fields, + encryption_configuration=self.encryption_configuration, + labels=self.labels, + description=self.description, + ) if self.max_id_key: - select_command = f'SELECT MAX({self.max_id_key}) FROM {escaped_table_name}' - cursor.execute(select_command) - row = cursor.fetchone() + select_command = f'SELECT MAX({self.max_id_key}) FROM `{self.destination_project_dataset_table}`' + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + job_id = bq_hook.run_query( + sql=select_command, + use_legacy_sql=False, + ) + row = list(bq_hook.get_job(job_id).result()) if row: max_id = row[0] if row[0] else 0 self.log.info( @@ -322,4 +333,4 @@ def execute(self, context: 'Context'): max_id, ) else: - raise RuntimeError(f"The f{select_command} returned no rows!") + raise RuntimeError(f"The {select_command} returned no rows!") diff --git a/tests/providers/google/cloud/transfers/test_bigquery_to_bigquery.py b/tests/providers/google/cloud/transfers/test_bigquery_to_bigquery.py index b3c0acd8ed510..a3995461ffb78 100644 --- a/tests/providers/google/cloud/transfers/test_bigquery_to_bigquery.py +++ b/tests/providers/google/cloud/transfers/test_bigquery_to_bigquery.py @@ -47,7 +47,7 @@ def test_execute(self, mock_hook): ) operator.execute(None) - mock_hook.return_value.get_conn.return_value.cursor.return_value.run_copy.assert_called_once_with( + mock_hook.return_value.run_copy.assert_called_once_with( source_project_dataset_tables=source_project_dataset_tables, destination_project_dataset_table=destination_project_dataset_table, write_disposition=write_disposition, diff --git a/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py b/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py index 7ef49ba32b6ed..43307975352f4 100644 --- a/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py +++ b/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py @@ -32,25 +32,6 @@ class TestGCSToBigQueryOperator(unittest.TestCase): - @mock.patch('airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook') - def test_execute_explicit_project_legacy(self, bq_hook): - operator = GCSToBigQueryOperator( - task_id=TASK_ID, - bucket=TEST_BUCKET, - source_objects=TEST_SOURCE_OBJECTS, - destination_project_dataset_table=TEST_EXPLICIT_DEST, - max_id_key=MAX_ID_KEY, - ) - - # using legacy SQL - bq_hook.return_value.get_conn.return_value.cursor.return_value.use_legacy_sql = True - - operator.execute(None) - - bq_hook.return_value.get_conn.return_value.cursor.return_value.execute.assert_called_once_with( - "SELECT MAX(id) FROM [test-project.dataset.table]" - ) - @mock.patch('airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook') def test_execute_explicit_project(self, bq_hook): operator = GCSToBigQueryOperator( @@ -61,13 +42,13 @@ def test_execute_explicit_project(self, bq_hook): max_id_key=MAX_ID_KEY, ) - # using non-legacy SQL - bq_hook.return_value.get_conn.return_value.cursor.return_value.use_legacy_sql = False + bq_hook.return_value.get_job.return_value.result.return_value = ('1',) operator.execute(None) - bq_hook.return_value.get_conn.return_value.cursor.return_value.execute.assert_called_once_with( - "SELECT MAX(id) FROM `test-project.dataset.table`" + bq_hook.return_value.run_query.assert_called_once_with( + sql="SELECT MAX(id) FROM `test-project.dataset.table`", + use_legacy_sql=False, ) @mock.patch('airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook') @@ -83,7 +64,7 @@ def test_labels(self, bq_hook): operator.execute(None) - bq_hook.return_value.get_conn.return_value.cursor.return_value.run_load.assert_called_once_with( + bq_hook.return_value.run_load.assert_called_once_with( destination_project_dataset_table=mock.ANY, schema_fields=mock.ANY, source_uris=mock.ANY, @@ -121,7 +102,7 @@ def test_description(self, bq_hook): operator.execute(None) - bq_hook.return_value.get_conn.return_value.cursor.return_value.run_load.assert_called_once_with( + bq_hook.return_value.run_load.assert_called_once_with( destination_project_dataset_table=mock.ANY, schema_fields=mock.ANY, source_uris=mock.ANY, @@ -160,26 +141,25 @@ def test_labels_external_table(self, bq_hook): operator.execute(None) # fmt: off - bq_hook.return_value.get_conn.return_value.cursor.return_value.create_external_table. \ - assert_called_once_with( - external_project_dataset_table=mock.ANY, - schema_fields=mock.ANY, - source_uris=mock.ANY, - source_format=mock.ANY, - compression=mock.ANY, - skip_leading_rows=mock.ANY, - field_delimiter=mock.ANY, - max_bad_records=mock.ANY, - quote_character=mock.ANY, - ignore_unknown_values=mock.ANY, - allow_quoted_newlines=mock.ANY, - allow_jagged_rows=mock.ANY, - encoding=mock.ANY, - src_fmt_configs=mock.ANY, - encryption_configuration=mock.ANY, - labels=LABELS, - description=mock.ANY, - ) + bq_hook.return_value.create_external_table.assert_called_once_with( + external_project_dataset_table=mock.ANY, + schema_fields=mock.ANY, + source_uris=mock.ANY, + source_format=mock.ANY, + compression=mock.ANY, + skip_leading_rows=mock.ANY, + field_delimiter=mock.ANY, + max_bad_records=mock.ANY, + quote_character=mock.ANY, + ignore_unknown_values=mock.ANY, + allow_quoted_newlines=mock.ANY, + allow_jagged_rows=mock.ANY, + encoding=mock.ANY, + src_fmt_configs=mock.ANY, + encryption_configuration=mock.ANY, + labels=LABELS, + description=mock.ANY, + ) # fmt: on @mock.patch('airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook') @@ -196,26 +176,25 @@ def test_description_external_table(self, bq_hook): operator.execute(None) # fmt: off - bq_hook.return_value.get_conn.return_value.cursor.return_value.create_external_table. \ - assert_called_once_with( - external_project_dataset_table=mock.ANY, - schema_fields=mock.ANY, - source_uris=mock.ANY, - source_format=mock.ANY, - compression=mock.ANY, - skip_leading_rows=mock.ANY, - field_delimiter=mock.ANY, - max_bad_records=mock.ANY, - quote_character=mock.ANY, - ignore_unknown_values=mock.ANY, - allow_quoted_newlines=mock.ANY, - allow_jagged_rows=mock.ANY, - encoding=mock.ANY, - src_fmt_configs=mock.ANY, - encryption_configuration=mock.ANY, - labels=mock.ANY, - description=DESCRIPTION, - ) + bq_hook.return_value.create_external_table.assert_called_once_with( + external_project_dataset_table=mock.ANY, + schema_fields=mock.ANY, + source_uris=mock.ANY, + source_format=mock.ANY, + compression=mock.ANY, + skip_leading_rows=mock.ANY, + field_delimiter=mock.ANY, + max_bad_records=mock.ANY, + quote_character=mock.ANY, + ignore_unknown_values=mock.ANY, + allow_quoted_newlines=mock.ANY, + allow_jagged_rows=mock.ANY, + encoding=mock.ANY, + src_fmt_configs=mock.ANY, + encryption_configuration=mock.ANY, + labels=mock.ANY, + description=DESCRIPTION, + ) # fmt: on @mock.patch('airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook') @@ -229,7 +208,7 @@ def test_source_objects_as_list(self, bq_hook): operator.execute(None) - bq_hook.return_value.get_conn.return_value.cursor.return_value.run_load.assert_called_once_with( + bq_hook.return_value.run_load.assert_called_once_with( destination_project_dataset_table=mock.ANY, schema_fields=mock.ANY, source_uris=[f'gs://{TEST_BUCKET}/{source_object}' for source_object in TEST_SOURCE_OBJECTS], @@ -265,7 +244,7 @@ def test_source_objects_as_string(self, bq_hook): operator.execute(None) - bq_hook.return_value.get_conn.return_value.cursor.return_value.run_load.assert_called_once_with( + bq_hook.return_value.run_load.assert_called_once_with( destination_project_dataset_table=mock.ANY, schema_fields=mock.ANY, source_uris=[f'gs://{TEST_BUCKET}/{TEST_SOURCE_OBJECTS_AS_STRING}'],