diff --git a/sdks/python/apache_beam/io/gcp/bigquery.py b/sdks/python/apache_beam/io/gcp/bigquery.py index 47674d988d02..3097d0210c7c 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery.py +++ b/sdks/python/apache_beam/io/gcp/bigquery.py @@ -362,8 +362,6 @@ def compute_table_name(row): NOTE: This job name template does not have backwards compatibility guarantees. """ BQ_JOB_NAME_TEMPLATE = "beam_bq_job_{job_type}_{job_id}_{step_id}{random}" -"""The number of shards per destination when writing via streaming inserts.""" -DEFAULT_SHARDS_PER_DESTINATION = 500 @deprecated(since='2.11.0', current="bigquery_tools.parse_table_reference") @@ -1081,7 +1079,8 @@ def __init__( max_buffered_rows=None, retry_strategy=None, additional_bq_parameters=None, - ignore_insert_ids=False): + ignore_insert_ids=False, + with_batched_input=False): """Initialize a WriteToBigQuery transform. Args: @@ -1122,6 +1121,9 @@ def __init__( duplication of data inserted to BigQuery, set `ignore_insert_ids` to True to increase the throughput for BQ writing. See: https://cloud.google.com/bigquery/streaming-data-into-bigquery#disabling_best_effort_de-duplication + with_batched_input: Whether the input has already been batched per + destination. If not, perform best-effort batching per destination within + a bunble. """ self.schema = schema self.test_client = test_client @@ -1142,6 +1144,7 @@ def __init__( max_buffered_rows or BigQueryWriteFn.DEFAULT_MAX_BUFFERED_ROWS) self._retry_strategy = retry_strategy or RetryStrategy.RETRY_ALWAYS self.ignore_insert_ids = ignore_insert_ids + self.with_batched_input = with_batched_input self.additional_bq_parameters = additional_bq_parameters or {} @@ -1254,13 +1257,19 @@ def process(self, element, *schema_side_inputs): destination = bigquery_tools.get_hashable_destination(destination) - row_and_insert_id = element[1] - self._rows_buffer[destination].append(row_and_insert_id) - self._total_buffered_rows += 1 - if len(self._rows_buffer[destination]) >= self._max_batch_size: + if not self.with_batched_input: + row_and_insert_id = element[1] + self._rows_buffer[destination].append(row_and_insert_id) + self._total_buffered_rows += 1 + if len(self._rows_buffer[destination]) >= self._max_batch_size: + return self._flush_batch(destination) + elif self._total_buffered_rows >= self._max_buffered_rows: + return self._flush_all_batches() + else: + # The input is already batched per destination, flush the rows now. + batched_rows = element[1] + self._rows_buffer[destination].extend(batched_rows) return self._flush_batch(destination) - elif self._total_buffered_rows >= self._max_buffered_rows: - return self._flush_all_batches() def finish_bundle(self): bigquery_tools.BigQueryWrapper.HISTOGRAM_METRIC_LOGGER.log_metrics( @@ -1348,6 +1357,13 @@ def _flush_batch(self, destination): ] +# The number of shards per destination when writing via streaming inserts. +DEFAULT_SHARDS_PER_DESTINATION = 500 +# The max duration a batch of elements is allowed to be buffered before being +# flushed to BigQuery. +DEFAULT_BATCH_BUFFERING_DURATION_LIMIT_SEC = 0.2 + + class _StreamToBigQuery(PTransform): def __init__( self, @@ -1362,6 +1378,7 @@ def __init__( retry_strategy, additional_bq_parameters, ignore_insert_ids, + with_auto_sharding, test_client=None): self.table_reference = table_reference self.table_side_inputs = table_side_inputs @@ -1375,11 +1392,9 @@ def __init__( self.test_client = test_client self.additional_bq_parameters = additional_bq_parameters self.ignore_insert_ids = ignore_insert_ids + self.with_auto_sharding = with_auto_sharding class InsertIdPrefixFn(DoFn): - def __init__(self, shards=DEFAULT_SHARDS_PER_DESTINATION): - self.shards = shards - def start_bundle(self): self.prefix = str(uuid.uuid4()) self._row_count = 0 @@ -1387,8 +1402,6 @@ def start_bundle(self): def process(self, element): key = element[0] value = element[1] - key = (key, random.randint(0, self.shards)) - insert_id = '%s-%s' % (self.prefix, self._row_count) self._row_count += 1 yield (key, (value, insert_id)) @@ -1403,27 +1416,56 @@ def expand(self, input): retry_strategy=self.retry_strategy, test_client=self.test_client, additional_bq_parameters=self.additional_bq_parameters, - ignore_insert_ids=self.ignore_insert_ids) + ignore_insert_ids=self.ignore_insert_ids, + with_batched_input=self.with_auto_sharding) + + def _add_random_shard(element): + key = element[0] + value = element[1] + return ((key, random.randint(0, DEFAULT_SHARDS_PER_DESTINATION)), value) + + def _to_hashable_table_ref(table_ref_elem_kv): + table_ref = table_ref_elem_kv[0] + hashable_table_ref = bigquery_tools.get_hashable_destination(table_ref) + return (hashable_table_ref, table_ref_elem_kv[1]) - def drop_shard(elms): - key_and_shard = elms[0] - key = key_and_shard[0] - value = elms[1] - return (key, value) + def _restore_table_ref(sharded_table_ref_elems_kv): + sharded_table_ref = sharded_table_ref_elems_kv[0] + table_ref = bigquery_tools.parse_table_reference(sharded_table_ref.key) + return (table_ref, sharded_table_ref_elems_kv[1]) - sharded_data = ( + tagged_data = ( input | 'AppendDestination' >> beam.ParDo( bigquery_tools.AppendDestinationsFn(self.table_reference), *self.table_side_inputs) - | 'AddInsertIdsWithRandomKeys' >> beam.ParDo( - _StreamToBigQuery.InsertIdPrefixFn())) - - sharded_data = (sharded_data | 'CommitInsertIds' >> ReshufflePerKey()) + | 'AddInsertIds' >> beam.ParDo(_StreamToBigQuery.InsertIdPrefixFn())) + + if not self.with_auto_sharding: + tagged_data = ( + tagged_data + | 'WithFixedSharding' >> beam.Map(_add_random_shard) + | 'CommitInsertIds' >> ReshufflePerKey() + | 'DropShard' >> beam.Map(lambda kv: (kv[0][0], kv[1]))) + else: + # Auto-sharding is achieved via GroupIntoBatches.WithShardedKey + # transform which shards, groups and at the same time batches the table + # rows to be inserted to BigQuery. + + # Firstly the keys of tagged_data (table references) are converted to a + # hashable format. This is needed to work with the keyed states used by + # GroupIntoBatches. After grouping and batching is done, original table + # references are restored. + tagged_data = ( + tagged_data + | 'ToHashableTableRef' >> beam.Map(_to_hashable_table_ref) + | 'WithAutoSharding' >> beam.GroupIntoBatches.WithShardedKey( + (self.batch_size or BigQueryWriteFn.DEFAULT_MAX_BUFFERED_ROWS), + DEFAULT_BATCH_BUFFERING_DURATION_LIMIT_SEC) + | 'FromHashableTableRefAndDropShard' >> beam.Map(_restore_table_ref)) return ( - sharded_data - | 'DropShard' >> beam.Map(drop_shard) + tagged_data | 'StreamInsertRows' >> ParDo( bigquery_write_fn, *self.schema_side_inputs).with_outputs( BigQueryWriteFn.FAILED_ROWS, main='main')) @@ -1467,7 +1509,9 @@ def __init__( triggering_frequency=None, validate=True, temp_file_format=None, - ignore_insert_ids=False): + ignore_insert_ids=False, + # TODO(BEAM-11857): Switch the default when the feature is mature. + with_auto_sharding=False): """Initialize a WriteToBigQuery transform. Args: @@ -1524,7 +1568,6 @@ def __init__( tables. batch_size (int): Number of rows to be written to BQ per streaming API insert. The default is 500. - insert. test_client: Override the default bigquery client used for testing. max_file_size (int): The maximum size for a file to be written and then loaded into BigQuery. The default value is 4TB, which is 80% of the @@ -1591,6 +1634,10 @@ def __init__( duplication of data inserted to BigQuery, set `ignore_insert_ids` to True to increase the throughput for BQ writing. See: https://cloud.google.com/bigquery/streaming-data-into-bigquery#disabling_best_effort_de-duplication + with_auto_sharding: Experimental. If true, enables using a dynamically + determined number of shards to write to BigQuery. This can be used for + both FILE_LOADS and STREAMING_INSERTS. Only applicable to unbounded + input. """ self._table = table self._dataset = dataset @@ -1615,6 +1662,7 @@ def __init__( self.max_files_per_bundle = max_files_per_bundle self.method = method or WriteToBigQuery.Method.DEFAULT self.triggering_frequency = triggering_frequency + self.with_auto_sharding = with_auto_sharding self.insert_retry_strategy = insert_retry_strategy self._validate = validate self._temp_file_format = temp_file_format or bigquery_tools.FileFormat.JSON @@ -1649,10 +1697,14 @@ def expand(self, pcoll): self.table_reference.projectId = pcoll.pipeline.options.view_as( GoogleCloudOptions).project - experiments = p.options.view_as(DebugOptions).experiments or [] # TODO(pabloem): Use a different method to determine if streaming or batch. is_streaming_pipeline = p.options.view_as(StandardOptions).streaming + if not is_streaming_pipeline and self.with_auto_sharding: + raise ValueError( + 'with_auto_sharding is not applicable to batch pipelines.') + + experiments = p.options.view_as(DebugOptions).experiments or [] method_to_use = self._compute_method(experiments, is_streaming_pipeline) if method_to_use == WriteToBigQuery.Method.STREAMING_INSERTS: @@ -1667,17 +1719,18 @@ def expand(self, pcoll): 'FILE_LOADS method of writing to BigQuery.') outputs = pcoll | _StreamToBigQuery( - self.table_reference, - self.table_side_inputs, - self.schema_side_inputs, - self.schema, - self.batch_size, - self.create_disposition, - self.write_disposition, - self.kms_key, - self.insert_retry_strategy, - self.additional_bq_parameters, - self._ignore_insert_ids, + table_reference=self.table_reference, + table_side_inputs=self.table_side_inputs, + schema_side_inputs=self.schema_side_inputs, + schema=self.schema, + batch_size=self.batch_size, + create_disposition=self.create_disposition, + write_disposition=self.write_disposition, + kms_key=self.kms_key, + retry_strategy=self.insert_retry_strategy, + additional_bq_parameters=self.additional_bq_parameters, + ignore_insert_ids=self._ignore_insert_ids, + with_auto_sharding=self.with_auto_sharding, test_client=self.test_client) return {BigQueryWriteFn.FAILED_ROWS: outputs[BigQueryWriteFn.FAILED_ROWS]} @@ -1701,6 +1754,7 @@ def expand(self, pcoll): create_disposition=self.create_disposition, write_disposition=self.write_disposition, triggering_frequency=self.triggering_frequency, + with_auto_sharding=self.with_auto_sharding, temp_file_format=self._temp_file_format, max_file_size=self.max_file_size, max_files_per_bundle=self.max_files_per_bundle, @@ -1759,6 +1813,8 @@ def serialize(side_inputs): 'triggering_frequency': self.triggering_frequency, 'validate': self._validate, 'temp_file_format': self._temp_file_format, + 'ignore_insert_ids': self._ignore_insert_ids, + 'with_auto_sharding': self.with_auto_sharding, } return 'beam:transform:write_to_big_query:v0', pickler.dumps(config) diff --git a/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py b/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py index ae5ebcc03205..b816efc3ec39 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py @@ -33,6 +33,7 @@ import hashlib import logging import random +import time import uuid from future.utils import iteritems @@ -46,6 +47,7 @@ from apache_beam.options.pipeline_options import GoogleCloudOptions from apache_beam.transforms import trigger from apache_beam.transforms.display import DisplayDataItem +from apache_beam.transforms.util import GroupIntoBatches from apache_beam.transforms.window import GlobalWindows _LOGGER = logging.getLogger(__name__) @@ -641,6 +643,7 @@ def __init__( create_disposition=None, write_disposition=None, triggering_frequency=None, + with_auto_sharding=False, temp_file_format=None, max_file_size=None, max_files_per_bundle=None, @@ -656,6 +659,7 @@ def __init__( self.create_disposition = create_disposition self.write_disposition = write_disposition self.triggering_frequency = triggering_frequency + self.with_auto_sharding = with_auto_sharding self.max_file_size = max_file_size or _DEFAULT_MAX_FILE_SIZE self.max_files_per_bundle = ( max_files_per_bundle or _DEFAULT_MAX_WRITERS_PER_BUNDLE) @@ -708,6 +712,9 @@ def verify(self): raise ValueError( 'triggering_frequency can only be used with file' 'loads in streaming') + if not self.is_streaming_pipeline and self.with_auto_sharding: + return ValueError( + 'with_auto_sharding can only be used with file loads in streaming.') def _window_fn(self): """Set the correct WindowInto PTransform""" @@ -719,7 +726,11 @@ def _window_fn(self): # that the files are written if a threshold number of records are ready. # We use only the user-supplied trigger on the actual BigQuery load. # This allows us to offload the data to the filesystem. - if self.is_streaming_pipeline: + # + # In the case of auto sharding, however, we use a default triggering and + # instead apply the user supplied triggering_frequency to the transfrom that + # performs sharding. + if self.is_streaming_pipeline and not self.with_auto_sharding: return beam.WindowInto(beam.window.GlobalWindows(), trigger=trigger.Repeatedly( trigger.AfterAny( @@ -732,6 +743,21 @@ def _window_fn(self): else: return beam.WindowInto(beam.window.GlobalWindows()) + def _maybe_apply_user_trigger(self, destination_file_kv_pc): + if self.is_streaming_pipeline: + # Apply the user's trigger back before we start triggering load jobs + return ( + destination_file_kv_pc + | "ApplyUserTrigger" >> beam.WindowInto( + beam.window.GlobalWindows(), + trigger=trigger.Repeatedly( + trigger.AfterAll( + trigger.AfterProcessingTime(self.triggering_frequency), + trigger.AfterCount(1))), + accumulation_mode=trigger.AccumulationMode.DISCARDING)) + else: + return destination_file_kv_pc + def _write_files(self, destination_data_kv_pc, file_prefix_pcv): outputs = ( destination_data_kv_pc @@ -774,19 +800,38 @@ def _write_files(self, destination_data_kv_pc, file_prefix_pcv): (destination_files_kv_pc, more_destination_files_kv_pc) | "DestinationFilesUnion" >> beam.Flatten() | "IdentityWorkaround" >> beam.Map(lambda x: x)) + return self._maybe_apply_user_trigger(all_destination_file_pairs_pc) - if self.is_streaming_pipeline: - # Apply the user's trigger back before we start triggering load jobs - all_destination_file_pairs_pc = ( - all_destination_file_pairs_pc - | "ApplyUserTrigger" >> beam.WindowInto( - beam.window.GlobalWindows(), - trigger=trigger.Repeatedly( - trigger.AfterAll( - trigger.AfterProcessingTime(self.triggering_frequency), - trigger.AfterCount(1))), - accumulation_mode=trigger.AccumulationMode.DISCARDING)) - return all_destination_file_pairs_pc + def _write_files_with_auto_sharding( + self, destination_data_kv_pc, file_prefix_pcv): + clock = self.test_client.test_clock if self.test_client else time.time + + # Auto-sharding is achieved via GroupIntoBatches.WithShardedKey + # transform which shards, groups and at the same time batches the table rows + # to be inserted to BigQuery. + + # Firstly, the keys of tagged_data (table references) are converted to a + # hashable format. This is needed to work with the keyed states used by. + # GroupIntoBatches. After grouping and batching is done, table references + # are restored. + destination_files_kv_pc = ( + destination_data_kv_pc + | 'ToHashableTableRef' >> beam.Map( + lambda kv: (bigquery_tools.get_hashable_destination(kv[0]), kv[1])) + | 'WithAutoSharding' >> GroupIntoBatches.WithShardedKey( + batch_size=_FILE_TRIGGERING_RECORD_COUNT, + max_buffering_duration_secs=self.triggering_frequency, + clock=clock) + | 'FromHashableTableRefAndDropShard' >> beam.Map( + lambda kvs: + (bigquery_tools.parse_table_reference(kvs[0].key), kvs[1])) + | beam.ParDo( + WriteGroupedRecordsToFile( + schema=self.schema, file_format=self._temp_file_format), + file_prefix_pcv, + *self.schema_side_inputs)) + + return self._maybe_apply_user_trigger(destination_files_kv_pc) def _load_data( self, @@ -933,8 +978,12 @@ def expand(self, pcoll): bigquery_tools.AppendDestinationsFn(self.destination), *self.table_side_inputs)) - all_destination_file_pairs_pc = self._write_files( - destination_data_kv_pc, file_prefix_pcv) + if not self.with_auto_sharding: + all_destination_file_pairs_pc = self._write_files( + destination_data_kv_pc, file_prefix_pcv) + else: + all_destination_file_pairs_pc = self._write_files_with_auto_sharding( + destination_data_kv_pc, file_prefix_pcv) grouped_files_pc = ( all_destination_file_pairs_pc diff --git a/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py b/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py index fca7d9c94bd2..0de211f7c16b 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py @@ -44,6 +44,7 @@ from apache_beam.io.gcp.internal.clients import bigquery as bigquery_api from apache_beam.io.gcp.tests.bigquery_matcher import BigqueryFullResultMatcher from apache_beam.io.gcp.tests.bigquery_matcher import BigqueryFullResultStreamingMatcher +from apache_beam.options.pipeline_options import StandardOptions from apache_beam.runners.dataflow.test_dataflow_runner import TestDataflowRunner from apache_beam.runners.runner import PipelineState from apache_beam.testing.pipeline_verifiers import PipelineStateMatcher @@ -52,7 +53,9 @@ from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to from apache_beam.transforms import combiners +from apache_beam.transforms.window import TimestampedValue from apache_beam.typehints.typehints import Tuple +from apache_beam.utils import timestamp try: from apitools.base.py.exceptions import HttpError @@ -603,6 +606,113 @@ def test_multiple_partition_files(self): equal_to([6]), label='CheckCopyJobCount') + @parameterized.expand([ + param(is_streaming=False, with_auto_sharding=False), + param(is_streaming=True, with_auto_sharding=False), + param(is_streaming=True, with_auto_sharding=True), + ]) + def test_triggering_frequency(self, is_streaming, with_auto_sharding): + destination = 'project1:dataset1.table1' + + job_reference = bigquery_api.JobReference() + job_reference.projectId = 'project1' + job_reference.jobId = 'job_name1' + result_job = bigquery_api.Job() + result_job.jobReference = job_reference + + mock_job = mock.Mock() + mock_job.status.state = 'DONE' + mock_job.status.errorResult = None + mock_job.jobReference = job_reference + + bq_client = mock.Mock() + bq_client.jobs.Get.return_value = mock_job + bq_client.jobs.Insert.return_value = result_job + + # Insert a fake clock to work with auto-sharding which needs a processing + # time timer. + class _FakeClock(object): + def __init__(self, now=time.time()): + self._now = now + + def __call__(self): + return self._now + + start_time = timestamp.Timestamp(0) + bq_client.test_clock = _FakeClock(now=start_time) + + triggering_frequency = 20 if is_streaming else None + transform = bqfl.BigQueryBatchFileLoads( + destination, + custom_gcs_temp_location=self._new_tempdir(), + test_client=bq_client, + validate=False, + temp_file_format=bigquery_tools.FileFormat.JSON, + is_streaming_pipeline=is_streaming, + triggering_frequency=triggering_frequency, + with_auto_sharding=with_auto_sharding) + + # Need to test this with the DirectRunner to avoid serializing mocks + with TestPipeline(runner='BundleBasedDirectRunner', + options=StandardOptions(streaming=is_streaming)) as p: + if is_streaming: + _SIZE = len(_ELEMENTS) + fisrt_batch = [ + TimestampedValue(value, start_time + i + 1) for i, + value in enumerate(_ELEMENTS[:_SIZE // 2]) + ] + second_batch = [ + TimestampedValue(value, start_time + _SIZE // 2 + i + 1) for i, + value in enumerate(_ELEMENTS[_SIZE // 2:]) + ] + # Advance processing time between batches of input elements to fire the + # user triggers. Intentionally advance the processing time twice for the + # auto-sharding case since we need to first fire the timer and then + # fire the trigger. + test_stream = ( + TestStream().advance_watermark_to(start_time).add_elements( + fisrt_batch).advance_processing_time( + 30).advance_processing_time(30).add_elements(second_batch). + advance_processing_time(30).advance_processing_time( + 30).advance_watermark_to_infinity()) + input = p | test_stream + else: + input = p | beam.Create(_ELEMENTS) + outputs = input | transform + + dest_files = outputs[bqfl.BigQueryBatchFileLoads.DESTINATION_FILE_PAIRS] + dest_job = outputs[bqfl.BigQueryBatchFileLoads.DESTINATION_JOBID_PAIRS] + + files = dest_files | "GetFiles" >> beam.Map(lambda x: x[1][0]) + destinations = ( + dest_files + | "GetDests" >> beam.Map( + lambda x: (bigquery_tools.get_hashable_destination(x[0]), x[1])) + | "GetUniques" >> combiners.Count.PerKey() + | "GetFinalDests" >> beam.Keys()) + jobs = dest_job | "GetJobs" >> beam.Map(lambda x: x[1]) + + # Check that all files exist. + _ = ( + files + | beam.Map(lambda x: hamcrest_assert(os.path.exists(x), is_(True)))) + + # Expect two load jobs are generated in the streaming case due to the + # triggering frequency. Grouping is per trigger so we expect two entries + # in the output as opposed to one. + file_count = files | combiners.Count.Globally().without_defaults() + expected_file_count = [1, 1] if is_streaming else [1] + expected_destinations = [destination, destination + ] if is_streaming else [destination] + expected_jobs = [job_reference, job_reference + ] if is_streaming else [job_reference] + assert_that(file_count, equal_to(expected_file_count), label='CountFiles') + assert_that( + destinations, + equal_to(expected_destinations), + label='CheckDestinations') + assert_that(jobs, equal_to(expected_jobs), label='CheckJobs') + class BigQueryFileLoadsIT(unittest.TestCase): diff --git a/sdks/python/apache_beam/io/gcp/bigquery_test.py b/sdks/python/apache_beam/io/gcp/bigquery_test.py index fefe461b0467..674e1f8571f7 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_test.py @@ -38,6 +38,8 @@ import mock import pytz from nose.plugins.attrib import attr +from parameterized import param +from parameterized import parameterized import apache_beam as beam from apache_beam.internal import pickler @@ -892,6 +894,41 @@ def test_dofn_client_no_records(self): # InsertRows not called in finish bundle as no records self.assertFalse(client.tabledata.InsertAll.called) + def test_with_batched_input(self): + client = mock.Mock() + client.tables.Get.return_value = bigquery.Table( + tableReference=bigquery.TableReference( + projectId='project_id', datasetId='dataset_id', tableId='table_id')) + client.tabledata.InsertAll.return_value = \ + bigquery.TableDataInsertAllResponse(insertErrors=[]) + create_disposition = beam.io.BigQueryDisposition.CREATE_IF_NEEDED + write_disposition = beam.io.BigQueryDisposition.WRITE_APPEND + + fn = beam.io.gcp.bigquery.BigQueryWriteFn( + batch_size=10, + create_disposition=create_disposition, + write_disposition=write_disposition, + kms_key=None, + with_batched_input=True, + test_client=client) + + fn.start_bundle() + + # Destination is a tuple of (destination, schema) to ensure the table is + # created. + fn.process(( + 'project_id:dataset_id.table_id', + [({ + 'month': 1 + }, 'insertid3'), ({ + 'month': 2 + }, 'insertid2'), ({ + 'month': 3 + }, 'insertid1')])) + + # InsertRows called since the input is already batched. + self.assertTrue(client.tabledata.InsertAll.called) + @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') class PipelineBasedStreamingInsertTest(_TestCaseWithTempDirCleanUp): @@ -937,19 +974,89 @@ def store_callback(arg): 'columnA': 'value5', 'columnB': 'value6' }]) | _StreamToBigQuery( - 'project:dataset.table', [], [], - 'anyschema', - None, - 'CREATE_NEVER', - None, - None, - None, [], + table_reference='project:dataset.table', + table_side_inputs=[], + schema_side_inputs=[], + schema='anyschema', + batch_size=None, + create_disposition='CREATE_NEVER', + write_disposition=None, + kms_key=None, + retry_strategy=None, + additional_bq_parameters=[], ignore_insert_ids=False, + with_auto_sharding=False, test_client=client)) with open(file_name_1) as f1, open(file_name_2) as f2: self.assertEqual(json.load(f1), json.load(f2)) + @parameterized.expand([ + param(with_auto_sharding=False), + param(with_auto_sharding=True), + ]) + def test_batch_size_with_auto_sharding(self, with_auto_sharding): + tempdir = '%s%s' % (self._new_tempdir(), os.sep) + file_name_1 = os.path.join(tempdir, 'file1') + file_name_2 = os.path.join(tempdir, 'file2') + + def store_callback(arg): + insert_ids = [r.insertId for r in arg.tableDataInsertAllRequest.rows] + colA_values = [ + r.json.additionalProperties[0].value.string_value + for r in arg.tableDataInsertAllRequest.rows + ] + json_output = {'insertIds': insert_ids, 'colA_values': colA_values} + # Expect two batches of rows will be inserted. Store them separately. + if not os.path.exists(file_name_1): + with open(file_name_1, 'w') as f: + json.dump(json_output, f) + else: + with open(file_name_2, 'w') as f: + json.dump(json_output, f) + + res = mock.Mock() + res.insertErrors = [] + return res + + client = mock.Mock() + client.tabledata.InsertAll = mock.Mock(side_effect=store_callback) + + # Using the bundle based direct runner to avoid pickling problems + # with mocks. + with beam.Pipeline(runner='BundleBasedDirectRunner') as p: + _ = ( + p + | beam.Create([{ + 'columnA': 'value1', 'columnB': 'value2' + }, { + 'columnA': 'value3', 'columnB': 'value4' + }, { + 'columnA': 'value5', 'columnB': 'value6' + }]) + | _StreamToBigQuery( + table_reference='project:dataset.table', + table_side_inputs=[], + schema_side_inputs=[], + schema='anyschema', + # Set a batch size such that the input elements will be inserted + # in 2 batches. + batch_size=2, + create_disposition='CREATE_NEVER', + write_disposition=None, + kms_key=None, + retry_strategy=None, + additional_bq_parameters=[], + ignore_insert_ids=False, + with_auto_sharding=with_auto_sharding, + test_client=client)) + + with open(file_name_1) as f1, open(file_name_2) as f2: + out1 = json.load(f1) + self.assertEqual(out1['colA_values'], ['value1', 'value3']) + out2 = json.load(f2) + self.assertEqual(out2['colA_values'], ['value5']) + class BigQueryStreamingInsertTransformIntegrationTests(unittest.TestCase): BIG_QUERY_DATASET_ID = 'python_bq_streaming_inserts_' diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index 0e795dbd7419..022aa699cc7c 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -811,12 +811,14 @@ class WithShardedKey(PTransform): override the default sharding to do a better load balancing during the execution time. """ - def __init__(self, batch_size, max_buffering_duration_secs=None): + def __init__( + self, batch_size, max_buffering_duration_secs=None, clock=time.time): """Create a new GroupIntoBatches with sharded output. See ``GroupIntoBatches`` transform for a description of input parameters. """ self.params = _GroupIntoBatchesParams( batch_size, max_buffering_duration_secs) + self.clock = clock _shard_id_prefix = uuid.uuid4().bytes @@ -836,7 +838,9 @@ def expand(self, pcoll): return ( sharded_pcoll | GroupIntoBatches( - self.params.batch_size, self.params.max_buffering_duration_secs)) + self.params.batch_size, + self.params.max_buffering_duration_secs, + self.clock)) def to_runner_api_parameter( self,