diff --git a/CHANGES.md b/CHANGES.md index 294ab4e3df82..a380c772048e 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -25,6 +25,8 @@ * New highly anticipated feature Y added to JavaSDK ([BEAM-Y](https://issues.apache.org/jira/browse/BEAM-Y)). ### I/Os + +* Support for Google Cloud Spanner added for Python SDK. This is an experimental module for reading and writing data from Google Cloud Spanner ([BEAM-7246](https://issues.apache.org/jira/browse/BEAM-7246)). * Python SDK: Adds support for standard HDFS URLs (with server name). ([#10223](https://github.com/apache/beam/pull/10223)). * Support for X source added (Java/Python) ([BEAM-X](https://issues.apache.org/jira/browse/BEAM-X)). diff --git a/sdks/python/apache_beam/io/gcp/experimental/spannerio.py b/sdks/python/apache_beam/io/gcp/experimental/spannerio.py index a320e6ea795d..b575f6ee4d74 100644 --- a/sdks/python/apache_beam/io/gcp/experimental/spannerio.py +++ b/sdks/python/apache_beam/io/gcp/experimental/spannerio.py @@ -22,6 +22,8 @@ This is an experimental module for reading and writing data from Google Cloud Spanner. Visit: https://cloud.google.com/spanner for more details. +Reading Data from Cloud Spanner. + To read from Cloud Spanner apply ReadFromSpanner transformation. It will return a PCollection, where each element represents an individual row returned from the read operation. Both Query and Read APIs are supported. @@ -109,20 +111,81 @@ ReadFromSpanner takes this transform in the constructor and pass this to the read pipeline as the singleton side input. + +Writing Data to Cloud Spanner. + +The WriteToSpanner transform writes to Cloud Spanner by executing a +collection a input rows (WriteMutation). The mutations are grouped into +batches for efficiency. + +WriteToSpanner transform relies on the WriteMutation objects which is exposed +by the SpannerIO API. WriteMutation have five static methods (insert, update, +insert_or_update, replace, delete). These methods returns the instance of the +_Mutator object which contains the mutation type and the Spanner Mutation +object. For more details, review the docs of the class SpannerIO.WriteMutation. +For example::: + + mutations = [ + WriteMutation.insert(table='user', columns=('name', 'email'), + values=[('sara'. 'sara@dev.com')]) + ] + _ = (p + | beam.Create(mutations) + | WriteToSpanner( + project_id=SPANNER_PROJECT_ID, + instance_id=SPANNER_INSTANCE_ID, + database_id=SPANNER_DATABASE_NAME) + ) + +You can also create WriteMutation via calling its constructor. For example::: + + mutations = [ + WriteMutation(insert='users', columns=('name', 'email'), + values=[('sara", 'sara@example.com')]) + ] + +For more information, review the docs available on WriteMutation class. + +WriteToSpanner transform also takes three batching parameters (max_number_rows, +max_number_cells and max_batch_size_bytes). By default, max_number_rows is set +to 50 rows, max_number_cells is set to 500 cells and max_batch_size_bytes is +set to 1MB (1048576 bytes). These parameter used to reduce the number of +transactions sent to spanner by grouping the mutation into batches. Setting +these param values either to smaller value or zero to disable batching. +Unlike the Java connector, this connector does not create batches of +transactions sorted by table and primary key. + +WriteToSpanner transforms starts with the grouping into batches. The first step +in this process is to make the make the mutation groups of the WriteMutation +objects and then filtering them into batchable and unbatchable mutation +groups. There are three batching parameters (max_number_cells, max_number_rows +& max_batch_size_bytes). We calculated th mutation byte size from the method +available in the `google.cloud.spanner_v1.proto.mutation_pb2.Mutation.ByteSize`. +if the mutation rows, cells or byte size are larger than value of the any +batching parameters param, it will be tagged as "unbatchable" mutation. After +this all the batchable mutation are merged into a single mutation group whos +size is not larger than the "max_batch_size_bytes", after this process, all the +mutation groups together to process. If the Mutation references a table or +column does not exits, it will cause a exception and fails the entire pipeline. """ from __future__ import absolute_import import typing +from collections import deque from collections import namedtuple from apache_beam import Create from apache_beam import DoFn +from apache_beam import Flatten from apache_beam import ParDo from apache_beam import Reshuffle +from apache_beam.metrics import Metrics from apache_beam.pvalue import AsSingleton from apache_beam.pvalue import PBegin +from apache_beam.pvalue import TaggedOutput from apache_beam.transforms import PTransform from apache_beam.transforms import ptransform_fn +from apache_beam.transforms import window from apache_beam.transforms.display import DisplayDataItem from apache_beam.typehints import with_input_types from apache_beam.typehints import with_output_types @@ -131,13 +194,22 @@ try: from google.cloud.spanner import Client from google.cloud.spanner import KeySet + from google.cloud.spanner_v1 import batch from google.cloud.spanner_v1.database import BatchSnapshot + from google.cloud.spanner_v1.proto.mutation_pb2 import Mutation except ImportError: Client = None KeySet = None BatchSnapshot = None -__all__ = ['create_transaction', 'ReadFromSpanner', 'ReadOperation'] +__all__ = [ + 'create_transaction', + 'ReadFromSpanner', + 'ReadOperation', + 'WriteToSpanner', + 'WriteMutation', + 'MutationGroup' +] class _SPANNER_TRANSACTION(namedtuple("SPANNER_TRANSACTION", ["transaction"])): @@ -619,3 +691,457 @@ def display_data(self): str(self._transaction), label='transaction') return res + + +@experimental(extra_message="No backwards-compatibility guarantees.") +class WriteToSpanner(PTransform): + def __init__( + self, + project_id, + instance_id, + database_id, + pool=None, + credentials=None, + max_batch_size_bytes=1048576, + max_number_rows=50, + max_number_cells=500): + """ + A PTransform to write onto Google Cloud Spanner. + + Args: + project_id: Cloud spanner project id. Be sure to use the Project ID, + not the Project Number. + instance_id: Cloud spanner instance id. + database_id: Cloud spanner database id. + max_batch_size_bytes: (optional) Split the mutations into batches to + reduce the number of transaction sent to Spanner. By default it is + set to 1 MB (1048576 Bytes). + max_number_rows: (optional) Split the mutations into batches to + reduce the number of transaction sent to Spanner. By default it is + set to 50 rows per batch. + max_number_cells: (optional) Split the mutations into batches to + reduce the number of transaction sent to Spanner. By default it is + set to 500 cells per batch. + """ + self._configuration = _BeamSpannerConfiguration( + project=project_id, + instance=instance_id, + database=database_id, + credentials=credentials, + pool=pool, + snapshot_read_timestamp=None, + snapshot_exact_staleness=None) + self._max_batch_size_bytes = max_batch_size_bytes + self._max_number_rows = max_number_rows + self._max_number_cells = max_number_cells + self._database_id = database_id + self._project_id = project_id + self._instance_id = instance_id + self._pool = pool + + def display_data(self): + res = { + 'project_id': DisplayDataItem(self._project_id, label='Project Id'), + 'instance_id': DisplayDataItem(self._instance_id, label='Instance Id'), + 'pool': DisplayDataItem(str(self._pool), label='Pool'), + 'database': DisplayDataItem(self._database_id, label='Database'), + 'batch_size': DisplayDataItem( + self._max_batch_size_bytes, label="Batch Size"), + 'max_number_rows': DisplayDataItem( + self._max_number_rows, label="Max Rows"), + 'max_number_cells': DisplayDataItem( + self._max_number_cells, label="Max Cells"), + } + return res + + def expand(self, pcoll): + return ( + pcoll + | "make batches" >> _WriteGroup( + max_batch_size_bytes=self._max_batch_size_bytes, + max_number_rows=self._max_number_rows, + max_number_cells=self._max_number_cells) + | + 'Writing to spanner' >> ParDo(_WriteToSpannerDoFn(self._configuration))) + + +class _Mutator(namedtuple('_Mutator', + ["mutation", "operation", "kwargs", "rows", "cells"]) + ): + __slots__ = () + + @property + def byte_size(self): + return self.mutation.ByteSize() + + +class MutationGroup(deque): + """ + A Bundle of Spanner Mutations (_Mutator). + """ + @property + def info(self): + cells = 0 + rows = 0 + bytes = 0 + for m in self.__iter__(): + bytes += m.byte_size + rows += m.rows + cells += m.cells + return {"rows": rows, "cells": cells, "byte_size": bytes} + + def primary(self): + return next(self.__iter__()) + + +class WriteMutation(object): + + _OPERATION_DELETE = "delete" + _OPERATION_INSERT = "insert" + _OPERATION_INSERT_OR_UPDATE = "insert_or_update" + _OPERATION_REPLACE = "replace" + _OPERATION_UPDATE = "update" + + def __init__( + self, + insert=None, + update=None, + insert_or_update=None, + replace=None, + delete=None, + columns=None, + values=None, + keyset=None): + """ + A convenient class to create Spanner Mutations for Write. User can provide + the operation via constructor or via static methods. + + Note: If a user passing the operation via construction, make sure that it + will only accept one operation at a time. For example, if a user passing + a table name in the `insert` parameter, and he also passes the `update` + parameter value, this will cause an error. + + Args: + insert: (Optional) Name of the table in which rows will be inserted. + update: (Optional) Name of the table in which existing rows will be + updated. + insert_or_update: (Optional) Table name in which rows will be written. + Like insert, except that if the row already exists, then its column + values are overwritten with the ones provided. Any column values not + explicitly written are preserved. + replace: (Optional) Table name in which rows will be replaced. Like + insert, except that if the row already exists, it is deleted, and the + column values provided are inserted instead. Unlike `insert_or_update`, + this means any values not explicitly written become `NULL`. + delete: (Optional) Table name from which rows will be deleted. Succeeds + whether or not the named rows were present. + columns: The names of the columns in table to be written. The list of + columns must contain enough columns to allow Cloud Spanner to derive + values for all primary key columns in the row(s) to be modified. + values: The values to be written. `values` can contain more than one + list of values. If it does, then multiple rows are written, one for + each entry in `values`. Each list in `values` must have exactly as + many entries as there are entries in columns above. Sending multiple + lists is equivalent to sending multiple Mutations, each containing one + `values` entry and repeating table and columns. + keyset: (Optional) The primary keys of the rows within table to delete. + Delete is idempotent. The transaction will succeed even if some or + all rows do not exist. + """ + self._columns = columns + self._values = values + self._keyset = keyset + + self._insert = insert + self._update = update + self._insert_or_update = insert_or_update + self._replace = replace + self._delete = delete + + if sum([1 for x in [self._insert, + self._update, + self._insert_or_update, + self._replace, + self._delete] if x is not None]) != 1: + raise ValueError( + "No or more than one write mutation operation " + "provided: <%s: %s>" % (self.__class__.__name__, str(self.__dict__))) + + def __call__(self, *args, **kwargs): + if self._insert is not None: + return WriteMutation.insert( + table=self._insert, columns=self._columns, values=self._values) + elif self._update is not None: + return WriteMutation.update( + table=self._update, columns=self._columns, values=self._values) + elif self._insert_or_update is not None: + return WriteMutation.insert_or_update( + table=self._insert_or_update, + columns=self._columns, + values=self._values) + elif self._replace is not None: + return WriteMutation.replace( + table=self._replace, columns=self._columns, values=self._values) + elif self._delete is not None: + return WriteMutation.delete(table=self._delete, keyset=self._keyset) + + @staticmethod + def insert(table, columns, values): + """Insert one or more new table rows. + + Args: + table: Name of the table to be modified. + columns: Name of the table columns to be modified. + values: Values to be modified. + """ + rows = len(values) + cells = len(columns) * len(values) + return _Mutator( + mutation=Mutation(insert=batch._make_write_pb(table, columns, values)), + operation=WriteMutation._OPERATION_INSERT, + rows=rows, + cells=cells, + kwargs={ + "table": table, "columns": columns, "values": values + }) + + @staticmethod + def update(table, columns, values): + """Update one or more existing table rows. + + Args: + table: Name of the table to be modified. + columns: Name of the table columns to be modified. + values: Values to be modified. + """ + rows = len(values) + cells = len(columns) * len(values) + return _Mutator( + mutation=Mutation(update=batch._make_write_pb(table, columns, values)), + operation=WriteMutation._OPERATION_UPDATE, + rows=rows, + cells=cells, + kwargs={ + "table": table, "columns": columns, "values": values + }) + + @staticmethod + def insert_or_update(table, columns, values): + """Insert/update one or more table rows. + Args: + table: Name of the table to be modified. + columns: Name of the table columns to be modified. + values: Values to be modified. + """ + rows = len(values) + cells = len(columns) * len(values) + return _Mutator( + mutation=Mutation( + insert_or_update=batch._make_write_pb(table, columns, values)), + operation=WriteMutation._OPERATION_INSERT_OR_UPDATE, + rows=rows, + cells=cells, + kwargs={ + "table": table, "columns": columns, "values": values + }) + + @staticmethod + def replace(table, columns, values): + """Replace one or more table rows. + + Args: + table: Name of the table to be modified. + columns: Name of the table columns to be modified. + values: Values to be modified. + """ + rows = len(values) + cells = len(columns) * len(values) + return _Mutator( + mutation=Mutation(replace=batch._make_write_pb(table, columns, values)), + operation=WriteMutation._OPERATION_REPLACE, + rows=rows, + cells=cells, + kwargs={ + "table": table, "columns": columns, "values": values + }) + + @staticmethod + def delete(table, keyset): + """Delete one or more table rows. + + Args: + table: Name of the table to be modified. + keyset: Keys/ranges identifying rows to delete. + """ + delete = Mutation.Delete(table=table, key_set=keyset._to_pb()) + return _Mutator( + mutation=Mutation(delete=delete), + rows=0, + cells=0, + operation=WriteMutation._OPERATION_DELETE, + kwargs={ + "table": table, "keyset": keyset + }) + + +@with_input_types(typing.Union[MutationGroup, TaggedOutput]) +@with_output_types(MutationGroup) +class _BatchFn(DoFn): + """ + Batches mutations together. + """ + def __init__(self, max_batch_size_bytes, max_number_rows, max_number_cells): + self._max_batch_size_bytes = max_batch_size_bytes + self._max_number_rows = max_number_rows + self._max_number_cells = max_number_cells + + def start_bundle(self): + self._batch = MutationGroup() + self._size_in_bytes = 0 + self._rows = 0 + self._cells = 0 + + def _reset_count(self): + self._batch = MutationGroup() + self._size_in_bytes = 0 + self._rows = 0 + self._cells = 0 + + def process(self, element): + mg_info = element.info + + if mg_info['byte_size'] + self._size_in_bytes > self._max_batch_size_bytes \ + or mg_info['cells'] + self._cells > self._max_number_cells \ + or mg_info['rows'] + self._rows > self._max_number_rows: + # Batch is full, output the batch and resetting the count. + if self._batch: + yield self._batch + self._reset_count() + + self._batch.extend(element) + + # total byte size of the mutation group. + self._size_in_bytes += mg_info['byte_size'] + + # total rows in the mutation group. + self._rows += mg_info['rows'] + + # total cells in the mutation group. + self._cells += mg_info['cells'] + + def finish_bundle(self): + if self._batch is not None: + yield window.GlobalWindows.windowed_value(self._batch) + self._batch = None + + +@with_input_types(MutationGroup) +@with_output_types(MutationGroup) +class _BatchableFilterFn(DoFn): + """ + Filters MutationGroups larger than the batch size to the output tagged with + OUTPUT_TAG_UNBATCHABLE. + """ + OUTPUT_TAG_UNBATCHABLE = 'unbatchable' + + def __init__(self, max_batch_size_bytes, max_number_rows, max_number_cells): + self._max_batch_size_bytes = max_batch_size_bytes + self._max_number_rows = max_number_rows + self._max_number_cells = max_number_cells + self._batchable = None + self._unbatchable = None + + def process(self, element): + if element.primary().operation == WriteMutation._OPERATION_DELETE: + # As delete mutations are not batchable. + yield TaggedOutput(_BatchableFilterFn.OUTPUT_TAG_UNBATCHABLE, element) + else: + mg_info = element.info + if mg_info['byte_size'] > self._max_batch_size_bytes \ + or mg_info['cells'] > self._max_number_cells \ + or mg_info['rows'] > self._max_number_rows: + yield TaggedOutput(_BatchableFilterFn.OUTPUT_TAG_UNBATCHABLE, element) + else: + yield element + + +class _WriteToSpannerDoFn(DoFn): + def __init__(self, spanner_configuration): + self._spanner_configuration = spanner_configuration + self._db_instance = None + self.batches = Metrics.counter(self.__class__, 'SpannerBatches') + + def setup(self): + spanner_client = Client(self._spanner_configuration.project) + instance = spanner_client.instance(self._spanner_configuration.instance) + self._db_instance = instance.database( + self._spanner_configuration.database, + pool=self._spanner_configuration.pool) + + def process(self, element): + self.batches.inc() + with self._db_instance.batch() as b: + for m in element: + if m.operation == WriteMutation._OPERATION_DELETE: + batch_func = b.delete + elif m.operation == WriteMutation._OPERATION_REPLACE: + batch_func = b.replace + elif m.operation == WriteMutation._OPERATION_INSERT_OR_UPDATE: + batch_func = b.insert_or_update + elif m.operation == WriteMutation._OPERATION_INSERT: + batch_func = b.insert + elif m.operation == WriteMutation._OPERATION_UPDATE: + batch_func = b.update + else: + raise ValueError("Unknown operation action: %s" % m.operation) + + batch_func(**m.kwargs) + + +@with_input_types(typing.Union[MutationGroup, _Mutator]) +@with_output_types(MutationGroup) +class _MakeMutationGroupsFn(DoFn): + """ + Make Mutation group object if the element is the instance of _Mutator. + """ + def process(self, element): + if isinstance(element, MutationGroup): + yield element + elif isinstance(element, _Mutator): + yield MutationGroup([element]) + else: + raise ValueError( + "Invalid object type: %s. Object must be an instance of " + "MutationGroup or WriteMutations" % str(element)) + + +class _WriteGroup(PTransform): + def __init__(self, max_batch_size_bytes, max_number_rows, max_number_cells): + self._max_batch_size_bytes = max_batch_size_bytes + self._max_number_rows = max_number_rows + self._max_number_cells = max_number_cells + + def expand(self, pcoll): + filter_batchable_mutations = ( + pcoll + | 'Making mutation groups' >> ParDo(_MakeMutationGroupsFn()) + | 'Filtering Batchable Mutations' >> ParDo( + _BatchableFilterFn( + max_batch_size_bytes=self._max_batch_size_bytes, + max_number_rows=self._max_number_rows, + max_number_cells=self._max_number_cells)).with_outputs( + _BatchableFilterFn.OUTPUT_TAG_UNBATCHABLE, main='batchable') + ) + + batching_batchables = ( + filter_batchable_mutations['batchable'] + | ParDo( + _BatchFn( + max_batch_size_bytes=self._max_batch_size_bytes, + max_number_rows=self._max_number_rows, + max_number_cells=self._max_number_cells))) + + return (( + batching_batchables, + filter_batchable_mutations[_BatchableFilterFn.OUTPUT_TAG_UNBATCHABLE]) + | 'Merging batchable and unbatchable' >> Flatten()) diff --git a/sdks/python/apache_beam/io/gcp/experimental/spannerio_test.py b/sdks/python/apache_beam/io/gcp/experimental/spannerio_test.py index b63d7cde5036..672bf8b7d242 100644 --- a/sdks/python/apache_beam/io/gcp/experimental/spannerio_test.py +++ b/sdks/python/apache_beam/io/gcp/experimental/spannerio_test.py @@ -26,20 +26,27 @@ import mock import apache_beam as beam +from apache_beam.metrics.metric import MetricsFilter from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to # Protect against environments where spanner library is not available. # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports +# pylint: disable=unused-import try: from google.cloud import spanner - from apache_beam.io.gcp.experimental.spannerio import ( - create_transaction, ReadOperation, ReadFromSpanner) # pylint: disable=unused-import - # disable=unused-import + from apache_beam.io.gcp.experimental.spannerio import create_transaction + from apache_beam.io.gcp.experimental.spannerio import ReadOperation + from apache_beam.io.gcp.experimental.spannerio import ReadFromSpanner + from apache_beam.io.gcp.experimental.spannerio import WriteMutation + from apache_beam.io.gcp.experimental.spannerio import MutationGroup + from apache_beam.io.gcp.experimental.spannerio import WriteToSpanner + from apache_beam.io.gcp.experimental.spannerio import _BatchFn except ImportError: spanner = None # pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports +# pylint: enable=unused-import MAX_DB_NAME_LENGTH = 30 TEST_PROJECT_ID = 'apache-beam-testing' @@ -371,6 +378,238 @@ def test_display_data(self, *args): self.assertTrue("transaction" in dd_transaction) +@unittest.skipIf(spanner is None, 'GCP dependencies are not installed.') +@mock.patch('apache_beam.io.gcp.experimental.spannerio.Client') +@mock.patch('google.cloud.spanner_v1.database.BatchCheckout') +class SpannerWriteTest(unittest.TestCase): + def test_spanner_write(self, mock_batch_snapshot_class, mock_batch_checkout): + ks = spanner.KeySet(keys=[[1233], [1234]]) + + mutations = [ + WriteMutation.delete("roles", ks), + WriteMutation.insert( + "roles", ("key", "rolename"), [('1233', "mutations-inset-1233")]), + WriteMutation.insert( + "roles", ("key", "rolename"), [('1234', "mutations-inset-1234")]), + WriteMutation.update( + "roles", ("key", "rolename"), + [('1234', "mutations-inset-1233-updated")]), + ] + + p = TestPipeline() + _ = ( + p + | beam.Create(mutations) + | WriteToSpanner( + project_id=TEST_PROJECT_ID, + instance_id=TEST_INSTANCE_ID, + database_id=_generate_database_name(), + max_batch_size_bytes=1024)) + res = p.run() + res.wait_until_finish() + + metric_results = res.metrics().query( + MetricsFilter().with_name("SpannerBatches")) + batches_counter = metric_results['counters'][0] + + self.assertEqual(batches_counter.committed, 2) + self.assertEqual(batches_counter.attempted, 2) + + def test_spanner_bundles_size( + self, mock_batch_snapshot_class, mock_batch_checkout): + ks = spanner.KeySet(keys=[[1233], [1234]]) + mutations = [ + WriteMutation.delete("roles", ks), + WriteMutation.insert( + "roles", ("key", "rolename"), [('1234', "mutations-inset-1234")]) + ] * 50 + p = TestPipeline() + _ = ( + p + | beam.Create(mutations) + | WriteToSpanner( + project_id=TEST_PROJECT_ID, + instance_id=TEST_INSTANCE_ID, + database_id=_generate_database_name(), + max_batch_size_bytes=1024)) + res = p.run() + res.wait_until_finish() + + metric_results = res.metrics().query( + MetricsFilter().with_name('SpannerBatches')) + batches_counter = metric_results['counters'][0] + + self.assertEqual(batches_counter.committed, 53) + self.assertEqual(batches_counter.attempted, 53) + + def test_spanner_write_mutation_groups( + self, mock_batch_snapshot_class, mock_batch_checkout): + ks = spanner.KeySet(keys=[[1233], [1234]]) + mutation_groups = [ + MutationGroup([ + WriteMutation.insert( + "roles", ("key", "rolename"), + [('9001233', "mutations-inset-1233")]), + WriteMutation.insert( + "roles", ("key", "rolename"), + [('9001234', "mutations-inset-1234")]) + ]), + MutationGroup([ + WriteMutation.update( + "roles", ("key", "rolename"), + [('9001234', "mutations-inset-9001233-updated")]) + ]), + MutationGroup([WriteMutation.delete("roles", ks)]) + ] + + p = TestPipeline() + _ = ( + p + | beam.Create(mutation_groups) + | WriteToSpanner( + project_id=TEST_PROJECT_ID, + instance_id=TEST_INSTANCE_ID, + database_id=_generate_database_name(), + max_batch_size_bytes=100)) + res = p.run() + res.wait_until_finish() + + metric_results = res.metrics().query( + MetricsFilter().with_name('SpannerBatches')) + batches_counter = metric_results['counters'][0] + + self.assertEqual(batches_counter.committed, 3) + self.assertEqual(batches_counter.attempted, 3) + + def test_batch_byte_size( + self, mock_batch_snapshot_class, mock_batch_checkout): + + # each mutation group byte size is 58 bytes. + mutation_group = [ + MutationGroup([ + WriteMutation.insert( + "roles", + ("key", "rolename"), [('1234', "mutations-inset-1234")]) + ]) + ] * 50 + + with TestPipeline() as p: + # the total 50 mutation group size will be 2900 (58 * 50) + # if we want to make two batches, so batch size should be 1450 (2900 / 2) + # and each bach should contains 25 mutations. + res = ( + p | beam.Create(mutation_group) + | beam.ParDo( + _BatchFn( + max_batch_size_bytes=1450, + max_number_rows=50, + max_number_cells=500)) + | beam.Map(lambda x: len(x))) + assert_that(res, equal_to([25] * 2)) + + def test_batch_disable(self, mock_batch_snapshot_class, mock_batch_checkout): + + mutation_group = [ + MutationGroup([ + WriteMutation.insert( + "roles", + ("key", "rolename"), [('1234', "mutations-inset-1234")]) + ]) + ] * 4 + + with TestPipeline() as p: + # to disable to batching, we need to set any of the batching parameters + # either to lower value or zero + res = ( + p | beam.Create(mutation_group) + | beam.ParDo( + _BatchFn( + max_batch_size_bytes=1450, + max_number_rows=0, + max_number_cells=500)) + | beam.Map(lambda x: len(x))) + assert_that(res, equal_to([1] * 4)) + + def test_batch_max_rows(self, mock_batch_snapshot_class, mock_batch_checkout): + + mutation_group = [ + MutationGroup([ + WriteMutation.insert( + "roles", ("key", "rolename"), + [ + ('1234', "mutations-inset-1234"), + ('1235', "mutations-inset-1235"), + ]) + ]) + ] * 50 + + with TestPipeline() as p: + # There are total 50 mutation groups, each contains two rows. + # The total number of rows will be 100 (50 * 2). + # If each batch contains 10 rows max then batch count should be 10 + # (contains 5 mutation groups each). + res = ( + p | beam.Create(mutation_group) + | beam.ParDo( + _BatchFn( + max_batch_size_bytes=1048576, + max_number_rows=10, + max_number_cells=500)) + | beam.Map(lambda x: len(x))) + assert_that(res, equal_to([5] * 10)) + + def test_batch_max_cells( + self, mock_batch_snapshot_class, mock_batch_checkout): + + mutation_group = [ + MutationGroup([ + WriteMutation.insert( + "roles", ("key", "rolename"), + [ + ('1234', "mutations-inset-1234"), + ('1235', "mutations-inset-1235"), + ]) + ]) + ] * 50 + + with TestPipeline() as p: + # There are total 50 mutation groups, each contains two rows (or 4 cells). + # The total number of cells will be 200 (50 groups * 4 cells). + # If each batch contains 50 cells max then batch count should be 5. + # 4 batches contains 12 mutations groups and the fifth batch should be + # consists of 2 mutation group element. + # No. of mutations groups per batch = Max Cells / Cells per mutation group + # total_batches = Total Number of Cells / Max Cells + res = ( + p | beam.Create(mutation_group) + | beam.ParDo( + _BatchFn( + max_batch_size_bytes=1048576, + max_number_rows=500, + max_number_cells=50)) + | beam.Map(lambda x: len(x))) + assert_that(res, equal_to([12, 12, 12, 12, 2])) + + def test_write_mutation_error(self, *args): + with self.assertRaises(ValueError): + # since `WriteMutation` only accept one operation. + WriteMutation(insert="table-name", update="table-name") + + def test_display_data(self, *args): + data = WriteToSpanner( + project_id=TEST_PROJECT_ID, + instance_id=TEST_INSTANCE_ID, + database_id=_generate_database_name(), + max_batch_size_bytes=1024).display_data() + self.assertTrue("project_id" in data) + self.assertTrue("instance_id" in data) + self.assertTrue("pool" in data) + self.assertTrue("database" in data) + self.assertTrue("batch_size" in data) + self.assertTrue("max_number_rows" in data) + self.assertTrue("max_number_cells" in data) + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main()