Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions sdks/python/apache_beam/transforms/combiners.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import copy
import heapq
import itertools
import operator
import random
from typing import Any
Expand Down Expand Up @@ -597,16 +598,24 @@ def teardown(self):


class _TupleCombineFnBase(core.CombineFn):
def __init__(self, *combiners):
def __init__(self, *combiners, merge_accumulators_batch_size=None):
self._combiners = [core.CombineFn.maybe_from_callable(c) for c in combiners]
self._named_combiners = combiners
# If the `merge_accumulators_batch_size` value is not specified, we chose a
# bounded default that is inversely proportional to the number of
# accumulators in merged tuples.
self._merge_accumulators_batch_size = (
merge_accumulators_batch_size or max(10, 1000 // len(combiners)))

def display_data(self):
combiners = [
c.__name__ if hasattr(c, '__name__') else c.__class__.__name__
for c in self._named_combiners
]
return {'combiners': str(combiners)}
return {
'combiners': str(combiners),
'merge_accumulators_batch_size': self._merge_accumulators_batch_size
}

def setup(self, *args, **kwargs):
for c in self._combiners:
Expand All @@ -616,10 +625,23 @@ def create_accumulator(self, *args, **kwargs):
return [c.create_accumulator(*args, **kwargs) for c in self._combiners]

def merge_accumulators(self, accumulators, *args, **kwargs):
return [
c.merge_accumulators(a, *args, **kwargs) for c,
a in zip(self._combiners, zip(*accumulators))
]
# Make sure that `accumulators` is an iterator (so that the position is
# remembered).
accumulators = iter(accumulators)
result = next(accumulators)
while True:
# Load accumulators into memory and merge in batches to decrease peak
# memory usage.
accumulators_batch = list(
itertools.islice(accumulators, self._merge_accumulators_batch_size))
if not accumulators_batch:
break
accumulators_batch += [result]
result = [
c.merge_accumulators(a, *args, **kwargs) for c,
a in zip(self._combiners, zip(*accumulators_batch))
]
return result

def compact(self, accumulator, *args, **kwargs):
return [
Expand Down
38 changes: 37 additions & 1 deletion sdks/python/apache_beam/transforms/combiners_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,8 @@ def test_basic_combiners_display_data(self):
dd = DisplayData.create_from(transform)
expected_items = [
DisplayDataItemMatcher('combine_fn', combine.TupleCombineFn),
DisplayDataItemMatcher('combiners', "['max', 'MeanCombineFn', 'sum']")
DisplayDataItemMatcher('combiners', "['max', 'MeanCombineFn', 'sum']"),
DisplayDataItemMatcher('merge_accumulators_batch_size', 333),
]
hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))

Expand Down Expand Up @@ -358,6 +359,41 @@ def test_tuple_combine_fn_without_defaults(self):
max).with_common_input()).without_defaults())
assert_that(result, equal_to([(1, 7.0 / 4, 3)]))

def test_tuple_combine_fn_batched_merge(self):
num_combine_fns = 10
max_num_accumulators_in_memory = 30
# Maximum number of accumulator tuples in memory - 1 for the merge result.
merge_accumulators_batch_size = (
max_num_accumulators_in_memory // num_combine_fns - 1)
num_accumulator_tuples_to_merge = 20

class CountedAccumulator:
count = 0
oom = False

def __init__(self):
if CountedAccumulator.count > max_num_accumulators_in_memory:
CountedAccumulator.oom = True
else:
CountedAccumulator.count += 1

class CountedAccumulatorCombineFn(beam.CombineFn):
def create_accumulator(self):
return CountedAccumulator()

def merge_accumulators(self, accumulators):
CountedAccumulator.count += 1
for _ in accumulators:
CountedAccumulator.count -= 1

combine_fn = combine.TupleCombineFn(
*[CountedAccumulatorCombineFn() for _ in range(num_combine_fns)],
merge_accumulators_batch_size=merge_accumulators_batch_size)
combine_fn.merge_accumulators(
combine_fn.create_accumulator()
for _ in range(num_accumulator_tuples_to_merge))
assert not CountedAccumulator.oom

def test_to_list_and_to_dict1(self):
with TestPipeline() as pipeline:
the_list = [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]
Expand Down