diff --git a/airflow/contrib/hooks/bigquery_hook.py b/airflow/contrib/hooks/bigquery_hook.py index aee84e9eef741..44ecd49e9edcd 100644 --- a/airflow/contrib/hooks/bigquery_hook.py +++ b/airflow/contrib/hooks/bigquery_hook.py @@ -24,6 +24,7 @@ import time from builtins import range +from copy import deepcopy from past.builtins import basestring @@ -195,10 +196,19 @@ class BigQueryBaseCursor(LoggingMixin): PEP 249 cursor isn't needed. """ - def __init__(self, service, project_id, use_legacy_sql=True): + def __init__(self, + service, + project_id, + use_legacy_sql=True, + api_resource_configs=None): + self.service = service self.project_id = project_id self.use_legacy_sql = use_legacy_sql + if api_resource_configs: + _validate_value("api_resource_configs", api_resource_configs, dict) + self.api_resource_configs = api_resource_configs \ + if api_resource_configs else {} self.running_job_id = None def create_empty_table(self, @@ -238,8 +248,7 @@ def create_empty_table(self, :return: """ - if time_partitioning is None: - time_partitioning = dict() + project_id = project_id if project_id is not None else self.project_id table_resource = { @@ -473,11 +482,11 @@ def create_external_table(self, def run_query(self, bql=None, sql=None, - destination_dataset_table=False, + destination_dataset_table=None, write_disposition='WRITE_EMPTY', allow_large_results=False, flatten_results=None, - udf_config=False, + udf_config=None, use_legacy_sql=None, maximum_billing_tier=None, maximum_bytes_billed=None, @@ -486,7 +495,8 @@ def run_query(self, labels=None, schema_update_options=(), priority='INTERACTIVE', - time_partitioning=None): + time_partitioning=None, + api_resource_configs=None): """ Executes a BigQuery SQL query. Optionally persists results in a BigQuery table. See here: @@ -518,6 +528,13 @@ def run_query(self, :param use_legacy_sql: Whether to use legacy SQL (true) or standard SQL (false). If `None`, defaults to `self.use_legacy_sql`. :type use_legacy_sql: boolean + :param api_resource_configs: a dictionary that contain params + 'configuration' applied for Google BigQuery Jobs API: + https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs + for example, {'query': {'useQueryCache': False}}. You could use it + if you need to provide some params that are not supported by the + BigQueryHook like args. + :type api_resource_configs: dict :type udf_config: list :param maximum_billing_tier: Positive integer that serves as a multiplier of the basic price. @@ -550,12 +567,22 @@ def run_query(self, :type time_partitioning: dict """ + if not api_resource_configs: + api_resource_configs = self.api_resource_configs + else: + _validate_value('api_resource_configs', + api_resource_configs, dict) + configuration = deepcopy(api_resource_configs) + if 'query' not in configuration: + configuration['query'] = {} + + else: + _validate_value("api_resource_configs['query']", + configuration['query'], dict) - # TODO remove `bql` in Airflow 2.0 - Jira: [AIRFLOW-2513] - if time_partitioning is None: - time_partitioning = {} sql = bql if sql is None else sql + # TODO remove `bql` in Airflow 2.0 - Jira: [AIRFLOW-2513] if bql: import warnings warnings.warn('Deprecated parameter `bql` used in ' @@ -566,95 +593,109 @@ def run_query(self, 'Airflow.', category=DeprecationWarning) - if sql is None: - raise TypeError('`BigQueryBaseCursor.run_query` missing 1 required ' - 'positional argument: `sql`') + if sql is None and not configuration['query'].get('query', None): + raise TypeError('`BigQueryBaseCursor.run_query` ' + 'missing 1 required positional argument: `sql`') # BigQuery also allows you to define how you want a table's schema to change # as a side effect of a query job # for more details: # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.query.schemaUpdateOptions + allowed_schema_update_options = [ 'ALLOW_FIELD_ADDITION', "ALLOW_FIELD_RELAXATION" ] - if not set(allowed_schema_update_options).issuperset( - set(schema_update_options)): - raise ValueError( - "{0} contains invalid schema update options. " - "Please only use one or more of the following options: {1}" - .format(schema_update_options, allowed_schema_update_options)) - if use_legacy_sql is None: - use_legacy_sql = self.use_legacy_sql + if not set(allowed_schema_update_options + ).issuperset(set(schema_update_options)): + raise ValueError("{0} contains invalid schema update options. " + "Please only use one or more of the following " + "options: {1}" + .format(schema_update_options, + allowed_schema_update_options)) - configuration = { - 'query': { - 'query': sql, - 'useLegacySql': use_legacy_sql, - 'maximumBillingTier': maximum_billing_tier, - 'maximumBytesBilled': maximum_bytes_billed, - 'priority': priority - } - } + if schema_update_options: + if write_disposition not in ["WRITE_APPEND", "WRITE_TRUNCATE"]: + raise ValueError("schema_update_options is only " + "allowed if write_disposition is " + "'WRITE_APPEND' or 'WRITE_TRUNCATE'.") if destination_dataset_table: - if '.' not in destination_dataset_table: - raise ValueError( - 'Expected destination_dataset_table name in the format of ' - '.. Got: {}'.format( - destination_dataset_table)) destination_project, destination_dataset, destination_table = \ _split_tablename(table_input=destination_dataset_table, default_project_id=self.project_id) - configuration['query'].update({ - 'allowLargeResults': allow_large_results, - 'flattenResults': flatten_results, - 'writeDisposition': write_disposition, - 'createDisposition': create_disposition, - 'destinationTable': { - 'projectId': destination_project, - 'datasetId': destination_dataset, - 'tableId': destination_table, - } - }) - if udf_config: - if not isinstance(udf_config, list): - raise TypeError("udf_config argument must have a type 'list'" - " not {}".format(type(udf_config))) - configuration['query'].update({ - 'userDefinedFunctionResources': udf_config - }) - if query_params: - if self.use_legacy_sql: - raise ValueError("Query parameters are not allowed when using " - "legacy SQL") - else: - configuration['query']['queryParameters'] = query_params + destination_dataset_table = { + 'projectId': destination_project, + 'datasetId': destination_dataset, + 'tableId': destination_table, + } - if labels: - configuration['labels'] = labels + query_param_list = [ + (sql, 'query', None, str), + (priority, 'priority', 'INTERACTIVE', str), + (use_legacy_sql, 'useLegacySql', self.use_legacy_sql, bool), + (query_params, 'queryParameters', None, dict), + (udf_config, 'userDefinedFunctionResources', None, list), + (maximum_billing_tier, 'maximumBillingTier', None, int), + (maximum_bytes_billed, 'maximumBytesBilled', None, float), + (time_partitioning, 'timePartitioning', {}, dict), + (schema_update_options, 'schemaUpdateOptions', None, tuple), + (destination_dataset_table, 'destinationTable', None, dict) + ] - time_partitioning = _cleanse_time_partitioning( - destination_dataset_table, - time_partitioning - ) - if time_partitioning: - configuration['query'].update({ - 'timePartitioning': time_partitioning - }) + for param_tuple in query_param_list: - if schema_update_options: - if write_disposition not in ["WRITE_APPEND", "WRITE_TRUNCATE"]: - raise ValueError("schema_update_options is only " - "allowed if write_disposition is " - "'WRITE_APPEND' or 'WRITE_TRUNCATE'.") - else: - self.log.info( - "Adding experimental " - "'schemaUpdateOptions': {0}".format(schema_update_options)) - configuration['query'][ - 'schemaUpdateOptions'] = schema_update_options + param, param_name, param_default, param_type = param_tuple + + if param_name not in configuration['query'] and param in [None, {}, ()]: + if param_name == 'timePartitioning': + param_default = _cleanse_time_partitioning( + destination_dataset_table, time_partitioning) + param = param_default + + if param not in [None, {}, ()]: + _api_resource_configs_duplication_check( + param_name, param, configuration['query']) + + configuration['query'][param_name] = param + + # check valid type of provided param, + # it last step because we can get param from 2 sources, + # and first of all need to find it + + _validate_value(param_name, configuration['query'][param_name], + param_type) + + if param_name == 'schemaUpdateOptions' and param: + self.log.info("Adding experimental 'schemaUpdateOptions': " + "{0}".format(schema_update_options)) + + if param_name == 'destinationTable': + for key in ['projectId', 'datasetId', 'tableId']: + if key not in configuration['query']['destinationTable']: + raise ValueError( + "Not correct 'destinationTable' in " + "api_resource_configs. 'destinationTable' " + "must be a dict with {'projectId':'', " + "'datasetId':'', 'tableId':''}") + + configuration['query'].update({ + 'allowLargeResults': allow_large_results, + 'flattenResults': flatten_results, + 'writeDisposition': write_disposition, + 'createDisposition': create_disposition, + }) + + if 'useLegacySql' in configuration['query'] and \ + 'queryParameters' in configuration['query']: + raise ValueError("Query parameters are not allowed " + "when using legacy SQL") + + if labels: + _api_resource_configs_duplication_check( + 'labels', labels, configuration) + configuration['labels'] = labels return self.run_with_configuration(configuration) @@ -888,8 +929,7 @@ def run_load(self, # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.query.tableDefinitions.(key).sourceFormat if src_fmt_configs is None: src_fmt_configs = {} - if time_partitioning is None: - time_partitioning = {} + source_format = source_format.upper() allowed_formats = [ "CSV", "NEWLINE_DELIMITED_JSON", "AVRO", "GOOGLE_SHEETS", @@ -1167,10 +1207,6 @@ def run_table_delete(self, deletion_dataset_table, :type ignore_if_missing: boolean :return: """ - if '.' not in deletion_dataset_table: - raise ValueError( - 'Expected deletion_dataset_table name in the format of ' - '.
. Got: {}'.format(deletion_dataset_table)) deletion_project, deletion_dataset, deletion_table = \ _split_tablename(table_input=deletion_dataset_table, default_project_id=self.project_id) @@ -1536,6 +1572,12 @@ def _bq_cast(string_field, bq_type): def _split_tablename(table_input, default_project_id, var_name=None): + + if '.' not in table_input: + raise ValueError( + 'Expected deletion_dataset_table name in the format of ' + '.
. Got: {}'.format(table_input)) + if not default_project_id: raise ValueError("INTERNAL: No default project is specified") @@ -1597,6 +1639,10 @@ def var_print(var_name): def _cleanse_time_partitioning(destination_dataset_table, time_partitioning_in): # if it is a partitioned table ($ is in the table name) add partition load option + + if time_partitioning_in is None: + time_partitioning_in = {} + time_partitioning_out = {} if destination_dataset_table and '$' in destination_dataset_table: if time_partitioning_in.get('field'): @@ -1607,3 +1653,20 @@ def _cleanse_time_partitioning(destination_dataset_table, time_partitioning_in): time_partitioning_out.update(time_partitioning_in) return time_partitioning_out + + +def _validate_value(key, value, expected_type): + """ function to check expected type and raise + error if type is not correct """ + if not isinstance(value, expected_type): + raise TypeError("{} argument must have a type {} not {}".format( + key, expected_type, type(value))) + + +def _api_resource_configs_duplication_check(key, value, config_dict): + if key in config_dict and value != config_dict[key]: + raise ValueError("Values of {param_name} param are duplicated. " + "`api_resource_configs` contained {param_name} param " + "in `query` config and {param_name} was also provided " + "with arg to run_query() method. Please remove duplicates." + .format(param_name=key)) diff --git a/airflow/contrib/operators/bigquery_operator.py b/airflow/contrib/operators/bigquery_operator.py index a8fdc66ec9631..b0c0ce2d6e31b 100644 --- a/airflow/contrib/operators/bigquery_operator.py +++ b/airflow/contrib/operators/bigquery_operator.py @@ -75,6 +75,13 @@ class BigQueryOperator(BaseOperator): (without incurring a charge). If unspecified, this will be set to your project default. :type maximum_bytes_billed: float + :param api_resource_configs: a dictionary that contain params + 'configuration' applied for Google BigQuery Jobs API: + https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs + for example, {'query': {'useQueryCache': False}}. You could use it + if you need to provide some params that are not supported by BigQueryOperator + like args. + :type api_resource_configs: dict :param schema_update_options: Allows the schema of the destination table to be updated as a side effect of the load job. :type schema_update_options: tuple @@ -118,7 +125,8 @@ def __init__(self, query_params=None, labels=None, priority='INTERACTIVE', - time_partitioning={}, + time_partitioning=None, + api_resource_configs=None, *args, **kwargs): super(BigQueryOperator, self).__init__(*args, **kwargs) @@ -140,7 +148,10 @@ def __init__(self, self.labels = labels self.bq_cursor = None self.priority = priority - self.time_partitioning = time_partitioning + if time_partitioning is None: + self.time_partitioning = {} + if api_resource_configs is None: + self.api_resource_configs = {} # TODO remove `bql` in Airflow 2.0 if self.bql: @@ -179,7 +190,8 @@ def execute(self, context): labels=self.labels, schema_update_options=self.schema_update_options, priority=self.priority, - time_partitioning=self.time_partitioning + time_partitioning=self.time_partitioning, + api_resource_configs=self.api_resource_configs, ) def on_kill(self): diff --git a/tests/contrib/hooks/test_bigquery_hook.py b/tests/contrib/hooks/test_bigquery_hook.py index d7e9491b8b8d4..69a103bf0689f 100644 --- a/tests/contrib/hooks/test_bigquery_hook.py +++ b/tests/contrib/hooks/test_bigquery_hook.py @@ -25,7 +25,8 @@ import mock from airflow.contrib.hooks import bigquery_hook as hook -from airflow.contrib.hooks.bigquery_hook import _cleanse_time_partitioning +from airflow.contrib.hooks.bigquery_hook import _cleanse_time_partitioning, \ + _validate_value, _api_resource_configs_duplication_check bq_available = True @@ -206,6 +207,16 @@ def mock_job_cancel(projectId, jobId): class TestBigQueryBaseCursor(unittest.TestCase): + def test_bql_deprecation_warning(self): + with warnings.catch_warnings(record=True) as w: + hook.BigQueryBaseCursor("test", "test").run_query( + bql='select * from test_table' + ) + yield + self.assertIn( + 'Deprecated parameter `bql`', + w[0].message.args[0]) + def test_invalid_schema_update_options(self): with self.assertRaises(Exception) as context: hook.BigQueryBaseCursor("test", "test").run_load( @@ -216,16 +227,6 @@ def test_invalid_schema_update_options(self): ) self.assertIn("THIS IS NOT VALID", str(context.exception)) - @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration') - def test_bql_deprecation_warning(self, mock_rwc): - with warnings.catch_warnings(record=True) as w: - hook.BigQueryBaseCursor("test", "test").run_query( - bql='select * from test_table' - ) - self.assertIn( - 'Deprecated parameter `bql`', - w[0].message.args[0]) - def test_nobql_nosql_param_error(self): with self.assertRaises(TypeError) as context: hook.BigQueryBaseCursor("test", "test").run_query( @@ -281,6 +282,39 @@ def test_run_query_sql_dialect_override(self, run_with_config): args, kwargs = run_with_config.call_args self.assertIs(args[0]['query']['useLegacySql'], bool_val) + @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration') + def test_api_resource_configs(self, run_with_config): + for bool_val in [True, False]: + cursor = hook.BigQueryBaseCursor(mock.Mock(), "project_id") + cursor.run_query('query', + api_resource_configs={ + 'query': {'useQueryCache': bool_val}}) + args, kwargs = run_with_config.call_args + self.assertIs(args[0]['query']['useQueryCache'], bool_val) + self.assertIs(args[0]['query']['useLegacySql'], True) + + @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration') + def test_api_resource_configs_duplication_warning(self, run_with_config): + with self.assertRaises(ValueError): + cursor = hook.BigQueryBaseCursor(mock.Mock(), "project_id") + cursor.run_query('query', + use_legacy_sql=True, + api_resource_configs={ + 'query': {'useLegacySql': False}}) + + def test_validate_value(self): + with self.assertRaises(TypeError): + _validate_value("case_1", "a", dict) + self.assertIsNone(_validate_value("case_2", 0, int)) + + def test_duplication_check(self): + with self.assertRaises(ValueError): + key_one = True + _api_resource_configs_duplication_check( + "key_one", key_one, {"key_one": False}) + self.assertIsNone(_api_resource_configs_duplication_check( + "key_one", key_one, {"key_one": True})) + class TestLabelsInRunJob(unittest.TestCase): @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')