diff --git a/sdks/python/apache_beam/transforms/combiners.py b/sdks/python/apache_beam/transforms/combiners.py index bcedd867386c..7e6b1f9d5b88 100644 --- a/sdks/python/apache_beam/transforms/combiners.py +++ b/sdks/python/apache_beam/transforms/combiners.py @@ -21,6 +21,7 @@ import copy import heapq +import itertools import operator import random from typing import Any @@ -597,16 +598,24 @@ def teardown(self): class _TupleCombineFnBase(core.CombineFn): - def __init__(self, *combiners): + def __init__(self, *combiners, merge_accumulators_batch_size=None): self._combiners = [core.CombineFn.maybe_from_callable(c) for c in combiners] self._named_combiners = combiners + # If the `merge_accumulators_batch_size` value is not specified, we chose a + # bounded default that is inversely proportional to the number of + # accumulators in merged tuples. + self._merge_accumulators_batch_size = ( + merge_accumulators_batch_size or max(10, 1000 // len(combiners))) def display_data(self): combiners = [ c.__name__ if hasattr(c, '__name__') else c.__class__.__name__ for c in self._named_combiners ] - return {'combiners': str(combiners)} + return { + 'combiners': str(combiners), + 'merge_accumulators_batch_size': self._merge_accumulators_batch_size + } def setup(self, *args, **kwargs): for c in self._combiners: @@ -616,10 +625,23 @@ def create_accumulator(self, *args, **kwargs): return [c.create_accumulator(*args, **kwargs) for c in self._combiners] def merge_accumulators(self, accumulators, *args, **kwargs): - return [ - c.merge_accumulators(a, *args, **kwargs) for c, - a in zip(self._combiners, zip(*accumulators)) - ] + # Make sure that `accumulators` is an iterator (so that the position is + # remembered). + accumulators = iter(accumulators) + result = next(accumulators) + while True: + # Load accumulators into memory and merge in batches to decrease peak + # memory usage. + accumulators_batch = list( + itertools.islice(accumulators, self._merge_accumulators_batch_size)) + if not accumulators_batch: + break + accumulators_batch += [result] + result = [ + c.merge_accumulators(a, *args, **kwargs) for c, + a in zip(self._combiners, zip(*accumulators_batch)) + ] + return result def compact(self, accumulator, *args, **kwargs): return [ diff --git a/sdks/python/apache_beam/transforms/combiners_test.py b/sdks/python/apache_beam/transforms/combiners_test.py index d82628791ae4..68b273e930b2 100644 --- a/sdks/python/apache_beam/transforms/combiners_test.py +++ b/sdks/python/apache_beam/transforms/combiners_test.py @@ -249,7 +249,8 @@ def test_basic_combiners_display_data(self): dd = DisplayData.create_from(transform) expected_items = [ DisplayDataItemMatcher('combine_fn', combine.TupleCombineFn), - DisplayDataItemMatcher('combiners', "['max', 'MeanCombineFn', 'sum']") + DisplayDataItemMatcher('combiners', "['max', 'MeanCombineFn', 'sum']"), + DisplayDataItemMatcher('merge_accumulators_batch_size', 333), ] hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) @@ -358,6 +359,41 @@ def test_tuple_combine_fn_without_defaults(self): max).with_common_input()).without_defaults()) assert_that(result, equal_to([(1, 7.0 / 4, 3)])) + def test_tuple_combine_fn_batched_merge(self): + num_combine_fns = 10 + max_num_accumulators_in_memory = 30 + # Maximum number of accumulator tuples in memory - 1 for the merge result. + merge_accumulators_batch_size = ( + max_num_accumulators_in_memory // num_combine_fns - 1) + num_accumulator_tuples_to_merge = 20 + + class CountedAccumulator: + count = 0 + oom = False + + def __init__(self): + if CountedAccumulator.count > max_num_accumulators_in_memory: + CountedAccumulator.oom = True + else: + CountedAccumulator.count += 1 + + class CountedAccumulatorCombineFn(beam.CombineFn): + def create_accumulator(self): + return CountedAccumulator() + + def merge_accumulators(self, accumulators): + CountedAccumulator.count += 1 + for _ in accumulators: + CountedAccumulator.count -= 1 + + combine_fn = combine.TupleCombineFn( + *[CountedAccumulatorCombineFn() for _ in range(num_combine_fns)], + merge_accumulators_batch_size=merge_accumulators_batch_size) + combine_fn.merge_accumulators( + combine_fn.create_accumulator() + for _ in range(num_accumulator_tuples_to_merge)) + assert not CountedAccumulator.oom + def test_to_list_and_to_dict1(self): with TestPipeline() as pipeline: the_list = [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]