Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 31 additions & 9 deletions sdks/python/apache_beam/io/gcp/bigtableio.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,14 @@
from apache_beam.typehints.row_type import RowTypeConstraint

_LOGGER = logging.getLogger(__name__)
FLUSH_COUNT = 1000
MAX_ROW_BYTES = 5242880 # 5MB

try:
from google.cloud.bigtable import Client
from google.cloud.bigtable.row import Cell, PartialRowData
from google.cloud.bigtable.batcher import MutationsBatcher

FLUSH_COUNT = 1000
MAX_ROW_BYTES = 5242880 # 5MB

except ImportError:
_LOGGER.warning(
'ImportError: from google.cloud.bigtable import Client', exc_info=True)
Expand All @@ -78,20 +77,27 @@ class _BigTableWriteFn(beam.DoFn):
project_id(str): GCP Project ID
instance_id(str): GCP Instance ID
table_id(str): GCP Table ID
flush_count(int): Max number of rows to flush
max_row_bytes(int) Max number of row mutations size to flush

"""
def __init__(self, project_id, instance_id, table_id):
def __init__(
self, project_id, instance_id, table_id, flush_count, max_row_bytes):
""" Constructor of the Write connector of Bigtable
Args:
project_id(str): GCP Project of to write the Rows
instance_id(str): GCP Instance to write the Rows
table_id(str): GCP Table to write the `DirectRows`
flush_count(int): Max number of rows to flush
max_row_bytes(int) Max number of row mutations size to flush
"""
super().__init__()
self.beam_options = {
'project_id': project_id,
'instance_id': instance_id,
'table_id': table_id
'table_id': table_id,
'flush_count': flush_count,
'max_row_bytes': max_row_bytes,
}
self.table = None
self.batcher = None
Expand Down Expand Up @@ -144,8 +150,8 @@ def start_bundle(self):
self.batcher = MutationsBatcher(
self.table,
batch_completed_callback=self.write_mutate_metrics,
flush_count=FLUSH_COUNT,
max_row_bytes=MAX_ROW_BYTES)
flush_count=self.beam_options['flush_count'],
max_row_bytes=self.beam_options['max_row_bytes'])

def process(self, row):
self.written.inc()
Expand Down Expand Up @@ -200,7 +206,10 @@ def __init__(
instance_id,
table_id,
use_cross_language=False,
expansion_service=None):
expansion_service=None,
flush_count=FLUSH_COUNT,
max_row_bytes=MAX_ROW_BYTES,
):
"""Initialize an WriteToBigTable transform.

:param table_id:
Expand All @@ -215,6 +224,12 @@ def __init__(
The address of the expansion service in the case of using cross-language.
If no expansion service is provided, will attempt to run the default GCP
expansion service.
:type flush_count: int
:param flush_count: (Optional) Max number of rows to flush.
Default is FLUSH_COUNT (1000 rows).
:type max_row_bytes: int
:param max_row_bytes: (Optional) Max number of row mutations size to flush.
Default is MAX_ROW_BYTES (5 MB).
"""
super().__init__()
self._table_id = table_id
Expand All @@ -229,6 +244,9 @@ def __init__(
SchemaAwareExternalTransform.discover_config(
self._expansion_service, self.URN))

self._flush_count = flush_count
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probably need to log warnings if using cross-lang is true since these two parameters are not supported.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._max_row_bytes = max_row_bytes

def expand(self, input):
if self._use_cross_language:
external_write = SchemaAwareExternalTransform(
Expand All @@ -250,7 +268,11 @@ def expand(self, input):
input
| beam.ParDo(
_BigTableWriteFn(
self._project_id, self._instance_id, self._table_id)))
self._project_id,
self._instance_id,
self._table_id,
flush_count=self._flush_count,
max_row_bytes=self._max_row_bytes)))

class _DirectRowMutationsToBeamRow(beam.DoFn):
def process(self, direct_row):
Expand Down
24 changes: 23 additions & 1 deletion sdks/python/apache_beam/io/gcp/bigtableio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,11 @@ def test_write(self):
def test_write_metrics(self):
MetricsEnvironment.process_wide_container().reset()
write_fn = bigtableio._BigTableWriteFn(
self._PROJECT_ID, self._INSTANCE_ID, self._TABLE_ID)
self._PROJECT_ID,
self._INSTANCE_ID,
self._TABLE_ID,
flush_count=1000,
max_row_bytes=5242880)
write_fn.table = self.table
write_fn.start_bundle()
number_of_rows = 2
Expand Down Expand Up @@ -363,6 +367,24 @@ def verify_write_call_metric(
self.assertTrue(
found, "Did not find write call metric with status: %s" % status)

def test_custom_flush_config(self):
direct_rows = [self.generate_row(0)]
with patch.object(
MutationsBatcher, '__init__', return_value=None) as mock_init, \
patch.object(MutationsBatcher, 'mutate'), \
patch.object(MutationsBatcher, 'close'), TestPipeline() as p:
_ = p | beam.Create(direct_rows) | bigtableio.WriteToBigTable(
self._PROJECT_ID,
self._INSTANCE_ID,
self._TABLE_ID,
flush_count=1001,
max_row_bytes=5000001)

mock_init.assert_called_once()
call_args = mock_init.call_args.kwargs
assert call_args['flush_count'] == 1001
assert call_args['max_row_bytes'] == 5000001


if __name__ == '__main__':
unittest.main()
Loading