Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,14 @@ def execute(self, context: 'Context') -> None:
location=self.location,
impersonation_chain=self.impersonation_chain,
)
conn = hook.get_conn()
cursor = conn.cursor()
cursor.run_copy(
source_project_dataset_tables=self.source_project_dataset_tables,
destination_project_dataset_table=self.destination_project_dataset_table,
write_disposition=self.write_disposition,
create_disposition=self.create_disposition,
labels=self.labels,
encryption_configuration=self.encryption_configuration,
)

with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
hook.run_copy(
source_project_dataset_tables=self.source_project_dataset_tables,
destination_project_dataset_table=self.destination_project_dataset_table,
write_disposition=self.write_disposition,
create_disposition=self.create_disposition,
labels=self.labels,
encryption_configuration=self.encryption_configuration,
)
133 changes: 72 additions & 61 deletions airflow/providers/google/cloud/transfers/gcs_to_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""This module contains a Google Cloud Storage to BigQuery operator."""

import json
import warnings
from typing import TYPE_CHECKING, Optional, Sequence, Union

from airflow.models import BaseOperator
Expand Down Expand Up @@ -163,8 +164,9 @@ def __init__(
allow_jagged_rows=False,
encoding="UTF-8",
max_id_key=None,
bigquery_conn_id='google_cloud_default',
google_cloud_storage_conn_id='google_cloud_default',
gcp_conn_id='google_cloud_default',
bigquery_conn_id=None,
google_cloud_storage_conn_id=None,
delegate_to=None,
schema_update_options=(),
src_fmt_configs=None,
Expand All @@ -179,6 +181,15 @@ def __init__(
description=None,
**kwargs,
):
# To preserve backward compatibility. Remove one day
if bigquery_conn_id or google_cloud_storage_conn_id:
warnings.warn(
"The bigquery_conn_id and google_cloud_storage_conn_id parameters have been deprecated. "
"You should pass only gcp_conn_id parameter. "
"Will be used bigquery_conn_id or google_cloud_storage_conn_id if gcp_conn_id not passed.",
DeprecationWarning,
stacklevel=2,
)

super().__init__(**kwargs)

Expand Down Expand Up @@ -209,8 +220,7 @@ def __init__(
self.encoding = encoding

self.max_id_key = max_id_key
self.bigquery_conn_id = bigquery_conn_id
self.google_cloud_storage_conn_id = google_cloud_storage_conn_id
self.gcp_conn_id = gcp_conn_id or bigquery_conn_id or google_cloud_storage_conn_id
self.delegate_to = delegate_to

self.schema_update_options = schema_update_options
Expand All @@ -227,7 +237,7 @@ def __init__(

def execute(self, context: 'Context'):
bq_hook = BigQueryHook(
bigquery_conn_id=self.bigquery_conn_id,
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
location=self.location,
impersonation_chain=self.impersonation_chain,
Expand All @@ -236,7 +246,7 @@ def execute(self, context: 'Context'):
if not self.schema_fields:
if self.schema_object and self.source_format != 'DATASTORE_BACKUP':
gcs_hook = GCSHook(
gcp_conn_id=self.google_cloud_storage_conn_id,
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
impersonation_chain=self.impersonation_chain,
)
Expand All @@ -247,72 +257,73 @@ def execute(self, context: 'Context'):
schema_fields = json.loads(blob.decode("utf-8"))
else:
schema_fields = None

else:
schema_fields = self.schema_fields

self.source_objects = (
self.source_objects if isinstance(self.source_objects, list) else [self.source_objects]
)
source_uris = [f'gs://{self.bucket}/{source_object}' for source_object in self.source_objects]
conn = bq_hook.get_conn()
cursor = conn.cursor()

if self.external_table:
cursor.create_external_table(
external_project_dataset_table=self.destination_project_dataset_table,
schema_fields=schema_fields,
source_uris=source_uris,
source_format=self.source_format,
compression=self.compression,
skip_leading_rows=self.skip_leading_rows,
field_delimiter=self.field_delimiter,
max_bad_records=self.max_bad_records,
quote_character=self.quote_character,
ignore_unknown_values=self.ignore_unknown_values,
allow_quoted_newlines=self.allow_quoted_newlines,
allow_jagged_rows=self.allow_jagged_rows,
encoding=self.encoding,
src_fmt_configs=self.src_fmt_configs,
encryption_configuration=self.encryption_configuration,
labels=self.labels,
description=self.description,
)
else:
cursor.run_load(
destination_project_dataset_table=self.destination_project_dataset_table,
schema_fields=schema_fields,
source_uris=source_uris,
source_format=self.source_format,
autodetect=self.autodetect,
create_disposition=self.create_disposition,
skip_leading_rows=self.skip_leading_rows,
write_disposition=self.write_disposition,
field_delimiter=self.field_delimiter,
max_bad_records=self.max_bad_records,
quote_character=self.quote_character,
ignore_unknown_values=self.ignore_unknown_values,
allow_quoted_newlines=self.allow_quoted_newlines,
allow_jagged_rows=self.allow_jagged_rows,
encoding=self.encoding,
schema_update_options=self.schema_update_options,
src_fmt_configs=self.src_fmt_configs,
time_partitioning=self.time_partitioning,
cluster_fields=self.cluster_fields,
encryption_configuration=self.encryption_configuration,
labels=self.labels,
description=self.description,
)

if cursor.use_legacy_sql:
escaped_table_name = f'[{self.destination_project_dataset_table}]'
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
bq_hook.create_external_table(
external_project_dataset_table=self.destination_project_dataset_table,
schema_fields=schema_fields,
source_uris=source_uris,
source_format=self.source_format,
compression=self.compression,
skip_leading_rows=self.skip_leading_rows,
field_delimiter=self.field_delimiter,
max_bad_records=self.max_bad_records,
quote_character=self.quote_character,
ignore_unknown_values=self.ignore_unknown_values,
allow_quoted_newlines=self.allow_quoted_newlines,
allow_jagged_rows=self.allow_jagged_rows,
encoding=self.encoding,
src_fmt_configs=self.src_fmt_configs,
encryption_configuration=self.encryption_configuration,
labels=self.labels,
description=self.description,
)
else:
escaped_table_name = f'`{self.destination_project_dataset_table}`'
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
bq_hook.run_load(
destination_project_dataset_table=self.destination_project_dataset_table,
schema_fields=schema_fields,
source_uris=source_uris,
source_format=self.source_format,
autodetect=self.autodetect,
create_disposition=self.create_disposition,
skip_leading_rows=self.skip_leading_rows,
write_disposition=self.write_disposition,
field_delimiter=self.field_delimiter,
max_bad_records=self.max_bad_records,
quote_character=self.quote_character,
ignore_unknown_values=self.ignore_unknown_values,
allow_quoted_newlines=self.allow_quoted_newlines,
allow_jagged_rows=self.allow_jagged_rows,
encoding=self.encoding,
schema_update_options=self.schema_update_options,
src_fmt_configs=self.src_fmt_configs,
time_partitioning=self.time_partitioning,
cluster_fields=self.cluster_fields,
encryption_configuration=self.encryption_configuration,
labels=self.labels,
description=self.description,
)

if self.max_id_key:
select_command = f'SELECT MAX({self.max_id_key}) FROM {escaped_table_name}'
cursor.execute(select_command)
row = cursor.fetchone()
select_command = f'SELECT MAX({self.max_id_key}) FROM `{self.destination_project_dataset_table}`'
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
job_id = bq_hook.run_query(
sql=select_command,
use_legacy_sql=False,
)
row = list(bq_hook.get_job(job_id).result())
if row:
max_id = row[0] if row[0] else 0
self.log.info(
Expand All @@ -322,4 +333,4 @@ def execute(self, context: 'Context'):
max_id,
)
else:
raise RuntimeError(f"The f{select_command} returned no rows!")
raise RuntimeError(f"The {select_command} returned no rows!")
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_execute(self, mock_hook):
)

operator.execute(None)
mock_hook.return_value.get_conn.return_value.cursor.return_value.run_copy.assert_called_once_with(
mock_hook.return_value.run_copy.assert_called_once_with(
source_project_dataset_tables=source_project_dataset_tables,
destination_project_dataset_table=destination_project_dataset_table,
write_disposition=write_disposition,
Expand Down
Loading