From 92ff3a63bb5727cfb9e7270610d95b46270240df Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Tue, 6 Feb 2018 13:35:47 -0800 Subject: [PATCH] Optimize reshuffle. --- sdks/python/apache_beam/transforms/trigger.py | 15 ++-- .../apache_beam/transforms/trigger_test.py | 2 +- sdks/python/apache_beam/transforms/util.py | 73 +++++++++++++------ 3 files changed, 59 insertions(+), 31 deletions(-) diff --git a/sdks/python/apache_beam/transforms/trigger.py b/sdks/python/apache_beam/transforms/trigger.py index d47c740d0e12..f0201d9fec51 100644 --- a/sdks/python/apache_beam/transforms/trigger.py +++ b/sdks/python/apache_beam/transforms/trigger.py @@ -34,6 +34,7 @@ from apache_beam.transforms import core from apache_beam.transforms.timeutil import TimeDomain from apache_beam.transforms.window import GlobalWindow +from apache_beam.transforms.window import GlobalWindows from apache_beam.transforms.window import TimestampCombiner from apache_beam.transforms.window import WindowedValue from apache_beam.transforms.window import WindowFn @@ -902,7 +903,12 @@ def create_trigger_driver(windowing, # TODO(robertwb): We can do more if we know elements are in timestamp # sorted order. if windowing.is_default() and is_batch: - driver = DefaultGlobalBatchTriggerDriver() + driver = DiscardingGlobalTriggerDriver() + elif (windowing.windowfn == GlobalWindows() + and windowing.triggerfn == AfterCount(1) + and windowing.accumulation_mode == AccumulationMode.DISCARDING): + # Here we also just pass through all the values every time. + driver = DiscardingGlobalTriggerDriver() else: driver = GeneralTriggerDriver(windowing, clock) @@ -971,14 +977,11 @@ def __ne__(self, other): return not self == other -class DefaultGlobalBatchTriggerDriver(TriggerDriver): - """Breaks a bundles into window (pane)s according to the default triggering. +class DiscardingGlobalTriggerDriver(TriggerDriver): + """Groups all received values together. """ GLOBAL_WINDOW_TUPLE = (GlobalWindow(),) - def __init__(self): - pass - def process_elements(self, state, windowed_values, unused_output_watermark): yield WindowedValue( _UnwindowedValues(windowed_values), diff --git a/sdks/python/apache_beam/transforms/trigger_test.py b/sdks/python/apache_beam/transforms/trigger_test.py index a765493ae810..2e672bb0cf1b 100644 --- a/sdks/python/apache_beam/transforms/trigger_test.py +++ b/sdks/python/apache_beam/transforms/trigger_test.py @@ -375,7 +375,7 @@ def test_sessions_after_each(self): def test_picklable_output(self): global_window = trigger.GlobalWindow(), - driver = trigger.DefaultGlobalBatchTriggerDriver() + driver = trigger.DiscardingGlobalTriggerDriver() unpicklable = (WindowedValue(k, 0, global_window) for k in range(10)) with self.assertRaises(TypeError): diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index 2be94332cf3c..61c2eaffd8fa 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -35,7 +35,7 @@ from apache_beam.transforms.core import GroupByKey from apache_beam.transforms.core import Map from apache_beam.transforms.core import ParDo -from apache_beam.transforms.core import WindowInto +from apache_beam.transforms.core import Windowing from apache_beam.transforms.ptransform import PTransform from apache_beam.transforms.ptransform import ptransform_fn from apache_beam.transforms.trigger import AccumulationMode @@ -485,33 +485,58 @@ class ReshufflePerKey(PTransform): """ def expand(self, pcoll): - class ReifyTimestamps(DoFn): - def process(self, element, timestamp=DoFn.TimestampParam): - yield element[0], TimestampedValue(element[1], timestamp) + windowing_saved = pcoll.windowing + if windowing_saved.is_default(): + # In this (common) case we can use a trivial trigger driver + # and avoid the (expensive) window param. + globally_windowed = window.GlobalWindows.windowed_value(None) + window_fn = window.GlobalWindows() + MIN_TIMESTAMP = window.MIN_TIMESTAMP + + def reify_timestamps(element, timestamp=DoFn.TimestampParam): + key, value = element + if timestamp == MIN_TIMESTAMP: + timestamp = None + return key, (value, timestamp) + + def restore_timestamps(element): + key, values = element + return [ + globally_windowed.with_value((key, value)) + if timestamp is None + else window.GlobalWindows.windowed_value((key, value), timestamp) + for (value, timestamp) in values] - class RestoreTimestamps(DoFn): - def process(self, element, window=DoFn.WindowParam): + else: + # The linter is confused. + # hash(1) is used to force "runtime" selection of _IdentityWindowFn + # pylint: disable=abstract-class-instantiated + cls = hash(1) and _IdentityWindowFn + window_fn = cls( + windowing_saved.windowfn.get_window_coder()) + + def reify_timestamps(element, timestamp=DoFn.TimestampParam): + key, value = element + return key, TimestampedValue(value, timestamp) + + def restore_timestamps(element, window=DoFn.WindowParam): # Pass the current window since _IdentityWindowFn wouldn't know how # to generate it. - yield windowed_value.WindowedValue( - (element[0], element[1].value), element[1].timestamp, [window]) - - windowing_saved = pcoll.windowing - # The linter is confused. - # pylint: disable=abstract-class-instantiated - result = (pcoll - | ParDo(ReifyTimestamps()) - | 'IdentityWindow' >> WindowInto( - _IdentityWindowFn( - windowing_saved.windowfn.get_window_coder()), - trigger=AfterCount(1), - accumulation_mode=AccumulationMode.DISCARDING, - timestamp_combiner=TimestampCombiner.OUTPUT_AT_EARLIEST, - ) + key, values = element + return [ + windowed_value.WindowedValue( + (key, value.value), value.timestamp, [window]) + for value in values] + + ungrouped = pcoll | Map(reify_timestamps) + ungrouped._windowing = Windowing( + window_fn, + triggerfn=AfterCount(1), + accumulation_mode=AccumulationMode.DISCARDING, + timestamp_combiner=TimestampCombiner.OUTPUT_AT_EARLIEST) + result = (ungrouped | GroupByKey() - | 'ExpandIterable' >> FlatMap( - lambda e: [(e[0], value) for value in e[1]]) - | ParDo(RestoreTimestamps())) + | FlatMap(restore_timestamps)) result._windowing = windowing_saved return result