Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 97 additions & 41 deletions sdks/python/apache_beam/io/gcp/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,6 @@ def compute_table_name(row):
NOTE: This job name template does not have backwards compatibility guarantees.
"""
BQ_JOB_NAME_TEMPLATE = "beam_bq_job_{job_type}_{job_id}_{step_id}{random}"
"""The number of shards per destination when writing via streaming inserts."""
DEFAULT_SHARDS_PER_DESTINATION = 500


@deprecated(since='2.11.0', current="bigquery_tools.parse_table_reference")
Expand Down Expand Up @@ -1081,7 +1079,8 @@ def __init__(
max_buffered_rows=None,
retry_strategy=None,
additional_bq_parameters=None,
ignore_insert_ids=False):
ignore_insert_ids=False,
with_batched_input=False):
"""Initialize a WriteToBigQuery transform.

Args:
Expand Down Expand Up @@ -1122,6 +1121,9 @@ def __init__(
duplication of data inserted to BigQuery, set `ignore_insert_ids`
to True to increase the throughput for BQ writing. See:
https://cloud.google.com/bigquery/streaming-data-into-bigquery#disabling_best_effort_de-duplication
with_batched_input: Whether the input has already been batched per
destination. If not, perform best-effort batching per destination within
a bunble.
"""
self.schema = schema
self.test_client = test_client
Expand All @@ -1142,6 +1144,7 @@ def __init__(
max_buffered_rows or BigQueryWriteFn.DEFAULT_MAX_BUFFERED_ROWS)
self._retry_strategy = retry_strategy or RetryStrategy.RETRY_ALWAYS
self.ignore_insert_ids = ignore_insert_ids
self.with_batched_input = with_batched_input

self.additional_bq_parameters = additional_bq_parameters or {}

Expand Down Expand Up @@ -1254,13 +1257,19 @@ def process(self, element, *schema_side_inputs):

destination = bigquery_tools.get_hashable_destination(destination)

row_and_insert_id = element[1]
self._rows_buffer[destination].append(row_and_insert_id)
self._total_buffered_rows += 1
if len(self._rows_buffer[destination]) >= self._max_batch_size:
if not self.with_batched_input:
row_and_insert_id = element[1]
self._rows_buffer[destination].append(row_and_insert_id)
self._total_buffered_rows += 1
if len(self._rows_buffer[destination]) >= self._max_batch_size:
return self._flush_batch(destination)
elif self._total_buffered_rows >= self._max_buffered_rows:
return self._flush_all_batches()
else:
# The input is already batched per destination, flush the rows now.
batched_rows = element[1]
self._rows_buffer[destination].extend(batched_rows)
return self._flush_batch(destination)
elif self._total_buffered_rows >= self._max_buffered_rows:
return self._flush_all_batches()

def finish_bundle(self):
bigquery_tools.BigQueryWrapper.HISTOGRAM_METRIC_LOGGER.log_metrics(
Expand Down Expand Up @@ -1348,6 +1357,13 @@ def _flush_batch(self, destination):
]


# The number of shards per destination when writing via streaming inserts.
DEFAULT_SHARDS_PER_DESTINATION = 500
# The max duration a batch of elements is allowed to be buffered before being
# flushed to BigQuery.
DEFAULT_BATCH_BUFFERING_DURATION_LIMIT_SEC = 0.2


class _StreamToBigQuery(PTransform):
def __init__(
self,
Expand All @@ -1362,6 +1378,7 @@ def __init__(
retry_strategy,
additional_bq_parameters,
ignore_insert_ids,
with_auto_sharding,
test_client=None):
self.table_reference = table_reference
self.table_side_inputs = table_side_inputs
Expand All @@ -1375,20 +1392,16 @@ def __init__(
self.test_client = test_client
self.additional_bq_parameters = additional_bq_parameters
self.ignore_insert_ids = ignore_insert_ids
self.with_auto_sharding = with_auto_sharding

class InsertIdPrefixFn(DoFn):
def __init__(self, shards=DEFAULT_SHARDS_PER_DESTINATION):
self.shards = shards

def start_bundle(self):
self.prefix = str(uuid.uuid4())
self._row_count = 0

def process(self, element):
key = element[0]
value = element[1]
key = (key, random.randint(0, self.shards))

insert_id = '%s-%s' % (self.prefix, self._row_count)
self._row_count += 1
yield (key, (value, insert_id))
Expand All @@ -1403,27 +1416,56 @@ def expand(self, input):
retry_strategy=self.retry_strategy,
test_client=self.test_client,
additional_bq_parameters=self.additional_bq_parameters,
ignore_insert_ids=self.ignore_insert_ids)
ignore_insert_ids=self.ignore_insert_ids,
with_batched_input=self.with_auto_sharding)

def _add_random_shard(element):
key = element[0]
value = element[1]
return ((key, random.randint(0, DEFAULT_SHARDS_PER_DESTINATION)), value)

def _to_hashable_table_ref(table_ref_elem_kv):
table_ref = table_ref_elem_kv[0]
hashable_table_ref = bigquery_tools.get_hashable_destination(table_ref)
return (hashable_table_ref, table_ref_elem_kv[1])

def drop_shard(elms):
key_and_shard = elms[0]
key = key_and_shard[0]
value = elms[1]
return (key, value)
def _restore_table_ref(sharded_table_ref_elems_kv):
sharded_table_ref = sharded_table_ref_elems_kv[0]
table_ref = bigquery_tools.parse_table_reference(sharded_table_ref.key)
return (table_ref, sharded_table_ref_elems_kv[1])

sharded_data = (
tagged_data = (
input
| 'AppendDestination' >> beam.ParDo(
bigquery_tools.AppendDestinationsFn(self.table_reference),
*self.table_side_inputs)
| 'AddInsertIdsWithRandomKeys' >> beam.ParDo(
_StreamToBigQuery.InsertIdPrefixFn()))

sharded_data = (sharded_data | 'CommitInsertIds' >> ReshufflePerKey())
| 'AddInsertIds' >> beam.ParDo(_StreamToBigQuery.InsertIdPrefixFn()))

if not self.with_auto_sharding:
tagged_data = (
tagged_data
| 'WithFixedSharding' >> beam.Map(_add_random_shard)
| 'CommitInsertIds' >> ReshufflePerKey()
| 'DropShard' >> beam.Map(lambda kv: (kv[0][0], kv[1])))
else:
# Auto-sharding is achieved via GroupIntoBatches.WithShardedKey
# transform which shards, groups and at the same time batches the table
# rows to be inserted to BigQuery.

# Firstly the keys of tagged_data (table references) are converted to a
# hashable format. This is needed to work with the keyed states used by
# GroupIntoBatches. After grouping and batching is done, original table
# references are restored.
tagged_data = (
tagged_data
| 'ToHashableTableRef' >> beam.Map(_to_hashable_table_ref)
| 'WithAutoSharding' >> beam.GroupIntoBatches.WithShardedKey(
(self.batch_size or BigQueryWriteFn.DEFAULT_MAX_BUFFERED_ROWS),
DEFAULT_BATCH_BUFFERING_DURATION_LIMIT_SEC)
| 'FromHashableTableRefAndDropShard' >> beam.Map(_restore_table_ref))

return (
sharded_data
| 'DropShard' >> beam.Map(drop_shard)
tagged_data
| 'StreamInsertRows' >> ParDo(
bigquery_write_fn, *self.schema_side_inputs).with_outputs(
BigQueryWriteFn.FAILED_ROWS, main='main'))
Expand Down Expand Up @@ -1467,7 +1509,9 @@ def __init__(
triggering_frequency=None,
validate=True,
temp_file_format=None,
ignore_insert_ids=False):
ignore_insert_ids=False,
# TODO(BEAM-11857): Switch the default when the feature is mature.
with_auto_sharding=False):
"""Initialize a WriteToBigQuery transform.

Args:
Expand Down Expand Up @@ -1524,7 +1568,6 @@ def __init__(
tables.
batch_size (int): Number of rows to be written to BQ per streaming API
insert. The default is 500.
insert.
test_client: Override the default bigquery client used for testing.
max_file_size (int): The maximum size for a file to be written and then
loaded into BigQuery. The default value is 4TB, which is 80% of the
Expand Down Expand Up @@ -1591,6 +1634,10 @@ def __init__(
duplication of data inserted to BigQuery, set `ignore_insert_ids`
to True to increase the throughput for BQ writing. See:
https://cloud.google.com/bigquery/streaming-data-into-bigquery#disabling_best_effort_de-duplication
with_auto_sharding: Experimental. If true, enables using a dynamically
determined number of shards to write to BigQuery. This can be used for
both FILE_LOADS and STREAMING_INSERTS. Only applicable to unbounded
input.
"""
self._table = table
self._dataset = dataset
Expand All @@ -1615,6 +1662,7 @@ def __init__(
self.max_files_per_bundle = max_files_per_bundle
self.method = method or WriteToBigQuery.Method.DEFAULT
self.triggering_frequency = triggering_frequency
self.with_auto_sharding = with_auto_sharding
self.insert_retry_strategy = insert_retry_strategy
self._validate = validate
self._temp_file_format = temp_file_format or bigquery_tools.FileFormat.JSON
Expand Down Expand Up @@ -1649,10 +1697,14 @@ def expand(self, pcoll):
self.table_reference.projectId = pcoll.pipeline.options.view_as(
GoogleCloudOptions).project

experiments = p.options.view_as(DebugOptions).experiments or []
# TODO(pabloem): Use a different method to determine if streaming or batch.
is_streaming_pipeline = p.options.view_as(StandardOptions).streaming

if not is_streaming_pipeline and self.with_auto_sharding:
raise ValueError(
'with_auto_sharding is not applicable to batch pipelines.')

experiments = p.options.view_as(DebugOptions).experiments or []
method_to_use = self._compute_method(experiments, is_streaming_pipeline)

if method_to_use == WriteToBigQuery.Method.STREAMING_INSERTS:
Expand All @@ -1667,17 +1719,18 @@ def expand(self, pcoll):
'FILE_LOADS method of writing to BigQuery.')

outputs = pcoll | _StreamToBigQuery(
self.table_reference,
self.table_side_inputs,
self.schema_side_inputs,
self.schema,
self.batch_size,
self.create_disposition,
self.write_disposition,
self.kms_key,
self.insert_retry_strategy,
self.additional_bq_parameters,
self._ignore_insert_ids,
table_reference=self.table_reference,
table_side_inputs=self.table_side_inputs,
schema_side_inputs=self.schema_side_inputs,
schema=self.schema,
batch_size=self.batch_size,
create_disposition=self.create_disposition,
write_disposition=self.write_disposition,
kms_key=self.kms_key,
retry_strategy=self.insert_retry_strategy,
additional_bq_parameters=self.additional_bq_parameters,
ignore_insert_ids=self._ignore_insert_ids,
with_auto_sharding=self.with_auto_sharding,
test_client=self.test_client)

return {BigQueryWriteFn.FAILED_ROWS: outputs[BigQueryWriteFn.FAILED_ROWS]}
Expand All @@ -1701,6 +1754,7 @@ def expand(self, pcoll):
create_disposition=self.create_disposition,
write_disposition=self.write_disposition,
triggering_frequency=self.triggering_frequency,
with_auto_sharding=self.with_auto_sharding,
temp_file_format=self._temp_file_format,
max_file_size=self.max_file_size,
max_files_per_bundle=self.max_files_per_bundle,
Expand Down Expand Up @@ -1759,6 +1813,8 @@ def serialize(side_inputs):
'triggering_frequency': self.triggering_frequency,
'validate': self._validate,
'temp_file_format': self._temp_file_format,
'ignore_insert_ids': self._ignore_insert_ids,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ignore_insert_ids was previously not added to runner api parameter. Wondering if this is a bug as when we convert from the runner api the parameter will be lost.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah thanks Siyuan!

'with_auto_sharding': self.with_auto_sharding,
}
return 'beam:transform:write_to_big_query:v0', pickler.dumps(config)

Expand Down
Loading