diff --git a/airflow/contrib/hooks/bigquery_hook.py b/airflow/contrib/hooks/bigquery_hook.py index e3075ba9b8573..35689013c5563 100644 --- a/airflow/contrib/hooks/bigquery_hook.py +++ b/airflow/contrib/hooks/bigquery_hook.py @@ -22,20 +22,16 @@ implementation for BigQuery. """ -import time import six -from builtins import range -from copy import deepcopy -from six import iteritems - -from past.builtins import basestring from airflow import AirflowException from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook from airflow.hooks.dbapi_hook import DbApiHook from airflow.utils.log.logging_mixin import LoggingMixin -from googleapiclient.discovery import build -from googleapiclient.errors import HttpError + +from google.cloud import bigquery +from google.cloud.bigquery import dbapi +from google.api_core.exceptions import NotFound from pandas_gbq.gbq import \ _check_google_client_version as gbq_check_google_client_version from pandas_gbq import read_gbq @@ -54,33 +50,27 @@ class BigQueryHook(GoogleCloudBaseHook, DbApiHook, LoggingMixin): def __init__(self, bigquery_conn_id='bigquery_default', delegate_to=None, - use_legacy_sql=True, + use_legacy_sql=False, location=None): super(BigQueryHook, self).__init__( gcp_conn_id=bigquery_conn_id, delegate_to=delegate_to) self.use_legacy_sql = use_legacy_sql self.location = location - def get_conn(self): - """ - Returns a BigQuery PEP 249 connection object. - """ - service = self.get_service() - project = self._get_field('project') - return BigQueryConnection( - service=service, - project_id=project, - use_legacy_sql=self.use_legacy_sql, + def get_client(self, project_id=None): + project_id = project_id if project_id is not None else self.project_id + return bigquery.Client( + project=project_id, + credentials=self._get_credentials(), location=self.location, ) - def get_service(self): + def get_conn(self, project_id=None): """ - Returns a BigQuery service object. + Returns a BigQuery PEP 249 connection object. """ - http_authorized = self._authorize() - return build( - 'bigquery', 'v2', http=http_authorized, cache_discovery=False) + project_id = project_id if project_id is not None else self.project_id + return dbapi.Connection(self.get_client(project_id)) def insert_rows(self, table, rows, target_fields=None, commit_every=1000): """ @@ -119,113 +109,37 @@ def get_pandas_df(self, sql, parameters=None, dialect=None): verbose=False, private_key=private_key) - def table_exists(self, project_id, dataset_id, table_id): + def table_exists(self, dataset_id, table_id, project_id=None): """ Checks for the existence of a table in Google BigQuery. - :param project_id: The Google cloud project in which to look for the - table. The connection supplied to the hook must provide access to - the specified project. - :type project_id: str :param dataset_id: The name of the dataset in which to look for the table. :type dataset_id: str :param table_id: The name of the table to check the existence of. :type table_id: str + :param project_id: The Google cloud project in which to look for the + table. The connection supplied to the hook must provide access to + the specified project. + :type project_id: str """ - service = self.get_service() + project_id = project_id if project_id is not None else self.project_id + client = self.get_client(project_id) try: - service.tables().get( - projectId=project_id, datasetId=dataset_id, - tableId=table_id).execute() + client.get_table( + bigquery.TableReference( + bigquery.DatasetReference(project_id, dataset_id), + table_id, + ) + ) return True - except HttpError as e: - if e.resp['status'] == '404': - return False - raise - - -class BigQueryPandasConnector(GbqConnector): - """ - This connector behaves identically to GbqConnector (from Pandas), except - that it allows the service to be injected, and disables a call to - self.get_credentials(). This allows Airflow to use BigQuery with Pandas - without forcing a three legged OAuth connection. Instead, we can inject - service account credentials into the binding. - """ - - def __init__(self, - project_id, - service, - reauth=False, - verbose=False, - dialect='legacy'): - super(BigQueryPandasConnector, self).__init__(project_id) - gbq_check_google_client_version() - gbq_test_google_api_imports() - self.project_id = project_id - self.reauth = reauth - self.service = service - self.verbose = verbose - self.dialect = dialect - - -class BigQueryConnection(object): - """ - BigQuery does not have a notion of a persistent connection. Thus, these - objects are small stateless factories for cursors, which do all the real - work. - """ - - def __init__(self, *args, **kwargs): - self._args = args - self._kwargs = kwargs - - def close(self): - """ BigQueryConnection does not have anything to close. """ - pass - - def commit(self): - """ BigQueryConnection does not support transactions. """ - pass - - def cursor(self): - """ Return a new :py:class:`Cursor` object using the connection. """ - return BigQueryCursor(*self._args, **self._kwargs) - - def rollback(self): - raise NotImplementedError( - "BigQueryConnection does not have transactions") - - -class BigQueryBaseCursor(LoggingMixin): - """ - The BigQuery base cursor contains helper methods to execute queries against - BigQuery. The methods can be used directly by operators, in cases where a - PEP 249 cursor isn't needed. - """ - - def __init__(self, - service, - project_id, - use_legacy_sql=True, - api_resource_configs=None, - location=None): - - self.service = service - self.project_id = project_id - self.use_legacy_sql = use_legacy_sql - if api_resource_configs: - _validate_value("api_resource_configs", api_resource_configs, dict) - self.api_resource_configs = api_resource_configs \ - if api_resource_configs else {} - self.running_job_id = None - self.location = location + except NotFound: + return False def create_empty_table(self, - project_id, dataset_id, table_id, + project_id=None, schema_fields=None, time_partitioning=None, labels=None, @@ -234,12 +148,12 @@ def create_empty_table(self, Creates a new, empty table in the dataset. To create a view, which is defined by a SQL query, parse a dictionary to 'view' kwarg - :param project_id: The project to create the table into. - :type project_id: str :param dataset_id: The dataset to create the table into. :type dataset_id: str :param table_id: The Name of the table to be created. :type table_id: str + :param project_id: The project to create the table into. + :type project_id: str :param schema_fields: If set, the schema field list as defined here: https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.load.schema :type schema_fields: list @@ -271,43 +185,33 @@ def create_empty_table(self, :return: """ - project_id = project_id if project_id is not None else self.project_id + client = self.get_client(project_id) - table_resource = { - 'tableReference': { - 'tableId': table_id - } - } + table_ref = bigquery.TableReference( + bigquery.DatasetReference(project_id, dataset_id), + table_id, + ) + table = bigquery.Table(table_ref) if schema_fields: - table_resource['schema'] = {'fields': schema_fields} + table.fields = schema_fields if time_partitioning: - table_resource['timePartitioning'] = time_partitioning + table.time_partitioning = time_partitioning if labels: - table_resource['labels'] = labels + table.labels = labels if view: - table_resource['view'] = view + table.view_query = view self.log.info('Creating Table %s:%s.%s', project_id, dataset_id, table_id) - try: - self.service.tables().insert( - projectId=project_id, - datasetId=dataset_id, - body=table_resource).execute() - - self.log.info('Table created successfully: %s:%s.%s', - project_id, dataset_id, table_id) - - except HttpError as err: - raise AirflowException( - 'BigQuery job failed. Error was: {}'.format(err.content) - ) + client.create_table(table) + self.log.info('Table created successfully: %s:%s.%s', + project_id, dataset_id, table_id) def create_external_table(self, external_project_dataset_table, @@ -323,9 +227,8 @@ def create_external_table(self, quote_character=None, allow_quoted_newlines=False, allow_jagged_rows=False, - src_fmt_configs=None, - labels=None - ): + external_config_options=None, + labels=None): """ Creates a new external table in the dataset with the data in Google Cloud Storage. See here: @@ -383,18 +286,15 @@ def create_external_table(self, records, an invalid error is returned in the job result. Only applicable when soure_format is CSV. :type allow_jagged_rows: bool - :param src_fmt_configs: configure optional fields specific to the source format - :type src_fmt_configs: dict :param labels: a dictionary containing labels for the table, passed to BigQuery :type labels: dict """ + table_ref = _split_tablename(table_input=external_project_dataset_table, + default_project_id=self.project_id, + var_name='external_project_dataset_table') + table = bigquery.Table(table_ref) - if src_fmt_configs is None: - src_fmt_configs = {} - project_id, dataset_id, external_table_id = \ - _split_tablename(table_input=external_project_dataset_table, - default_project_id=self.project_id, - var_name='external_project_dataset_table') + client = self.get_client(table_ref.project) # bigquery only allows certain source formats # we check to make sure the passed source format is valid @@ -419,91 +319,46 @@ def create_external_table(self, "Please use one of the following types: {1}" .format(compression, allowed_compressions)) - table_resource = { - 'externalDataConfiguration': { - 'autodetect': autodetect, - 'sourceFormat': source_format, - 'sourceUris': source_uris, - 'compression': compression, - 'ignoreUnknownValues': ignore_unknown_values - }, - 'tableReference': { - 'projectId': project_id, - 'datasetId': dataset_id, - 'tableId': external_table_id, - } - } + external_config = bigquery.ExternalConfig(source_format) + external_config.autodetect = autodetect + external_config.compression = compression + external_config.source_uris = source_uris + external_config.ignore_unknown_values = ignore_unknown_values - if schema_fields: - table_resource['externalDataConfiguration'].update({ - 'schema': { - 'fields': schema_fields - } - }) + if external_config_options is not None: + if not isinstance(external_config_options, type(external_config.options)): + raise AirflowException( + 'external_config_options must have type {}'.format( + type(external_config.options))) + external_config.options = external_config_options - self.log.info('Creating external table: %s', external_project_dataset_table) + if schema_fields: + external_config.schema = schema_fields if max_bad_records: - table_resource['externalDataConfiguration']['maxBadRecords'] = max_bad_records + external_config.max_bad_records = max_bad_records - # if following fields are not specified in src_fmt_configs, - # honor the top-level params for backward-compatibility - if 'skipLeadingRows' not in src_fmt_configs: - src_fmt_configs['skipLeadingRows'] = skip_leading_rows - if 'fieldDelimiter' not in src_fmt_configs: - src_fmt_configs['fieldDelimiter'] = field_delimiter - if 'quote_character' not in src_fmt_configs: - src_fmt_configs['quote'] = quote_character - if 'allowQuotedNewlines' not in src_fmt_configs: - src_fmt_configs['allowQuotedNewlines'] = allow_quoted_newlines - if 'allowJaggedRows' not in src_fmt_configs: - src_fmt_configs['allowJaggedRows'] = allow_jagged_rows - - src_fmt_to_param_mapping = { - 'CSV': 'csvOptions', - 'GOOGLE_SHEETS': 'googleSheetsOptions' - } - - src_fmt_to_configs_mapping = { - 'csvOptions': [ - 'allowJaggedRows', 'allowQuotedNewlines', - 'fieldDelimiter', 'skipLeadingRows', - 'quote' - ], - 'googleSheetsOptions': ['skipLeadingRows'] - } - - if source_format in src_fmt_to_param_mapping.keys(): - - valid_configs = src_fmt_to_configs_mapping[ - src_fmt_to_param_mapping[source_format] - ] - - src_fmt_configs = { - k: v - for k, v in src_fmt_configs.items() if k in valid_configs - } + self.log.info('Creating external table: %s', external_project_dataset_table) - table_resource['externalDataConfiguration'][src_fmt_to_param_mapping[ - source_format]] = src_fmt_configs + # if following fields are not specified in external_config_options + # honor the top-level params for backward-compatibility + if external_config.skip_leading_rows is None: + external_config.options.skip_leading_rows = skip_leading_rows + if external_config.field_delimiter is None: + external_config.field_delimiter = field_delimiter + if external_config.options.quote_character is None: + external_config.options.quote_character = quote_character + if external_config.options.allow_quoted_newlines is None: + external_config.options.allow_quoted_newlines = allow_quoted_newlines + if external_config.options.allow_jagged_rows is None: + external_config.options.allow_jagged_rows = allow_jagged_rows if labels: - table_resource['labels'] = labels + table.labels = labels - try: - self.service.tables().insert( - projectId=project_id, - datasetId=dataset_id, - body=table_resource - ).execute() - - self.log.info('External table created successfully: %s', - external_project_dataset_table) - - except HttpError as err: - raise Exception( - 'BigQuery job failed. Error was: {}'.format(err.content) - ) + client.create_table(table) + self.log.info('External table created successfully: %s', + external_project_dataset_table) def patch_table(self, dataset_id, @@ -573,50 +428,55 @@ def patch_table(self, :type require_partition_filter: bool """ - project_id = project_id if project_id is not None else self.project_id + client = self.get_client(project_id) - table_resource = {} + table_ref = bigquery.TableReference( + bigquery.DatasetReference(project_id, dataset_id), + table_id, + ) + table = bigquery.Table(table_ref) + fields = [] if description is not None: - table_resource['description'] = description + table.description = description + fields.append('description') if expiration_time is not None: - table_resource['expirationTime'] = expiration_time + table.expires = expiration_time + fields.append('expires') if external_data_configuration: - table_resource['externalDataConfiguration'] = external_data_configuration + table.external_data_configuration = external_data_configuration + fields.append('external_data_configuration') if friendly_name is not None: - table_resource['friendlyName'] = friendly_name + table.friendly_name = friendly_name + fields.append('friendly_name') if labels: - table_resource['labels'] = labels + table.labels = labels + fields.append('labels') if schema: - table_resource['schema'] = {'fields': schema} + table.schema = schema + fields.append('schema') if time_partitioning: - table_resource['timePartitioning'] = time_partitioning + table.time_partitioning = time_partitioning + fields.append('time_partitioning') if view: - table_resource['view'] = view - if require_partition_filter is not None: - table_resource['requirePartitionFilter'] = require_partition_filter + table.view_query = view + fields.append('view_query') self.log.info('Patching Table %s:%s.%s', project_id, dataset_id, table_id) - try: - self.service.tables().patch( - projectId=project_id, - datasetId=dataset_id, - tableId=table_id, - body=table_resource).execute() - - self.log.info('Table patched successfully: %s:%s.%s', - project_id, dataset_id, table_id) - - except HttpError as err: - raise AirflowException( - 'BigQuery job failed. Error was: {}'.format(err.content) - ) + client.update_table( + table, + fields=fields, + ) + + self.log.info('Table patched successfully: %s:%s.%s', + project_id, dataset_id, table_id) def run_query(self, sql, + project_id=None, destination_dataset_table=None, write_disposition='WRITE_EMPTY', allow_large_results=False, @@ -631,8 +491,8 @@ def run_query(self, schema_update_options=(), priority='INTERACTIVE', time_partitioning=None, - api_resource_configs=None, cluster_fields=None, + job_config=None, location=None): """ Executes a BigQuery SQL query. Optionally persists results in a BigQuery @@ -706,29 +566,10 @@ def run_query(self, https://cloud.google.com/bigquery/docs/locations#specifying_your_location :type location: str """ + project_id = project_id if project_id is not None else self.project_id + client = self.get_client(project_id) - if time_partitioning is None: - time_partitioning = {} - - if location: - self.location = location - - if not api_resource_configs: - api_resource_configs = self.api_resource_configs - else: - _validate_value('api_resource_configs', - api_resource_configs, dict) - configuration = deepcopy(api_resource_configs) - if 'query' not in configuration: - configuration['query'] = {} - - else: - _validate_value("api_resource_configs['query']", - configuration['query'], dict) - - if sql is None and not configuration['query'].get('query', None): - raise TypeError('`BigQueryBaseCursor.run_query` ' - 'missing 1 required positional argument: `sql`') + job_config = job_config or bigquery.QueryJobConfig() # BigQuery also allows you to define how you want a table's schema to change # as a side effect of a query job @@ -754,87 +595,57 @@ def run_query(self, "'WRITE_APPEND' or 'WRITE_TRUNCATE'.") if destination_dataset_table: - destination_project, destination_dataset, destination_table = \ - _split_tablename(table_input=destination_dataset_table, - default_project_id=self.project_id) - - destination_dataset_table = { - 'projectId': destination_project, - 'datasetId': destination_dataset, - 'tableId': destination_table, - } - - if cluster_fields: - cluster_fields = {'fields': cluster_fields} + job_config.destination = _split_tablename(table_input=destination_dataset_table, + default_project_id=self.project_id) query_param_list = [ - (sql, 'query', None, six.string_types), (priority, 'priority', 'INTERACTIVE', six.string_types), - (use_legacy_sql, 'useLegacySql', self.use_legacy_sql, bool), - (query_params, 'queryParameters', None, dict), - (udf_config, 'userDefinedFunctionResources', None, list), - (maximum_billing_tier, 'maximumBillingTier', None, int), - (maximum_bytes_billed, 'maximumBytesBilled', None, float), - (time_partitioning, 'timePartitioning', {}, dict), - (schema_update_options, 'schemaUpdateOptions', None, tuple), - (destination_dataset_table, 'destinationTable', None, dict), - (cluster_fields, 'clustering', None, dict), + (use_legacy_sql, 'use_legacy_sql', self.use_legacy_sql, bool), + (query_params, 'query_parameters', None, list), + (udf_config, 'udf_resources', None, list), + (maximum_billing_tier, 'maximum_billing_tier', None, int), + (maximum_bytes_billed, 'maximum_bytes_billed', None, int), + (time_partitioning, 'time_partitioning', {}, bigquery.TimePartitioning), + (schema_update_options, 'schema_update_options', None, list), + (cluster_fields, 'clustering_fields', None, list), ] for param_tuple in query_param_list: param, param_name, param_default, param_type = param_tuple - if param_name not in configuration['query'] and param in [None, {}, ()]: - if param_name == 'timePartitioning': - param_default = _cleanse_time_partitioning( - destination_dataset_table, time_partitioning) + if param in [None, {}, ()]: param = param_default if param not in [None, {}, ()]: - _api_resource_configs_duplication_check( - param_name, param, configuration['query']) - - configuration['query'][param_name] = param + setattr(job_config, param_name, param) # check valid type of provided param, # it last step because we can get param from 2 sources, # and first of all need to find it - _validate_value(param_name, configuration['query'][param_name], - param_type) - - if param_name == 'schemaUpdateOptions' and param: - self.log.info("Adding experimental 'schemaUpdateOptions': " - "{0}".format(schema_update_options)) - - if param_name == 'destinationTable': - for key in ['projectId', 'datasetId', 'tableId']: - if key not in configuration['query']['destinationTable']: - raise ValueError( - "Not correct 'destinationTable' in " - "api_resource_configs. 'destinationTable' " - "must be a dict with {'projectId':'', " - "'datasetId':'', 'tableId':''}") - - configuration['query'].update({ - 'allowLargeResults': allow_large_results, - 'flattenResults': flatten_results, - 'writeDisposition': write_disposition, - 'createDisposition': create_disposition, - }) - - if 'useLegacySql' in configuration['query'] and \ - 'queryParameters' in configuration['query']: + # _validate_value(param_name, configuration['query'][param_name], + # param_type) + + if param_name == 'destination_table': + job_config.allow_large_results = allow_large_results + job_config.flatten_results = flatten_results + job_config.write_disposition = write_disposition + job_config.create_disposition = create_disposition + + if job_config.use_legacy_sql and job_config.query_parameters: raise ValueError("Query parameters are not allowed " "when using legacy SQL") if labels: - _api_resource_configs_duplication_check( - 'labels', labels, configuration) - configuration['labels'] = labels + job_config.labels = labels - return self.run_with_configuration(configuration) + return client.query( + sql, + project=project_id, + location=location, + job_config=job_config, + ).result() def run_extract( # noqa self, @@ -844,7 +655,8 @@ def run_extract( # noqa export_format='CSV', field_delimiter=',', print_header=True, - labels=None): + labels=None, + location=None): """ Executes a BigQuery extract command to copy data from BigQuery to Google Cloud Storage. See here: @@ -873,40 +685,39 @@ def run_extract( # noqa passed to BigQuery :type labels: dict """ + table_ref = _split_tablename(table_input=source_project_dataset_table, + default_project_id=self.project_id, + var_name='source_project_dataset_table') - source_project, source_dataset, source_table = \ - _split_tablename(table_input=source_project_dataset_table, - default_project_id=self.project_id, - var_name='source_project_dataset_table') - - configuration = { - 'extract': { - 'sourceTable': { - 'projectId': source_project, - 'datasetId': source_dataset, - 'tableId': source_table, - }, - 'compression': compression, - 'destinationUris': destination_cloud_storage_uris, - 'destinationFormat': export_format, - } - } + client = self.get_client(table_ref.project) + + job_config = bigquery.ExtractJobConfig( + compression=compression, + destination_format=export_format, + ) if labels: - configuration['labels'] = labels + job_config.labels = labels if export_format == 'CSV': # Only set fieldDelimiter and printHeader fields if using CSV. # Google does not like it if you set these fields for other export # formats. - configuration['extract']['fieldDelimiter'] = field_delimiter - configuration['extract']['printHeader'] = print_header + job_config.field_delimiter = field_delimiter + job_config.print_header = print_header - return self.run_with_configuration(configuration) + return client.extract_table( + table_ref, + destination_cloud_storage_uris, + location=location, + job_config=job_config, + ).result() def run_copy(self, source_project_dataset_tables, destination_project_dataset_table, + project_id=None, + location=None, write_disposition='WRITE_EMPTY', create_disposition='CREATE_IF_NEEDED', labels=None): @@ -936,50 +747,45 @@ def run_copy(self, passed to BigQuery :type labels: dict """ + project_id = project_id if project_id is not None else self.project_id + client = self.get_client(project_id) + source_project_dataset_tables = ([ source_project_dataset_tables ] if not isinstance(source_project_dataset_tables, list) else source_project_dataset_tables) - source_project_dataset_tables_fixup = [] + source_table_refs = [] for source_project_dataset_table in source_project_dataset_tables: - source_project, source_dataset, source_table = \ - _split_tablename(table_input=source_project_dataset_table, - default_project_id=self.project_id, - var_name='source_project_dataset_table') - source_project_dataset_tables_fixup.append({ - 'projectId': - source_project, - 'datasetId': - source_dataset, - 'tableId': - source_table - }) - - destination_project, destination_dataset, destination_table = \ - _split_tablename(table_input=destination_project_dataset_table, - default_project_id=self.project_id) - configuration = { - 'copy': { - 'createDisposition': create_disposition, - 'writeDisposition': write_disposition, - 'sourceTables': source_project_dataset_tables_fixup, - 'destinationTable': { - 'projectId': destination_project, - 'datasetId': destination_dataset, - 'tableId': destination_table - } - } - } + source_table_ref = _split_tablename(table_input=source_project_dataset_table, + default_project_id=project_id, + var_name='source_project_dataset_table') + source_table_refs.append(source_table_ref) + + dest_table_ref = _split_tablename(table_input=destination_project_dataset_table, + default_project_id=project_id) + + job_config = bigquery.CopyJobConfig( + write_disposition=write_disposition, + create_disposition=create_disposition, + ) if labels: - configuration['labels'] = labels + job_config.labels = labels - return self.run_with_configuration(configuration) + return client.copy_table( + source_table_refs, + dest_table_ref, + project=project_id, + location=location, + job_config=job_config, + ).result() def run_load(self, destination_project_dataset_table, source_uris, + project_id=None, + location=None, schema_fields=None, source_format='CSV', create_disposition='CREATE_IF_NEEDED', @@ -992,7 +798,7 @@ def run_load(self, allow_quoted_newlines=False, allow_jagged_rows=False, schema_update_options=(), - src_fmt_configs=None, + job_config=None, time_partitioning=None, cluster_fields=None, autodetect=False): @@ -1056,8 +862,6 @@ def run_load(self, :param schema_update_options: Allows the schema of the destination table to be updated as a side effect of the load job. :type schema_update_options: tuple - :param src_fmt_configs: configure optional fields specific to the source format - :type src_fmt_configs: dict :param time_partitioning: configure optional time partitioning fields i.e. partition by field, type and expiration as per API specifications. :type time_partitioning: dict @@ -1066,6 +870,8 @@ def run_load(self, time_partitioning. The order of columns given determines the sort order. :type cluster_fields: list of str """ + project_id = project_id if project_id is not None else self.project_id + client = self.get_client(project_id) # bigquery only allows certain source formats # we check to make sure the passed source format is valid @@ -1077,9 +883,6 @@ def run_load(self, raise ValueError( 'You must either pass a schema or autodetect=True.') - if src_fmt_configs is None: - src_fmt_configs = {} - source_format = source_format.upper() allowed_formats = [ "CSV", "NEWLINE_DELIMITED_JSON", "AVRO", "GOOGLE_SHEETS", @@ -1104,41 +907,25 @@ def run_load(self, "Please only use one or more of the following options: {1}" .format(schema_update_options, allowed_schema_update_options)) - destination_project, destination_dataset, destination_table = \ - _split_tablename(table_input=destination_project_dataset_table, - default_project_id=self.project_id, - var_name='destination_project_dataset_table') - - configuration = { - 'load': { - 'autodetect': autodetect, - 'createDisposition': create_disposition, - 'destinationTable': { - 'projectId': destination_project, - 'datasetId': destination_dataset, - 'tableId': destination_table, - }, - 'sourceFormat': source_format, - 'sourceUris': source_uris, - 'writeDisposition': write_disposition, - 'ignoreUnknownValues': ignore_unknown_values - } - } + table_ref = _split_tablename(table_input=destination_project_dataset_table, + default_project_id=project_id, + var_name='destination_project_dataset_table') + + job_config = job_config or bigquery.LoadJobConfig() + job_config.create_disposition = create_disposition + job_config.write_disposition = write_disposition + job_config.autodetect = autodetect + job_config.source_format = source_format + job_config.ignore_unknown_values = ignore_unknown_values - time_partitioning = _cleanse_time_partitioning( - destination_project_dataset_table, - time_partitioning - ) if time_partitioning: - configuration['load'].update({ - 'timePartitioning': time_partitioning - }) + job_config.time_partitioning = time_partitioning if cluster_fields: - configuration['load'].update({'clustering': {'fields': cluster_fields}}) + job_config.clustering_fields = cluster_fields if schema_fields: - configuration['load']['schema'] = {'fields': schema_fields} + job_config.schema = schema_fields if schema_update_options: if write_disposition not in ["WRITE_APPEND", "WRITE_TRUNCATE"]: @@ -1146,178 +933,36 @@ def run_load(self, "allowed if write_disposition is " "'WRITE_APPEND' or 'WRITE_TRUNCATE'.") else: - self.log.info( - "Adding experimental " - "'schemaUpdateOptions': {0}".format(schema_update_options)) - configuration['load'][ - 'schemaUpdateOptions'] = schema_update_options + job_config.schema_update_options = schema_update_options if max_bad_records: - configuration['load']['maxBadRecords'] = max_bad_records + job_config.max_bad_records = max_bad_records - # if following fields are not specified in src_fmt_configs, + # if following fields are not specified in job_config, # honor the top-level params for backward-compatibility - if 'skipLeadingRows' not in src_fmt_configs: - src_fmt_configs['skipLeadingRows'] = skip_leading_rows - if 'fieldDelimiter' not in src_fmt_configs: - src_fmt_configs['fieldDelimiter'] = field_delimiter - if 'ignoreUnknownValues' not in src_fmt_configs: - src_fmt_configs['ignoreUnknownValues'] = ignore_unknown_values + if job_config.skip_leading_rows is None: + job_config.skip_leading_rows = skip_leading_rows + if job_config.field_delimiter is None: + job_config.field_delimiter = field_delimiter + if job_config.ignore_unknown_values is None: + job_config.ignore_unknown_values = ignore_unknown_values if quote_character is not None: - src_fmt_configs['quote'] = quote_character + job_config.quote_character = quote_character if allow_quoted_newlines: - src_fmt_configs['allowQuotedNewlines'] = allow_quoted_newlines - - src_fmt_to_configs_mapping = { - 'CSV': [ - 'allowJaggedRows', 'allowQuotedNewlines', 'autodetect', - 'fieldDelimiter', 'skipLeadingRows', 'ignoreUnknownValues', - 'nullMarker', 'quote' - ], - 'DATASTORE_BACKUP': ['projectionFields'], - 'NEWLINE_DELIMITED_JSON': ['autodetect', 'ignoreUnknownValues'], - 'PARQUET': ['autodetect', 'ignoreUnknownValues'], - 'AVRO': [], - } - valid_configs = src_fmt_to_configs_mapping[source_format] - src_fmt_configs = { - k: v - for k, v in src_fmt_configs.items() if k in valid_configs - } - configuration['load'].update(src_fmt_configs) + job_config.allow_quoted_newlines = allow_quoted_newlines if allow_jagged_rows: - configuration['load']['allowJaggedRows'] = allow_jagged_rows - - return self.run_with_configuration(configuration) - - def run_with_configuration(self, configuration): - """ - Executes a BigQuery SQL query. See here: - - https://cloud.google.com/bigquery/docs/reference/v2/jobs - - For more details about the configuration parameter. - - :param configuration: The configuration parameter maps directly to - BigQuery's configuration field in the job object. See - https://cloud.google.com/bigquery/docs/reference/v2/jobs for - details. - """ - jobs = self.service.jobs() - job_data = {'configuration': configuration} - - # Send query and wait for reply. - query_reply = jobs \ - .insert(projectId=self.project_id, body=job_data) \ - .execute() - self.running_job_id = query_reply['jobReference']['jobId'] - - # Wait for query to finish. - keep_polling_job = True - while keep_polling_job: - try: - if self.location: - job = jobs.get( - projectId=self.project_id, - jobId=self.running_job_id, - location=self.location).execute() - else: - job = jobs.get( - projectId=self.project_id, - jobId=self.running_job_id).execute() - if job['status']['state'] == 'DONE': - keep_polling_job = False - # Check if job had errors. - if 'errorResult' in job['status']: - raise Exception( - 'BigQuery job failed. Final error was: {}. The job was: {}'. - format(job['status']['errorResult'], job)) - else: - self.log.info('Waiting for job to complete : %s, %s', - self.project_id, self.running_job_id) - time.sleep(5) - - except HttpError as err: - if err.resp.status in [500, 503]: - self.log.info( - '%s: Retryable error, waiting for job to complete: %s', - err.resp.status, self.running_job_id) - time.sleep(5) - else: - raise Exception( - 'BigQuery job status check failed. Final error was: %s', - err.resp.status) - - return self.running_job_id - - def poll_job_complete(self, job_id): - jobs = self.service.jobs() - try: - if self.location: - job = jobs.get(projectId=self.project_id, - jobId=job_id, - location=self.location).execute() - else: - job = jobs.get(projectId=self.project_id, - jobId=job_id).execute() - if job['status']['state'] == 'DONE': - return True - except HttpError as err: - if err.resp.status in [500, 503]: - self.log.info( - '%s: Retryable error while polling job with id %s', - err.resp.status, job_id) - else: - raise Exception( - 'BigQuery job status check failed. Final error was: %s', - err.resp.status) - return False + job_config.allow_jagged_rows = allow_jagged_rows - def cancel_query(self): - """ - Cancel all started queries that have not yet completed - """ - jobs = self.service.jobs() - if (self.running_job_id and - not self.poll_job_complete(self.running_job_id)): - self.log.info('Attempting to cancel job : %s, %s', self.project_id, - self.running_job_id) - if self.location: - jobs.cancel( - projectId=self.project_id, - jobId=self.running_job_id, - location=self.location).execute() - else: - jobs.cancel( - projectId=self.project_id, - jobId=self.running_job_id).execute() - else: - self.log.info('No running BigQuery jobs to cancel.') - return - - # Wait for all the calls to cancel to finish - max_polling_attempts = 12 - polling_attempts = 0 - - job_complete = False - while polling_attempts < max_polling_attempts and not job_complete: - polling_attempts = polling_attempts + 1 - job_complete = self.poll_job_complete(self.running_job_id) - if job_complete: - self.log.info('Job successfully canceled: %s, %s', - self.project_id, self.running_job_id) - elif polling_attempts == max_polling_attempts: - self.log.info( - "Stopping polling due to timeout. Job with id %s " - "has not completed cancel and may or may not finish.", - self.running_job_id) - else: - self.log.info('Waiting for canceled job with id %s to finish.', - self.running_job_id) - time.sleep(5) + return client.load_table_from_uri( + source_uris, + table_ref, + project=project_id, + location=location, + job_config=job_config, + ).result() - def get_schema(self, dataset_id, table_id): + def get_schema(self, dataset_id, table_id, project_id=None): """ Get the schema for a given datset.table. see https://cloud.google.com/bigquery/docs/reference/v2/tables#resource @@ -1326,14 +971,18 @@ def get_schema(self, dataset_id, table_id): :param table_id: the table ID of the requested table :return: a table schema """ - tables_resource = self.service.tables() \ - .get(projectId=self.project_id, datasetId=dataset_id, tableId=table_id) \ - .execute() - return tables_resource['schema'] + project_id = project_id if project_id is not None else self.project_id + client = self.get_client(project_id) + + table_ref = bigquery.TableReference( + bigquery.DatasetReference(project_id, dataset_id), + table_id, + ) + return client.get_table(table_ref).schema - def get_tabledata(self, dataset_id, table_id, + def get_tabledata(self, dataset_id, table_id, project_id=None, max_results=None, selected_fields=None, page_token=None, - start_index=None): + start_index=None, page_size=None): """ Get the data of a given dataset.table and optionally with selected columns. see https://cloud.google.com/bigquery/docs/reference/v2/tabledata/list @@ -1345,23 +994,26 @@ def get_tabledata(self, dataset_id, table_id, unspecified, all fields are returned. :param page_token: page token, returned from a previous call, identifying the result set. - :param start_index: zero based index of the starting row to read. - :return: map containing the requested rows. + :param page_token: The maximum number of rows in each page of results + from this request """ - optional_params = {} - if max_results: - optional_params['maxResults'] = max_results - if selected_fields: - optional_params['selectedFields'] = selected_fields - if page_token: - optional_params['pageToken'] = page_token - if start_index: - optional_params['startIndex'] = start_index - return (self.service.tabledata().list( - projectId=self.project_id, - datasetId=dataset_id, - tableId=table_id, - **optional_params).execute()) + project_id = project_id if project_id is not None else self.project_id + client = self.get_client(project_id) + + table_ref = bigquery.TableReference( + bigquery.DatasetReference(project_id, dataset_id), + table_id, + ) + table = bigquery.Table(table_ref) + + return client.list_rows( + table, + selected_fields=selected_fields, + max_results=max_results, + page_token=page_token, + start_index=start_index, + page_size=page_size, + ) def run_table_delete(self, deletion_dataset_table, ignore_if_missing=False): @@ -1379,25 +1031,21 @@ def run_table_delete(self, deletion_dataset_table, :type ignore_if_missing: bool :return: """ - deletion_project, deletion_dataset, deletion_table = \ - _split_tablename(table_input=deletion_dataset_table, - default_project_id=self.project_id) + table_ref = _split_tablename(table_input=deletion_dataset_table, + default_project_id=self.project_id) + + client = self.get_client(table_ref.project) try: - self.service.tables() \ - .delete(projectId=deletion_project, - datasetId=deletion_dataset, - tableId=deletion_table) \ - .execute() - self.log.info('Deleted table %s:%s.%s.', deletion_project, - deletion_dataset, deletion_table) - except HttpError: + client.delete_table(table_ref) + self.log.info('Deleted table %s', table_ref) + except NotFound: if not ignore_if_missing: - raise Exception('Table deletion failed. Table does not exist.') + raise else: self.log.info('Table does not exist. Skipping.') - def run_table_upsert(self, dataset_id, table_resource, project_id=None): + def run_table_upsert(self, table, fields, project_id=None): """ creates a new, empty table in the dataset; If the table already exists, update the existing table. @@ -1409,108 +1057,21 @@ def run_table_upsert(self, dataset_id, table_resource, project_id=None): :param table_resource: a table resource. see https://cloud.google.com/bigquery/docs/reference/v2/tables#resource :type table_resource: dict - :param project_id: the project to upsert the table into. If None, + :param project_id: the project to upsert the table into. If None, project will be self.project_id. :return: """ - # check to see if the table exists - table_id = table_resource['tableReference']['tableId'] project_id = project_id if project_id is not None else self.project_id - tables_list_resp = self.service.tables().list( - projectId=project_id, datasetId=dataset_id).execute() - while True: - for table in tables_list_resp.get('tables', []): - if table['tableReference']['tableId'] == table_id: - # found the table, do update - self.log.info('Table %s:%s.%s exists, updating.', - project_id, dataset_id, table_id) - return self.service.tables().update( - projectId=project_id, - datasetId=dataset_id, - tableId=table_id, - body=table_resource).execute() - # If there is a next page, we need to check the next page. - if 'nextPageToken' in tables_list_resp: - tables_list_resp = self.service.tables()\ - .list(projectId=project_id, - datasetId=dataset_id, - pageToken=tables_list_resp['nextPageToken'])\ - .execute() - # If there is no next page, then the table doesn't exist. - else: - # do insert - self.log.info('Table %s:%s.%s does not exist. creating.', - project_id, dataset_id, table_id) - return self.service.tables().insert( - projectId=project_id, - datasetId=dataset_id, - body=table_resource).execute() - - def run_grant_dataset_view_access(self, - source_dataset, - view_dataset, - view_table, - source_project=None, - view_project=None): - """ - Grant authorized view access of a dataset to a view table. - If this view has already been granted access to the dataset, do nothing. - This method is not atomic. Running it may clobber a simultaneous update. - - :param source_dataset: the source dataset - :type source_dataset: str - :param view_dataset: the dataset that the view is in - :type view_dataset: str - :param view_table: the table of the view - :type view_table: str - :param source_project: the project of the source dataset. If None, - self.project_id will be used. - :type source_project: str - :param view_project: the project that the view is in. If None, - self.project_id will be used. - :type view_project: str - :return: the datasets resource of the source dataset. - """ + client = self.get_client(project_id) - # Apply default values to projects - source_project = source_project if source_project else self.project_id - view_project = view_project if view_project else self.project_id - - # we don't want to clobber any existing accesses, so we have to get - # info on the dataset before we can add view access - source_dataset_resource = self.service.datasets().get( - projectId=source_project, datasetId=source_dataset).execute() - access = source_dataset_resource[ - 'access'] if 'access' in source_dataset_resource else [] - view_access = { - 'view': { - 'projectId': view_project, - 'datasetId': view_dataset, - 'tableId': view_table - } - } - # check to see if the view we want to add already exists. - if view_access not in access: - self.log.info( - 'Granting table %s:%s.%s authorized view access to %s:%s dataset.', - view_project, view_dataset, view_table, source_project, - source_dataset) - access.append(view_access) - return self.service.datasets().patch( - projectId=source_project, - datasetId=source_dataset, - body={ - 'access': access - }).execute() - else: - # if view is already in access, do nothing. - self.log.info( - 'Table %s:%s.%s already has authorized view access to %s:%s dataset.', - view_project, view_dataset, view_table, source_project, source_dataset) - return source_dataset_resource - - def create_empty_dataset(self, dataset_id="", project_id="", - dataset_reference=None): + # Update or create table + try: + client.update_table(table, fields) + except NotFound: + client.create_table(table) + + def create_empty_dataset(self, dataset_id=None, project_id=None, + dataset_ref=None): """ Create a new empty dataset: https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets/insert @@ -1524,81 +1085,37 @@ def create_empty_dataset(self, dataset_id="", project_id="", :param dataset_reference: Dataset reference that could be provided with request body. More info: https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource - :type dataset_reference: dict + :type dataset_reference: DatasetReference """ + project_id = project_id if project_id is not None else self.project_id + client = self.get_client(project_id) - if dataset_reference: - _validate_value('dataset_reference', dataset_reference, dict) - else: - dataset_reference = {} - - if "datasetReference" not in dataset_reference: - dataset_reference["datasetReference"] = {} - - if not dataset_reference["datasetReference"].get("datasetId") and not dataset_id: - raise ValueError( - "{} not provided datasetId. Impossible to create dataset") - - dataset_required_params = [(dataset_id, "datasetId", ""), - (project_id, "projectId", self.project_id)] - for param_tuple in dataset_required_params: - param, param_name, param_default = param_tuple - if param_name not in dataset_reference['datasetReference']: - if param_default and not param: - self.log.info("{} was not specified. Will be used default " - "value {}.".format(param_name, - param_default)) - param = param_default - dataset_reference['datasetReference'].update( - {param_name: param}) - elif param: - _api_resource_configs_duplication_check( - param_name, param, - dataset_reference['datasetReference'], 'dataset_reference') - - dataset_id = dataset_reference.get("datasetReference").get("datasetId") - dataset_project_id = dataset_reference.get("datasetReference").get( - "projectId") - - self.log.info('Creating Dataset: %s in project: %s ', dataset_id, - dataset_project_id) + if dataset_ref is None: + dataset_ref = bigquery.DatasetReference(project_id, dataset_id) + self.log.info('Creating Dataset: %s in project: %s ', dataset_ref) - try: - self.service.datasets().insert( - projectId=dataset_project_id, - body=dataset_reference).execute() - self.log.info('Dataset created successfully: In project %s ' - 'Dataset %s', dataset_project_id, dataset_id) - - except HttpError as err: - raise AirflowException( - 'BigQuery job failed. Error was: {}'.format(err.content) - ) + client.create_dataset(dataset_ref) + self.log.info('Dataset created successfully: %s', dataset_ref) - def delete_dataset(self, project_id, dataset_id): + def delete_dataset(self, dataset_id, project_id=None): """ Delete a dataset of Big query in your project. - :param project_id: The name of the project where we have the dataset . - :type project_id: str :param dataset_id: The dataset to be delete. :type dataset_id: str + :param project_id: The name of the project where we have the dataset . + :type project_id: str :return: """ project_id = project_id if project_id is not None else self.project_id + client = self.get_client(project_id) + + dataset_ref = bigquery.DatasetReference(project_id, dataset_id) self.log.info('Deleting from project: %s Dataset:%s', project_id, dataset_id) - try: - self.service.datasets().delete( - projectId=project_id, - datasetId=dataset_id).execute() - self.log.info('Dataset deleted successfully: In project %s ' - 'Dataset %s', project_id, dataset_id) - - except HttpError as err: - raise AirflowException( - 'BigQuery job failed. Error was: {}'.format(err.content) - ) + client.delete_dataset(dataset_ref) + self.log.info('Dataset deleted successfully: In project %s ' + 'Dataset %s', project_id, dataset_id) def get_dataset(self, dataset_id, project_id=None): """ @@ -1615,24 +1132,16 @@ def get_dataset(self, dataset_id, project_id=None): For more information, see Dataset Resource content: https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets#resource """ + project_id = project_id if project_id is not None else self.project_id + client = self.get_client(project_id) - if not dataset_id or not isinstance(dataset_id, str): - raise ValueError("dataset_id argument must be provided and has " - "a type 'str'. You provided: {}".format(dataset_id)) - - dataset_project_id = project_id if project_id else self.project_id - - try: - dataset_resource = self.service.datasets().get( - datasetId=dataset_id, projectId=dataset_project_id).execute() - self.log.info("Dataset Resource: {}".format(dataset_resource)) - except HttpError as err: - raise AirflowException( - 'BigQuery job failed. Error was: {}'.format(err.content)) + dataset_ref = bigquery.DatasetReference(project_id, dataset_id) - return dataset_resource + dataset = client.get_dataset(dataset_ref) + self.log.info("Dataset Resource: {}".format(dataset)) + return dataset - def get_datasets_list(self, project_id=None): + def list_datasets(self, project_id=None): """ Method returns full list of BigQuery datasets in the current project @@ -1667,21 +1176,15 @@ def get_datasets_list(self, project_id=None): } ] """ - dataset_project_id = project_id if project_id else self.project_id - - try: - datasets_list = self.service.datasets().list( - projectId=dataset_project_id).execute()['datasets'] - self.log.info("Datasets List: {}".format(datasets_list)) - - except HttpError as err: - raise AirflowException( - 'BigQuery job failed. Error was: {}'.format(err.content)) + project_id = project_id if project_id is not None else self.project_id + client = self.get_client(project_id) - return datasets_list + datasets = list(client.list_datasets(project_id)) + self.log.info("Datasets List: {}".format(datasets)) + return datasets - def insert_all(self, project_id, dataset_id, table_id, - rows, ignore_unknown_values=False, + def insert_all(self, dataset_id, table_id, rows, + project_id=None, ignore_unknown_values=False, skip_invalid_rows=False, fail_on_error=False): """ Method to stream data into BigQuery one record at a time without needing @@ -1716,252 +1219,63 @@ def insert_all(self, project_id, dataset_id, table_id, even if any insertion errors occur. :type fail_on_error: bool """ + project_id = project_id if project_id is not None else self.project_id + client = self.get_client(project_id) - dataset_project_id = project_id if project_id else self.project_id - - body = { - "rows": rows, - "ignoreUnknownValues": ignore_unknown_values, - "kind": "bigquery#tableDataInsertAllRequest", - "skipInvalidRows": skip_invalid_rows, - } - - try: - self.log.info('Inserting {} row(s) into Table {}:{}.{}'.format( - len(rows), dataset_project_id, - dataset_id, table_id)) - - resp = self.service.tabledata().insertAll( - projectId=dataset_project_id, datasetId=dataset_id, - tableId=table_id, body=body - ).execute() - - if 'insertErrors' not in resp: - self.log.info('All row(s) inserted successfully: {}:{}.{}'.format( - dataset_project_id, dataset_id, table_id)) - else: - error_msg = '{} insert error(s) occurred: {}:{}.{}. Details: {}'.format( - len(resp['insertErrors']), - dataset_project_id, dataset_id, table_id, resp['insertErrors']) - if fail_on_error: - raise AirflowException( - 'BigQuery job failed. Error was: {}'.format(error_msg) - ) - self.log.info(error_msg) - except HttpError as err: - raise AirflowException( - 'BigQuery job failed. Error was: {}'.format(err.content) - ) - - -class BigQueryCursor(BigQueryBaseCursor): - """ - A very basic BigQuery PEP 249 cursor implementation. The PyHive PEP 249 - implementation was used as a reference: - - https://github.com/dropbox/PyHive/blob/master/pyhive/presto.py - https://github.com/dropbox/PyHive/blob/master/pyhive/common.py - """ - - def __init__(self, service, project_id, use_legacy_sql=True, location=None): - super(BigQueryCursor, self).__init__( - service=service, - project_id=project_id, - use_legacy_sql=use_legacy_sql, - location=location, + table_ref = bigquery.TableReference( + bigquery.DatasetReference(project_id, dataset_id), + table_id, ) - self.buffersize = None - self.page_token = None - self.job_id = None - self.buffer = [] - self.all_pages_loaded = False - - @property - def description(self): - """ The schema description method is not currently implemented. """ - raise NotImplementedError - - def close(self): - """ By default, do nothing """ - pass - - @property - def rowcount(self): - """ By default, return -1 to indicate that this is not supported. """ - return -1 - - def execute(self, operation, parameters=None): - """ - Executes a BigQuery query, and returns the job ID. - - :param operation: The query to execute. - :type operation: str - :param parameters: Parameters to substitute into the query. - :type parameters: dict - """ - sql = _bind_parameters(operation, - parameters) if parameters else operation - self.job_id = self.run_query(sql) - - def executemany(self, operation, seq_of_parameters): - """ - Execute a BigQuery query multiple times with different parameters. - - :param operation: The query to execute. - :type operation: str - :param seq_of_parameters: List of dictionary parameters to substitute into the - query. - :type seq_of_parameters: list - """ - for parameters in seq_of_parameters: - self.execute(operation, parameters) - - def fetchone(self): - """ Fetch the next row of a query result set. """ - return self.next() - - def next(self): - """ - Helper method for fetchone, which returns the next row from a buffer. - If the buffer is empty, attempts to paginate through the result set for - the next page, and load it into the buffer. - """ - if not self.job_id: - return None - - if len(self.buffer) == 0: - if self.all_pages_loaded: - return None - - query_results = (self.service.jobs().getQueryResults( - projectId=self.project_id, - jobId=self.job_id, - pageToken=self.page_token).execute()) - - if 'rows' in query_results and query_results['rows']: - self.page_token = query_results.get('pageToken') - fields = query_results['schema']['fields'] - col_types = [field['type'] for field in fields] - rows = query_results['rows'] - - for dict_row in rows: - typed_row = ([ - _bq_cast(vs['v'], col_types[idx]) - for idx, vs in enumerate(dict_row['f']) - ]) - self.buffer.append(typed_row) - - if not self.page_token: - self.all_pages_loaded = True - else: - # Reset all state since we've exhausted the results. - self.page_token = None - self.job_id = None - self.page_token = None - return None - - return self.buffer.pop(0) - - def fetchmany(self, size=None): - """ - Fetch the next set of rows of a query result, returning a sequence of sequences - (e.g. a list of tuples). An empty sequence is returned when no more rows are - available. The number of rows to fetch per call is specified by the parameter. - If it is not given, the cursor's arraysize determines the number of rows to be - fetched. The method should try to fetch as many rows as indicated by the size - parameter. If this is not possible due to the specified number of rows not being - available, fewer rows may be returned. An :py:class:`~pyhive.exc.Error` - (or subclass) exception is raised if the previous call to - :py:meth:`execute` did not produce any result set or no call was issued yet. - """ - if size is None: - size = self.arraysize - result = [] - for _ in range(size): - one = self.fetchone() - if one is None: - break - else: - result.append(one) - return result - - def fetchall(self): - """ - Fetch all (remaining) rows of a query result, returning them as a sequence of - sequences (e.g. a list of tuples). - """ - result = [] - while True: - one = self.fetchone() - if one is None: - break - else: - result.append(one) - return result - - def get_arraysize(self): - """ Specifies the number of rows to fetch at a time with .fetchmany() """ - return self._buffersize if self.buffersize else 1 - - def set_arraysize(self, arraysize): - """ Specifies the number of rows to fetch at a time with .fetchmany() """ - self.buffersize = arraysize - - arraysize = property(get_arraysize, set_arraysize) - - def setinputsizes(self, sizes): - """ Does nothing by default """ - pass - - def setoutputsize(self, size, column=None): - """ Does nothing by default """ - pass + self.log.info('Inserting {} row(s) into Table {}:{}.{}'.format( + len(rows), project_id, + dataset_id, table_id)) + table = client.get_table(table_ref) + errors = client.insert_rows( + table, + rows, + skip_invalid_rows=skip_invalid_rows, + ignore_unknown_values=ignore_unknown_values, + ) -def _bind_parameters(operation, parameters): - """ Helper method that binds parameters to a SQL query. """ - # inspired by MySQL Python Connector (conversion.py) - string_parameters = {} - for (name, value) in iteritems(parameters): - if value is None: - string_parameters[name] = 'NULL' - elif isinstance(value, basestring): - string_parameters[name] = "'" + _escape(value) + "'" + if not errors: + self.log.info('All row(s) inserted successfully: {}:{}.{}'.format( + project_id, dataset_id, table_id)) else: - string_parameters[name] = str(value) - return operation % string_parameters + error_msg = '{} insert error(s) occurred: {}:{}.{}. Details: {}'.format( + len(errors), + project_id, dataset_id, table_id, errors) + if fail_on_error: + raise AirflowException( + 'BigQuery job failed. Error was: {}'.format(error_msg) + ) + self.log.info(error_msg) -def _escape(s): - """ Helper method that escapes parameters to a SQL query. """ - e = s - e = e.replace('\\', '\\\\') - e = e.replace('\n', '\\n') - e = e.replace('\r', '\\r') - e = e.replace("'", "\\'") - e = e.replace('"', '\\"') - return e - - -def _bq_cast(string_field, bq_type): +class BigQueryPandasConnector(GbqConnector): """ - Helper method that casts a BigQuery row to the appropriate data types. - This is useful because BigQuery returns all fields as strings. + This connector behaves identically to GbqConnector (from Pandas), except + that it allows the service to be injected, and disables a call to + self.get_credentials(). This allows Airflow to use BigQuery with Pandas + without forcing a three legged OAuth connection. Instead, we can inject + service account credentials into the binding. """ - if string_field is None: - return None - elif bq_type == 'INTEGER': - return int(string_field) - elif bq_type == 'FLOAT' or bq_type == 'TIMESTAMP': - return float(string_field) - elif bq_type == 'BOOLEAN': - if string_field not in ['true', 'false']: - raise ValueError("{} must have value 'true' or 'false'".format( - string_field)) - return string_field == 'true' - else: - return string_field + + def __init__(self, + project_id, + service, + reauth=False, + verbose=False, + dialect='legacy'): + super(BigQueryPandasConnector, self).__init__(project_id) + gbq_check_google_client_version() + gbq_test_google_api_imports() + self.project_id = project_id + self.reauth = reauth + self.service = service + self.verbose = verbose + self.dialect = dialect def _split_tablename(table_input, default_project_id, var_name=None): @@ -1971,9 +1285,6 @@ def _split_tablename(table_input, default_project_id, var_name=None): 'Expected target table name in the format of ' '.. Got: {}'.format(table_input)) - if not default_project_id: - raise ValueError("INTERNAL: No default project is specified") - def var_print(var_name): if var_name is None: return "" @@ -2018,6 +1329,8 @@ def var_print(var_name): 'got {input}').format(var=var_print(var_name), input=table_input)) if project_id is None: + if not default_project_id: + raise ValueError("INTERNAL: No default project is specified") if var_name is not None: log = LoggingMixin().log log.info('Project not included in {var}: {input}; ' @@ -2027,20 +1340,10 @@ def var_print(var_name): project=default_project_id)) project_id = default_project_id - return project_id, dataset_id, table_id - - -def _cleanse_time_partitioning(destination_dataset_table, time_partitioning_in): - # if it is a partitioned table ($ is in the table name) add partition load option - - if time_partitioning_in is None: - time_partitioning_in = {} - - time_partitioning_out = {} - if destination_dataset_table and '$' in destination_dataset_table: - time_partitioning_out['type'] = 'DAY' - time_partitioning_out.update(time_partitioning_in) - return time_partitioning_out + return bigquery.TableReference( + bigquery.DatasetReference(project_id, dataset_id), + table_id, + ) def _validate_value(key, value, expected_type): @@ -2049,13 +1352,3 @@ def _validate_value(key, value, expected_type): if not isinstance(value, expected_type): raise TypeError("{} argument must have a type {} not {}".format( key, expected_type, type(value))) - - -def _api_resource_configs_duplication_check(key, value, config_dict, - config_dict_name='api_resource_configs'): - if key in config_dict and value != config_dict[key]: - raise ValueError("Values of {param_name} param are duplicated. " - "{dict_name} contained {param_name} param " - "in `query` config and {param_name} was also provided " - "with arg to run_query() method. Please remove duplicates." - .format(param_name=key, dict_name=config_dict_name)) diff --git a/airflow/contrib/operators/bigquery_check_operator.py b/airflow/contrib/operators/bigquery_check_operator.py index 247a1ae7fba1b..0d2deecb9d5dd 100644 --- a/airflow/contrib/operators/bigquery_check_operator.py +++ b/airflow/contrib/operators/bigquery_check_operator.py @@ -64,7 +64,7 @@ class BigQueryCheckOperator(CheckOperator): def __init__(self, sql, bigquery_conn_id='bigquery_default', - use_legacy_sql=True, + use_legacy_sql=False, *args, **kwargs): super(BigQueryCheckOperator, self).__init__(sql=sql, *args, **kwargs) self.bigquery_conn_id = bigquery_conn_id @@ -92,7 +92,7 @@ def __init__(self, sql, pass_value, tolerance=None, bigquery_conn_id='bigquery_default', - use_legacy_sql=True, + use_legacy_sql=False, *args, **kwargs): super(BigQueryValueCheckOperator, self).__init__( sql=sql, pass_value=pass_value, tolerance=tolerance, @@ -132,7 +132,7 @@ class BigQueryIntervalCheckOperator(IntervalCheckOperator): @apply_defaults def __init__(self, table, metrics_thresholds, date_filter_column='ds', days_back=-7, bigquery_conn_id='bigquery_default', - use_legacy_sql=True, *args, **kwargs): + use_legacy_sql=False, *args, **kwargs): super(BigQueryIntervalCheckOperator, self).__init__( table=table, metrics_thresholds=metrics_thresholds, date_filter_column=date_filter_column, days_back=days_back, diff --git a/airflow/contrib/operators/bigquery_get_data.py b/airflow/contrib/operators/bigquery_get_data.py index f5e6e50f066d5..9172f720bd50c 100644 --- a/airflow/contrib/operators/bigquery_get_data.py +++ b/airflow/contrib/operators/bigquery_get_data.py @@ -96,21 +96,10 @@ def execute(self, context): hook = BigQueryHook(bigquery_conn_id=self.bigquery_conn_id, delegate_to=self.delegate_to) - conn = hook.get_conn() - cursor = conn.cursor() - response = cursor.get_tabledata(dataset_id=self.dataset_id, - table_id=self.table_id, - max_results=self.max_results, - selected_fields=self.selected_fields) - - self.log.info('Total Extracted rows: %s', response['totalRows']) - rows = response['rows'] - - table_data = [] - for dict_row in rows: - single_row = [] - for fields in dict_row['f']: - single_row.append(fields['v']) - table_data.append(single_row) - - return table_data + rows = hook.get_tabledata(dataset_id=self.dataset_id, + table_id=self.table_id, + max_results=self.max_results, + selected_fields=self.selected_fields) + rows = list(rows) + self.log.info('Total Extracted rows: %s', len(rows)) + return rows diff --git a/airflow/contrib/operators/bigquery_operator.py b/airflow/contrib/operators/bigquery_operator.py index f597db93a58a3..f4f1609f45bcd 100644 --- a/airflow/contrib/operators/bigquery_operator.py +++ b/airflow/contrib/operators/bigquery_operator.py @@ -117,7 +117,7 @@ def __init__(self, bigquery_conn_id='bigquery_default', delegate_to=None, udf_config=None, - use_legacy_sql=True, + use_legacy_sql=False, maximum_billing_tier=None, maximum_bytes_billed=None, create_disposition='CREATE_IF_NEEDED', @@ -147,7 +147,6 @@ def __init__(self, self.schema_update_options = schema_update_options self.query_params = query_params self.labels = labels - self.bq_cursor = None self.priority = priority self.time_partitioning = time_partitioning self.api_resource_configs = api_resource_configs @@ -155,17 +154,14 @@ def __init__(self, self.location = location def execute(self, context): - if self.bq_cursor is None: - self.log.info('Executing: %s', self.sql) - hook = BigQueryHook( - bigquery_conn_id=self.bigquery_conn_id, - use_legacy_sql=self.use_legacy_sql, - delegate_to=self.delegate_to, - location=self.location, - ) - conn = hook.get_conn() - self.bq_cursor = conn.cursor() - self.bq_cursor.run_query( + self.log.info('Executing: %s', self.sql) + hook = BigQueryHook( + bigquery_conn_id=self.bigquery_conn_id, + use_legacy_sql=self.use_legacy_sql, + delegate_to=self.delegate_to, + location=self.location, + ) + hook.run_query( sql=self.sql, destination_dataset_table=self.destination_dataset_table, write_disposition=self.write_disposition, @@ -184,12 +180,6 @@ def execute(self, context): cluster_fields=self.cluster_fields, ) - def on_kill(self): - super(BigQueryOperator, self).on_kill() - if self.bq_cursor is not None: - self.log.info('Cancelling running query') - self.bq_cursor.cancel_query() - class BigQueryCreateEmptyTableOperator(BaseOperator): """ @@ -328,10 +318,7 @@ def execute(self, context): else: schema_fields = self.schema_fields - conn = bq_hook.get_conn() - cursor = conn.cursor() - - cursor.create_empty_table( + bq_hook.create_empty_table( project_id=self.project_id, dataset_id=self.dataset_id, table_id=self.table_id, @@ -408,8 +395,8 @@ class BigQueryCreateExternalTableOperator(BaseOperator): work, the service account making the request must have domain-wide delegation enabled. :type delegate_to: str - :param src_fmt_configs: configure optional fields specific to the source format - :type src_fmt_configs: dict + :param external_config_options: configure optional fields specific to the source format + :type external_config_options: dict :param labels: a dictionary containing labels for the table, passed to BigQuery :type labels: dict """ @@ -435,7 +422,7 @@ def __init__(self, bigquery_conn_id='bigquery_default', google_cloud_storage_conn_id='google_cloud_default', delegate_to=None, - src_fmt_configs={}, + external_config_options=None, labels=None, *args, **kwargs): @@ -462,7 +449,7 @@ def __init__(self, self.google_cloud_storage_conn_id = google_cloud_storage_conn_id self.delegate_to = delegate_to - self.src_fmt_configs = src_fmt_configs + self.external_config_options = external_config_options self.labels = labels def execute(self, context): @@ -482,10 +469,8 @@ def execute(self, context): source_uris = ['gs://{}/{}'.format(self.bucket, source_object) for source_object in self.source_objects] - conn = bq_hook.get_conn() - cursor = conn.cursor() - cursor.create_external_table( + bq_hook.create_external_table( external_project_dataset_table=self.destination_project_dataset_table, schema_fields=schema_fields, source_uris=source_uris, @@ -497,8 +482,8 @@ def execute(self, context): quote_character=self.quote_character, allow_quoted_newlines=self.allow_quoted_newlines, allow_jagged_rows=self.allow_jagged_rows, - src_fmt_configs=self.src_fmt_configs, - labels=self.labels + external_config_options=self.external_config_options, + labels=self.labels, ) @@ -545,10 +530,7 @@ def execute(self, context): bq_hook = BigQueryHook(bigquery_conn_id=self.bigquery_conn_id, delegate_to=self.delegate_to) - conn = bq_hook.get_conn() - cursor = conn.cursor() - - cursor.delete_dataset( + bq_hook.delete_dataset( project_id=self.project_id, dataset_id=self.dataset_id ) @@ -608,10 +590,7 @@ def execute(self, context): bq_hook = BigQueryHook(bigquery_conn_id=self.bigquery_conn_id, delegate_to=self.delegate_to) - conn = bq_hook.get_conn() - cursor = conn.cursor() - - cursor.create_empty_dataset( + bq_hook.create_empty_dataset( project_id=self.project_id, dataset_id=self.dataset_id, dataset_reference=self.dataset_reference) diff --git a/airflow/contrib/operators/bigquery_table_delete_operator.py b/airflow/contrib/operators/bigquery_table_delete_operator.py index 45c481454e475..9fb057d441f2d 100644 --- a/airflow/contrib/operators/bigquery_table_delete_operator.py +++ b/airflow/contrib/operators/bigquery_table_delete_operator.py @@ -61,6 +61,4 @@ def execute(self, context): self.log.info('Deleting: %s', self.deletion_dataset_table) hook = BigQueryHook(bigquery_conn_id=self.bigquery_conn_id, delegate_to=self.delegate_to) - conn = hook.get_conn() - cursor = conn.cursor() - cursor.run_table_delete(self.deletion_dataset_table, self.ignore_if_missing) + hook.run_table_delete(self.deletion_dataset_table, self.ignore_if_missing) diff --git a/airflow/contrib/operators/bigquery_to_bigquery.py b/airflow/contrib/operators/bigquery_to_bigquery.py index 288731e157de7..d2b77640b6c4d 100644 --- a/airflow/contrib/operators/bigquery_to_bigquery.py +++ b/airflow/contrib/operators/bigquery_to_bigquery.py @@ -85,9 +85,7 @@ def execute(self, context): ) hook = BigQueryHook(bigquery_conn_id=self.bigquery_conn_id, delegate_to=self.delegate_to) - conn = hook.get_conn() - cursor = conn.cursor() - cursor.run_copy( + hook.run_copy( self.source_project_dataset_tables, self.destination_project_dataset_table, self.write_disposition, diff --git a/airflow/contrib/operators/bigquery_to_gcs.py b/airflow/contrib/operators/bigquery_to_gcs.py index ec6b937d218b2..5fc67e3406f75 100644 --- a/airflow/contrib/operators/bigquery_to_gcs.py +++ b/airflow/contrib/operators/bigquery_to_gcs.py @@ -93,9 +93,7 @@ def execute(self, context): self.destination_cloud_storage_uris) hook = BigQueryHook(bigquery_conn_id=self.bigquery_conn_id, delegate_to=self.delegate_to) - conn = hook.get_conn() - cursor = conn.cursor() - cursor.run_extract( + hook.run_extract( self.source_project_dataset_table, self.destination_cloud_storage_uris, self.compression, diff --git a/airflow/contrib/operators/gcs_to_bq.py b/airflow/contrib/operators/gcs_to_bq.py index abbf380b2d0fb..a7d2b7830404d 100644 --- a/airflow/contrib/operators/gcs_to_bq.py +++ b/airflow/contrib/operators/gcs_to_bq.py @@ -148,7 +148,7 @@ def __init__(self, google_cloud_storage_conn_id='google_cloud_default', delegate_to=None, schema_update_options=(), - src_fmt_configs=None, + external_config_options=None, external_table=False, time_partitioning=None, cluster_fields=None, @@ -158,10 +158,6 @@ def __init__(self, super(GoogleCloudStorageToBigQueryOperator, self).__init__(*args, **kwargs) # GCS config - if src_fmt_configs is None: - src_fmt_configs = {} - if time_partitioning is None: - time_partitioning = {} self.bucket = bucket self.source_objects = source_objects self.schema_object = schema_object @@ -188,7 +184,7 @@ def __init__(self, self.delegate_to = delegate_to self.schema_update_options = schema_update_options - self.src_fmt_configs = src_fmt_configs + self.external_config_options, = external_config_options, self.time_partitioning = time_partitioning self.cluster_fields = cluster_fields self.autodetect = autodetect @@ -220,7 +216,7 @@ def execute(self, context): cursor = conn.cursor() if self.external_table: - cursor.create_external_table( + bq_hook.create_external_table( external_project_dataset_table=self.destination_project_dataset_table, schema_fields=schema_fields, source_uris=source_uris, @@ -233,10 +229,10 @@ def execute(self, context): ignore_unknown_values=self.ignore_unknown_values, allow_quoted_newlines=self.allow_quoted_newlines, allow_jagged_rows=self.allow_jagged_rows, - src_fmt_configs=self.src_fmt_configs + external_config_options=self.external_config_options, ) else: - cursor.run_load( + bq_hook.run_load( destination_project_dataset_table=self.destination_project_dataset_table, schema_fields=schema_fields, source_uris=source_uris, diff --git a/setup.py b/setup.py index afee79651f1f4..281df816ad07f 100644 --- a/setup.py +++ b/setup.py @@ -176,6 +176,7 @@ def write_version(filename=os.path.join(*['airflow', 'google-auth-httplib2>=0.0.1', 'google-cloud-container>=0.1.1', 'google-cloud-bigtable==0.31.0', + 'google-cloud-bigquery>=1.8.1', 'google-cloud-spanner>=1.7.1', 'grpcio-gcp>=0.2.2', 'PyOpenSSL', diff --git a/tests/contrib/hooks/test_bigquery_hook.py b/tests/contrib/hooks/test_bigquery_hook.py index 9bae4abb2a642..f2215e261f855 100644 --- a/tests/contrib/hooks/test_bigquery_hook.py +++ b/tests/contrib/hooks/test_bigquery_hook.py @@ -18,24 +18,38 @@ # under the License. # +import os import unittest -from google.auth.exceptions import GoogleAuthError import mock -from googleapiclient.errors import HttpError +from google.cloud import bigquery +from airflow.utils import timezone +from airflow import AirflowException from airflow.contrib.hooks import bigquery_hook as hook -from airflow.contrib.hooks.bigquery_hook import _cleanse_time_partitioning, \ - _validate_value, _api_resource_configs_duplication_check +from airflow.contrib.hooks.bigquery_hook import _validate_value bq_available = True +if "TRAVIS" in os.environ and bool(os.environ["TRAVIS"]): + bq_available = False + try: - hook.BigQueryHook().get_service() -except GoogleAuthError: + bigquery.Client() +except Exception: bq_available = False +class BigQueryTestBase(unittest.TestCase): + def setUp(self): + super(BigQueryTestBase, self).setUp() + self.bq_client = mock.Mock() + self.bq_hook = hook.BigQueryHook() + patcher = mock.patch.object(hook.BigQueryHook, 'get_client') + self.addCleanup(patcher.stop) + patcher.start().return_value = self.bq_client + + class TestPandasGbqPrivateKey(unittest.TestCase): def setUp(self): self.instance = hook.BigQueryHook() @@ -109,41 +123,43 @@ def test_internal_need_default_project(self): str(context.exception), "") def test_split_dataset_table(self): - project, dataset, table = hook._split_tablename('dataset.table', - 'project') - self.assertEqual("project", project) - self.assertEqual("dataset", dataset) - self.assertEqual("table", table) + table_ref = hook._split_tablename('dataset.table', + 'project') + self.assertEqual('project', table_ref.project) + self.assertEqual("dataset", table_ref.dataset_id) + self.assertEqual("table", table_ref.table_id) def test_split_project_dataset_table(self): - project, dataset, table = hook._split_tablename('alternative:dataset.table', - 'project') - self.assertEqual("alternative", project) - self.assertEqual("dataset", dataset) - self.assertEqual("table", table) + table_ref = hook._split_tablename('alternative:dataset.table', + 'project') + + self.assertEqual('alternative', table_ref.project) + self.assertEqual("dataset", table_ref.dataset_id) + self.assertEqual("table", table_ref.table_id) def test_sql_split_project_dataset_table(self): - project, dataset, table = hook._split_tablename('alternative.dataset.table', - 'project') - self.assertEqual("alternative", project) - self.assertEqual("dataset", dataset) - self.assertEqual("table", table) + table_ref = hook._split_tablename('alternative.dataset.table', + 'project') + + self.assertEqual('alternative', table_ref.project) + self.assertEqual("dataset", table_ref.dataset_id) + self.assertEqual("table", table_ref.table_id) def test_colon_in_project(self): - project, dataset, table = hook._split_tablename('alt1:alt.dataset.table', - 'project') + table_ref = hook._split_tablename('alt1:alt.dataset.table', + 'project') - self.assertEqual('alt1:alt', project) - self.assertEqual("dataset", dataset) - self.assertEqual("table", table) + self.assertEqual('alt1:alt', table_ref.project) + self.assertEqual("dataset", table_ref.dataset_id) + self.assertEqual("table", table_ref.table_id) def test_valid_double_column(self): - project, dataset, table = hook._split_tablename('alt1:alt:dataset.table', - 'project') + table_ref = hook._split_tablename('alt1:alt:dataset.table', + 'project') - self.assertEqual('alt1:alt', project) - self.assertEqual("dataset", dataset) - self.assertEqual("table", table) + self.assertEqual('alt1:alt', table_ref.project) + self.assertEqual("dataset", table_ref.dataset_id) + self.assertEqual("table", table_ref.table_id) def test_invalid_syntax_triple_colon(self): with self.assertRaises(Exception) as context: @@ -194,23 +210,11 @@ def test_invalid_syntax_triple_dot_var(self): str(context.exception), "") -class TestBigQueryHookSourceFormat(unittest.TestCase): - def test_invalid_source_format(self): - with self.assertRaises(Exception) as context: - hook.BigQueryBaseCursor("test", "test").run_load( - "test.test", "test_schema.json", ["test_data.json"], source_format="json" - ) - - # since we passed 'json' in, and it's not valid, make sure it's present in the - # error string. - self.assertIn("JSON", str(context.exception)) - - class TestBigQueryExternalTableSourceFormat(unittest.TestCase): def test_invalid_source_format(self): with self.assertRaises(Exception) as context: - hook.BigQueryBaseCursor("test", "test").create_external_table( - external_project_dataset_table='test.test', + hook.BigQueryHook().create_external_table( + external_project_dataset_table='test:test.test', schema_fields='test_schema.json', source_uris=['test_data.json'], source_format='json' @@ -221,106 +225,60 @@ def test_invalid_source_format(self): self.assertIn("JSON", str(context.exception)) -# Helpers to test_cancel_queries that have mock_poll_job_complete returning false, -# unless mock_job_cancel was called with the same job_id -mock_canceled_jobs = [] - - -def mock_poll_job_complete(job_id): - return job_id in mock_canceled_jobs - +class TestQueries(BigQueryTestBase): -def mock_job_cancel(projectId, jobId): - mock_canceled_jobs.append(jobId) - return mock.Mock() - - -class TestBigQueryBaseCursor(unittest.TestCase): def test_invalid_schema_update_options(self): with self.assertRaises(Exception) as context: - hook.BigQueryBaseCursor("test", "test").run_load( + self.bq_hook.run_load( "test.test", - "test_schema.json", ["test_data.json"], - schema_update_options=["THIS IS NOT VALID"] + project_id='project_id', + schema_update_options=["THIS IS NOT VALID"], + schema_fields=[], ) self.assertIn("THIS IS NOT VALID", str(context.exception)) def test_invalid_schema_update_and_write_disposition(self): with self.assertRaises(Exception) as context: - hook.BigQueryBaseCursor("test", "test").run_load( + self.bq_hook.run_load( "test.test", - "test_schema.json", ["test_data.json"], + project_id='project_id', schema_update_options=['ALLOW_FIELD_ADDITION'], - write_disposition='WRITE_EMPTY' + schema_fields=[], + write_disposition='WRITE_EMPTY', ) self.assertIn("schema_update_options is only", str(context.exception)) - def test_cancel_queries(self): - project_id = 12345 - running_job_id = 3 - - mock_jobs = mock.Mock() - mock_jobs.cancel = mock.Mock(side_effect=mock_job_cancel) - mock_service = mock.Mock() - mock_service.jobs = mock.Mock(return_value=mock_jobs) - - bq_hook = hook.BigQueryBaseCursor(mock_service, project_id) - bq_hook.running_job_id = running_job_id - bq_hook.poll_job_complete = mock.Mock(side_effect=mock_poll_job_complete) - - bq_hook.cancel_query() + def test_run_query_sql_dialect_default(self): + self.bq_hook.run_query('query') + self.bq_client.query.assert_called + args, kwargs = self.bq_client.query.call_args + self.assertFalse(kwargs['job_config'].use_legacy_sql) - mock_jobs.cancel.assert_called_with(projectId=project_id, jobId=running_job_id) - - @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration') - def test_run_query_sql_dialect_default(self, run_with_config): - cursor = hook.BigQueryBaseCursor(mock.Mock(), "project_id") - cursor.run_query('query') - args, kwargs = run_with_config.call_args - self.assertIs(args[0]['query']['useLegacySql'], True) - - @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration') - def test_run_query_sql_dialect_override(self, run_with_config): - for bool_val in [True, False]: - cursor = hook.BigQueryBaseCursor(mock.Mock(), "project_id") - cursor.run_query('query', use_legacy_sql=bool_val) - args, kwargs = run_with_config.call_args - self.assertIs(args[0]['query']['useLegacySql'], bool_val) - - @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration') - def test_api_resource_configs(self, run_with_config): + def test_run_query_sql_dialect_override(self): for bool_val in [True, False]: - cursor = hook.BigQueryBaseCursor(mock.Mock(), "project_id") - cursor.run_query('query', - api_resource_configs={ - 'query': {'useQueryCache': bool_val}}) - args, kwargs = run_with_config.call_args - self.assertIs(args[0]['query']['useQueryCache'], bool_val) - self.assertIs(args[0]['query']['useLegacySql'], True) - - def test_api_resource_configs_duplication_warning(self): - with self.assertRaises(ValueError): - cursor = hook.BigQueryBaseCursor(mock.Mock(), "project_id") - cursor.run_query('query', - use_legacy_sql=True, - api_resource_configs={ - 'query': {'useLegacySql': False}}) + self.bq_hook.run_query('query', use_legacy_sql=bool_val) + args, kwargs = self.bq_client.query.call_args + self.assertIs(kwargs['job_config'].use_legacy_sql, bool_val) + + def test_run_query_labels(self): + self.bq_hook.run_query('query', labels={'foo': 'bar'}) + self.bq_client.query.assert_called + args, kwargs = self.bq_client.query.call_args + self.assertEqual(kwargs['job_config'].labels, {'foo': 'bar'}) + + def test_run_query_autodetect(self): + self.bq_hook.run_query('query', labels={'foo': 'bar'}) + self.bq_client.query.assert_called + args, kwargs = self.bq_client.query.call_args + self.assertEqual(kwargs['job_config'].labels, {'foo': 'bar'}) def test_validate_value(self): with self.assertRaises(TypeError): _validate_value("case_1", "a", dict) self.assertIsNone(_validate_value("case_2", 0, int)) - def test_duplication_check(self): - with self.assertRaises(ValueError): - key_one = True - _api_resource_configs_duplication_check( - "key_one", key_one, {"key_one": False}) - self.assertIsNone(_api_resource_configs_duplication_check( - "key_one", key_one, {"key_one": True})) - def test_insert_all_succeed(self): project_id = 'bq-project' dataset_id = 'bq_dataset' @@ -328,22 +286,18 @@ def test_insert_all_succeed(self): rows = [ {"json": {"a_key": "a_value_0"}} ] - body = { - "rows": rows, - "ignoreUnknownValues": False, - "kind": "bigquery#tableDataInsertAllRequest", - "skipInvalidRows": False, - } - - mock_service = mock.Mock() - method = mock_service.tabledata.return_value.insertAll - method.return_value.execute.return_value = { - "kind": "bigquery#tableDataInsertAllResponse" - } - cursor = hook.BigQueryBaseCursor(mock_service, 'project_id') - cursor.insert_all(project_id, dataset_id, table_id, rows) - method.assert_called_with(projectId=project_id, datasetId=dataset_id, - tableId=table_id, body=body) + + self.bq_client.insert_rows.return_value = [] + + self.bq_hook.insert_all(dataset_id, table_id, rows, project_id=project_id) + + self.bq_client.insert_rows.assert_called + args, kwargs = self.bq_client.insert_rows.call_args + self.assertEqual(args[1:], (rows, )) + self.assertEqual(kwargs, { + 'skip_invalid_rows': False, + 'ignore_unknown_values': False, + }) def test_insert_all_fail(self): project_id = 'bq-project' @@ -353,475 +307,120 @@ def test_insert_all_fail(self): {"json": {"a_key": "a_value_0"}} ] - mock_service = mock.Mock() - method = mock_service.tabledata.return_value.insertAll - method.return_value.execute.return_value = { - "kind": "bigquery#tableDataInsertAllResponse", - "insertErrors": [ - { - "index": 1, - "errors": [] - } - ] - } - cursor = hook.BigQueryBaseCursor(mock_service, 'project_id') - with self.assertRaises(Exception): - cursor.insert_all(project_id, dataset_id, table_id, - rows, fail_on_error=True) - - def test_create_view_fails_on_exception(self): - project_id = 'bq-project' - dataset_id = 'bq_dataset' - table_id = 'bq_table_view' - view = { - 'incorrect_key': 'SELECT * FROM `test-project-id.test_dataset_id.test_table_prefix*`', - "useLegacySql": False - } - - mock_service = mock.Mock() - method = mock_service.tables.return_value.insert - method.return_value.execute.side_effect = HttpError( - resp={'status': '400'}, content=b'Query is required for views') - cursor = hook.BigQueryBaseCursor(mock_service, project_id) - with self.assertRaises(Exception): - cursor.create_empty_table(project_id, dataset_id, table_id, - view=view) + self.bq_client.insert_rows.return_value = [{'errors': [{'reason': 'invalid'}]}] + + with self.assertRaises(AirflowException): + self.bq_hook.insert_all(dataset_id, table_id, rows, + project_id=project_id, fail_on_error=True) def test_create_view(self): project_id = 'bq-project' dataset_id = 'bq_dataset' table_id = 'bq_table_view' - view = { - 'query': 'SELECT * FROM `test-project-id.test_dataset_id.test_table_prefix*`', - "useLegacySql": False - } - - mock_service = mock.Mock() - method = mock_service.tables.return_value.insert - cursor = hook.BigQueryBaseCursor(mock_service, project_id) - cursor.create_empty_table(project_id, dataset_id, table_id, - view=view) - body = { - 'tableReference': { - 'tableId': table_id - }, - 'view': view - } - method.assert_called_once_with(projectId=project_id, datasetId=dataset_id, body=body) - - @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration') - def test_patch_table(self, run_with_config): + view_query = 'SELECT * FROM `test-project-id.test_dataset_id.test_table_prefix*`' + + self.bq_client.insert_rows.return_value = [{'errors': [{'reason': 'invalid'}]}] + + self.bq_hook.create_empty_table(dataset_id, table_id, project_id=project_id, view=view_query) + + self.bq_client.create_table.assert_called + args, kwargs = self.bq_client.create_table.call_args + table = args[0] + self.assertEqual(table.view_query, view_query) + + def test_patch_table(self): project_id = 'bq-project' dataset_id = 'bq_dataset' table_id = 'bq_table' description_patched = 'Test description.' - expiration_time_patched = 2524608000000 + expiration_time_patched = timezone.datetime(2019, 4, 20) friendly_name_patched = 'Test friendly name.' labels_patched = {'label1': 'test1', 'label2': 'test2'} schema_patched = [ - {'name': 'id', 'type': 'STRING', 'mode': 'REQUIRED'}, - {'name': 'name', 'type': 'STRING', 'mode': 'NULLABLE'}, - {'name': 'balance', 'type': 'FLOAT', 'mode': 'NULLABLE'}, - {'name': 'new_field', 'type': 'STRING', 'mode': 'NULLABLE'} + bigquery.SchemaField('id', 'STRING', mode='REQUIRED'), + bigquery.SchemaField('name', 'STRING', mode='NULLABLE'), + bigquery.SchemaField('balance', 'FLOAT', mode='NULLABLE'), + bigquery.SchemaField('new_field', 'STRING', mode='NULLABLE'), ] - time_partitioning_patched = { - 'expirationMs': 10000000 - } - require_partition_filter_patched = True - - mock_service = mock.Mock() - method = (mock_service.tables.return_value.patch) - cursor = hook.BigQueryBaseCursor(mock_service, project_id) - cursor.patch_table( + time_partitioning_patched = bigquery.TimePartitioning() + + self.bq_hook.patch_table( dataset_id, table_id, project_id, description=description_patched, expiration_time=expiration_time_patched, friendly_name=friendly_name_patched, labels=labels_patched, schema=schema_patched, time_partitioning=time_partitioning_patched, - require_partition_filter=require_partition_filter_patched ) - body = { - "description": description_patched, - "expirationTime": expiration_time_patched, - "friendlyName": friendly_name_patched, - "labels": labels_patched, - "schema": { - "fields": schema_patched - }, - "timePartitioning": time_partitioning_patched, - "requirePartitionFilter": require_partition_filter_patched - } - method.assert_called_once_with( - projectId=project_id, - datasetId=dataset_id, - tableId=table_id, - body=body - ) + self.bq_client.update_table.assert_called + args, kwargs = self.bq_client.update_table.call_args + table = args[0] + self.assertEqual(table.description, description_patched) + self.assertEqual(table.expires, expiration_time_patched) - @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration') - def test_patch_view(self, run_with_config): + def test_patch_view(self): project_id = 'bq-project' dataset_id = 'bq_dataset' view_id = 'bq_view' - view_patched = { - 'query': "SELECT * FROM `test-project-id.test_dataset_id.test_table_prefix*` LIMIT 500", - 'useLegacySql': False - } - - mock_service = mock.Mock() - method = (mock_service.tables.return_value.patch) - cursor = hook.BigQueryBaseCursor(mock_service, project_id) - cursor.patch_table(dataset_id, view_id, project_id, view=view_patched) - body = { - 'view': view_patched - } - method.assert_called_once_with( - projectId=project_id, - datasetId=dataset_id, - tableId=view_id, - body=body - ) - - -class TestBigQueryCursor(unittest.TestCase): - @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration') - def test_execute_with_parameters(self, mocked_rwc): - hook.BigQueryCursor("test", "test").execute( - "SELECT %(foo)s", {"foo": "bar"}) - mocked_rwc.assert_called_once() - - -class TestLabelsInRunJob(unittest.TestCase): - @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration') - def test_run_query_with_arg(self, mocked_rwc): - project_id = 12345 + view_patched = "SELECT * FROM `test-project-id.test_dataset_id.test_table_prefix*` LIMIT 500" - def run_with_config(config): - self.assertEqual( - config['labels'], {'label1': 'test1', 'label2': 'test2'} - ) - mocked_rwc.side_effect = run_with_config - - bq_hook = hook.BigQueryBaseCursor(mock.Mock(), project_id) - bq_hook.run_query( - sql='select 1', - destination_dataset_table='my_dataset.my_table', - labels={'label1': 'test1', 'label2': 'test2'} + self.bq_hook.patch_table( + dataset_id, view_id, project_id, + view=view_patched, ) - mocked_rwc.assert_called_once() - - -class TestDatasetsOperations(unittest.TestCase): - - def test_create_empty_dataset_no_dataset_id_err(self): - - with self.assertRaises(ValueError): - hook.BigQueryBaseCursor( - mock.Mock(), "test_create_empty_dataset").create_empty_dataset( - dataset_id="", project_id="") + self.bq_client.update_table.assert_called + args, kwargs = self.bq_client.update_table.call_args + table = args[0] + self.assertEqual(table.view_query, view_patched) + self.assertEqual(kwargs['fields'], ['view_query']) - def test_create_empty_dataset_duplicates_call_err(self): - with self.assertRaises(ValueError): - hook.BigQueryBaseCursor( - mock.Mock(), "test_create_empty_dataset").create_empty_dataset( - dataset_id="", project_id="project_test", - dataset_reference={ - "datasetReference": - {"datasetId": "test_dataset", - "projectId": "project_test2"}}) - def test_get_dataset_without_dataset_id(self): - with mock.patch.object(hook.BigQueryHook, 'get_service'): - with self.assertRaises(ValueError): - hook.BigQueryBaseCursor( - mock.Mock(), "test_create_empty_dataset").get_dataset( - dataset_id="", project_id="project_test") +class TestDatasetsOperations(BigQueryTestBase): def test_get_dataset(self): - expected_result = { - "kind": "bigquery#dataset", - "location": "US", - "id": "your-project:dataset_2_test", - "datasetReference": { - "projectId": "your-project", - "datasetId": "dataset_2_test" - } - } dataset_id = "test_dataset" project_id = "project_test" - bq_hook = hook.BigQueryBaseCursor(mock.Mock(), project_id) - with mock.patch.object(bq_hook.service, 'datasets') as MockService: - MockService.return_value.get(datasetId=dataset_id, - projectId=project_id).execute.\ - return_value = expected_result - result = bq_hook.get_dataset(dataset_id=dataset_id, - project_id=project_id) - self.assertEqual(result, expected_result) - - def test_get_datasets_list(self): - expected_result = {'datasets': [ - { - "kind": "bigquery#dataset", - "location": "US", - "id": "your-project:dataset_2_test", - "datasetReference": { - "projectId": "your-project", - "datasetId": "dataset_2_test" - } - }, - { - "kind": "bigquery#dataset", - "location": "US", - "id": "your-project:dataset_1_test", - "datasetReference": { - "projectId": "your-project", - "datasetId": "dataset_1_test" - } - } - ]} - project_id = "project_test"'' - - mocked = mock.Mock() - with mock.patch.object(hook.BigQueryBaseCursor(mocked, project_id).service, - 'datasets') as MockService: - MockService.return_value.list( - projectId=project_id).execute.return_value = expected_result - result = hook.BigQueryBaseCursor( - mocked, "test_create_empty_dataset").get_datasets_list( - project_id=project_id) - self.assertEqual(result, expected_result['datasets']) - - -class TestTimePartitioningInRunJob(unittest.TestCase): - @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration') - def test_run_load_default(self, mocked_rwc): - project_id = 12345 - - def run_with_config(config): - self.assertIsNone(config['load'].get('timePartitioning')) - mocked_rwc.side_effect = run_with_config - - bq_hook = hook.BigQueryBaseCursor(mock.Mock(), project_id) - bq_hook.run_load( - destination_project_dataset_table='my_dataset.my_table', - schema_fields=[], - source_uris=[], - ) - - mocked_rwc.assert_called_once() - - @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration') - def test_run_with_auto_detect(self, run_with_config): - destination_project_dataset_table = "autodetect.table" - cursor = hook.BigQueryBaseCursor(mock.Mock(), "project_id") - cursor.run_load(destination_project_dataset_table, [], [], autodetect=True) - args, kwargs = run_with_config.call_args - self.assertIs(args[0]['load']['autodetect'], True) - - @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration') - def test_run_load_with_arg(self, mocked_rwc): - project_id = 12345 - - def run_with_config(config): - self.assertEqual( - config['load']['timePartitioning'], - { - 'field': 'test_field', - 'type': 'DAY', - 'expirationMs': 1000 - } - ) - mocked_rwc.side_effect = run_with_config - - bq_hook = hook.BigQueryBaseCursor(mock.Mock(), project_id) - bq_hook.run_load( - destination_project_dataset_table='my_dataset.my_table', - schema_fields=[], - source_uris=[], - time_partitioning={'type': 'DAY', 'field': 'test_field', 'expirationMs': 1000} - ) - - mocked_rwc.assert_called_once() - - @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration') - def test_run_query_default(self, mocked_rwc): - project_id = 12345 - - def run_with_config(config): - self.assertIsNone(config['query'].get('timePartitioning')) - mocked_rwc.side_effect = run_with_config - - bq_hook = hook.BigQueryBaseCursor(mock.Mock(), project_id) - bq_hook.run_query(sql='select 1') - - mocked_rwc.assert_called_once() - - @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration') - def test_run_query_with_arg(self, mocked_rwc): - project_id = 12345 - - def run_with_config(config): - self.assertEqual( - config['query']['timePartitioning'], - { - 'field': 'test_field', - 'type': 'DAY', - 'expirationMs': 1000 - } - ) - mocked_rwc.side_effect = run_with_config - - bq_hook = hook.BigQueryBaseCursor(mock.Mock(), project_id) - bq_hook.run_query( - sql='select 1', - destination_dataset_table='my_dataset.my_table', - time_partitioning={'type': 'DAY', - 'field': 'test_field', 'expirationMs': 1000} - ) - - mocked_rwc.assert_called_once() - - def test_dollar_makes_partition(self): - tp_out = _cleanse_time_partitioning('test.teast$20170101', {}) - expect = { - 'type': 'DAY' - } - self.assertEqual(tp_out, expect) + dataset = self.bq_hook.get_dataset(dataset_id=dataset_id, project_id=project_id) - def test_extra_time_partitioning_options(self): - tp_out = _cleanse_time_partitioning( - 'test.teast', - {'type': 'DAY', 'field': 'test_field', 'expirationMs': 1000} - ) - - expect = { - 'type': 'DAY', - 'field': 'test_field', - 'expirationMs': 1000 - } - self.assertEqual(tp_out, expect) - - -class TestClusteringInRunJob(unittest.TestCase): - - @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration') - def test_run_load_default(self, mocked_rwc): - project_id = 12345 - - def run_with_config(config): - self.assertIsNone(config['load'].get('clustering')) - mocked_rwc.side_effect = run_with_config + self.assertEqual(dataset, self.bq_client.get_dataset.return_value) - bq_hook = hook.BigQueryBaseCursor(mock.Mock(), project_id) - bq_hook.run_load( - destination_project_dataset_table='my_dataset.my_table', - schema_fields=[], - source_uris=[], - ) - - mocked_rwc.assert_called_once() - - @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration') - def test_run_load_with_arg(self, mocked_rwc): - project_id = 12345 - - def run_with_config(config): - self.assertEqual( - config['load']['clustering'], - { - 'fields': ['field1', 'field2'] - } - ) - mocked_rwc.side_effect = run_with_config - - bq_hook = hook.BigQueryBaseCursor(mock.Mock(), project_id) - bq_hook.run_load( - destination_project_dataset_table='my_dataset.my_table', - schema_fields=[], - source_uris=[], - cluster_fields=['field1', 'field2'], - time_partitioning={'type': 'DAY'} - ) - - mocked_rwc.assert_called_once() - - @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration') - def test_run_query_default(self, mocked_rwc): - project_id = 12345 - - def run_with_config(config): - self.assertIsNone(config['query'].get('clustering')) - mocked_rwc.side_effect = run_with_config - - bq_hook = hook.BigQueryBaseCursor(mock.Mock(), project_id) - bq_hook.run_query(sql='select 1') - - mocked_rwc.assert_called_once() - - @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration') - def test_run_query_with_arg(self, mocked_rwc): - project_id = 12345 - - def run_with_config(config): - self.assertEqual( - config['query']['clustering'], - { - 'fields': ['field1', 'field2'] - } - ) - mocked_rwc.side_effect = run_with_config - - bq_hook = hook.BigQueryBaseCursor(mock.Mock(), project_id) - bq_hook.run_query( - sql='select 1', - destination_dataset_table='my_dataset.my_table', - cluster_fields=['field1', 'field2'], - time_partitioning={'type': 'DAY'} - ) + def test_get_datasets_list(self): + project_id = "project_test" - mocked_rwc.assert_called_once() + self.bq_client.list_datasets.return_value = [] + datasets = self.bq_hook.list_datasets(project_id=project_id) + self.assertEqual(datasets, self.bq_client.list_datasets.return_value) -class TestBigQueryHookLegacySql(unittest.TestCase): - """Ensure `use_legacy_sql` param in `BigQueryHook` propagates properly.""" - @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration') - def test_hook_uses_legacy_sql_by_default(self, run_with_config): - with mock.patch.object(hook.BigQueryHook, 'get_service'): - bq_hook = hook.BigQueryHook() - bq_hook.get_first('query') - args, kwargs = run_with_config.call_args - self.assertIs(args[0]['query']['useLegacySql'], True) +class TestLoad(BigQueryTestBase): - @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration') - def test_legacy_sql_override_propagates_properly(self, run_with_config): - with mock.patch.object(hook.BigQueryHook, 'get_service'): - bq_hook = hook.BigQueryHook(use_legacy_sql=False) - bq_hook.get_first('query') - args, kwargs = run_with_config.call_args - self.assertIs(args[0]['query']['useLegacySql'], False) + def test_load(self): + project_id = 'project' + destination = 'dataset.table' + source_uris = ['gs://bucket/path'] + schema = [] + self.bq_hook.run_load(destination, source_uris, project_id=project_id, schema_fields=schema) + self.bq_client.load_table_from_uri.assert_called() + self.bq_client.load_table_from_uri.return_value.result.assert_called() -class TestBigQueryHookLocation(unittest.TestCase): - @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration') - def test_location_propagates_properly(self, run_with_config): - with mock.patch.object(hook.BigQueryHook, 'get_service'): - bq_hook = hook.BigQueryHook(location=None) - self.assertIsNone(bq_hook.location) + def test_load_time_partitioning(self): + project_id = 'project' + destination = 'dataset.table' + source_uris = ['gs://bucket/path'] + schema = [] + partitioning = bigquery.TimePartitioning() + self.bq_hook.run_load(destination, source_uris, project_id=project_id, schema_fields=schema, + time_partitioning=partitioning) - bq_cursor = hook.BigQueryBaseCursor(mock.Mock(), - 'test-project', - location=None) - self.assertIsNone(bq_cursor.location) - bq_cursor.run_query(sql='select 1', location='US') - run_with_config.assert_called_once() - self.assertEqual(bq_cursor.location, 'US') + self.bq_client.load_table_from_uri.assert_called() + self.bq_client.load_table_from_uri.return_value.result.assert_called() + args, kwargs = self.bq_client.load_table_from_uri.call_args -if __name__ == '__main__': - unittest.main() + self.assertEqual(kwargs['job_config'].time_partitioning, partitioning) diff --git a/tests/contrib/operators/test_bigquery_operator.py b/tests/contrib/operators/test_bigquery_operator.py index 304699410992a..af046125a82cb 100644 --- a/tests/contrib/operators/test_bigquery_operator.py +++ b/tests/contrib/operators/test_bigquery_operator.py @@ -61,8 +61,6 @@ def test_execute(self, mock_hook): operator.execute(None) mock_hook.return_value \ - .get_conn() \ - .cursor() \ .create_empty_table \ .assert_called_once_with( dataset_id=TEST_DATASET, @@ -91,8 +89,6 @@ def test_execute(self, mock_hook): operator.execute(None) mock_hook.return_value \ - .get_conn() \ - .cursor() \ .create_external_table \ .assert_called_once_with( external_project_dataset_table='{}.{}'.format( @@ -109,7 +105,7 @@ def test_execute(self, mock_hook): quote_character=None, allow_quoted_newlines=False, allow_jagged_rows=False, - src_fmt_configs={}, + external_config_options=None, labels=None ) @@ -125,8 +121,6 @@ def test_execute(self, mock_hook): operator.execute(None) mock_hook.return_value \ - .get_conn() \ - .cursor() \ .delete_dataset \ .assert_called_once_with( dataset_id=TEST_DATASET, @@ -145,8 +139,6 @@ def test_execute(self, mock_hook): operator.execute(None) mock_hook.return_value \ - .get_conn() \ - .cursor() \ .create_empty_dataset \ .assert_called_once_with( dataset_id=TEST_DATASET, @@ -183,7 +175,7 @@ def test_execute(self, mock_hook): flatten_results=None, bigquery_conn_id='bigquery_default', udf_config=None, - use_legacy_sql=True, + use_legacy_sql=False, maximum_billing_tier=None, maximum_bytes_billed=None, create_disposition='CREATE_IF_NEEDED', @@ -198,8 +190,6 @@ def test_execute(self, mock_hook): operator.execute(None) mock_hook.return_value \ - .get_conn() \ - .cursor() \ .run_query \ .assert_called_once_with( sql='Select * from test_table', @@ -231,8 +221,6 @@ def test_bigquery_operator_defaults(self, mock_hook): operator.execute(None) mock_hook.return_value \ - .get_conn() \ - .cursor() \ .run_query \ .assert_called_once_with( sql='Select * from test_table',