Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions sdks/python/apache_beam/transforms/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from apache_beam.transforms import core
from apache_beam.transforms.timeutil import TimeDomain
from apache_beam.transforms.window import GlobalWindow
from apache_beam.transforms.window import GlobalWindows
from apache_beam.transforms.window import TimestampCombiner
from apache_beam.transforms.window import WindowedValue
from apache_beam.transforms.window import WindowFn
Expand Down Expand Up @@ -902,7 +903,12 @@ def create_trigger_driver(windowing,
# TODO(robertwb): We can do more if we know elements are in timestamp
# sorted order.
if windowing.is_default() and is_batch:
driver = DefaultGlobalBatchTriggerDriver()
driver = DiscardingGlobalTriggerDriver()
elif (windowing.windowfn == GlobalWindows()
and windowing.triggerfn == AfterCount(1)
and windowing.accumulation_mode == AccumulationMode.DISCARDING):
# Here we also just pass through all the values every time.
driver = DiscardingGlobalTriggerDriver()
else:
driver = GeneralTriggerDriver(windowing, clock)

Expand Down Expand Up @@ -971,14 +977,11 @@ def __ne__(self, other):
return not self == other


class DefaultGlobalBatchTriggerDriver(TriggerDriver):
"""Breaks a bundles into window (pane)s according to the default triggering.
class DiscardingGlobalTriggerDriver(TriggerDriver):
"""Groups all received values together.
"""
GLOBAL_WINDOW_TUPLE = (GlobalWindow(),)

def __init__(self):
pass

def process_elements(self, state, windowed_values, unused_output_watermark):
yield WindowedValue(
_UnwindowedValues(windowed_values),
Expand Down
2 changes: 1 addition & 1 deletion sdks/python/apache_beam/transforms/trigger_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def test_sessions_after_each(self):

def test_picklable_output(self):
global_window = trigger.GlobalWindow(),
driver = trigger.DefaultGlobalBatchTriggerDriver()
driver = trigger.DiscardingGlobalTriggerDriver()
unpicklable = (WindowedValue(k, 0, global_window)
for k in range(10))
with self.assertRaises(TypeError):
Expand Down
73 changes: 49 additions & 24 deletions sdks/python/apache_beam/transforms/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from apache_beam.transforms.core import GroupByKey
from apache_beam.transforms.core import Map
from apache_beam.transforms.core import ParDo
from apache_beam.transforms.core import WindowInto
from apache_beam.transforms.core import Windowing
from apache_beam.transforms.ptransform import PTransform
from apache_beam.transforms.ptransform import ptransform_fn
from apache_beam.transforms.trigger import AccumulationMode
Expand Down Expand Up @@ -485,33 +485,58 @@ class ReshufflePerKey(PTransform):
"""

def expand(self, pcoll):
class ReifyTimestamps(DoFn):
def process(self, element, timestamp=DoFn.TimestampParam):
yield element[0], TimestampedValue(element[1], timestamp)
windowing_saved = pcoll.windowing
if windowing_saved.is_default():
# In this (common) case we can use a trivial trigger driver
# and avoid the (expensive) window param.
globally_windowed = window.GlobalWindows.windowed_value(None)
window_fn = window.GlobalWindows()
MIN_TIMESTAMP = window.MIN_TIMESTAMP

def reify_timestamps(element, timestamp=DoFn.TimestampParam):
key, value = element
if timestamp == MIN_TIMESTAMP:
timestamp = None
return key, (value, timestamp)

def restore_timestamps(element):
key, values = element
return [
globally_windowed.with_value((key, value))
if timestamp is None
else window.GlobalWindows.windowed_value((key, value), timestamp)
for (value, timestamp) in values]

class RestoreTimestamps(DoFn):
def process(self, element, window=DoFn.WindowParam):
else:
# The linter is confused.
# hash(1) is used to force "runtime" selection of _IdentityWindowFn
# pylint: disable=abstract-class-instantiated
cls = hash(1) and _IdentityWindowFn
window_fn = cls(
windowing_saved.windowfn.get_window_coder())

def reify_timestamps(element, timestamp=DoFn.TimestampParam):
key, value = element
return key, TimestampedValue(value, timestamp)

def restore_timestamps(element, window=DoFn.WindowParam):
# Pass the current window since _IdentityWindowFn wouldn't know how
# to generate it.
yield windowed_value.WindowedValue(
(element[0], element[1].value), element[1].timestamp, [window])

windowing_saved = pcoll.windowing
# The linter is confused.
# pylint: disable=abstract-class-instantiated
result = (pcoll
| ParDo(ReifyTimestamps())
| 'IdentityWindow' >> WindowInto(
_IdentityWindowFn(
windowing_saved.windowfn.get_window_coder()),
trigger=AfterCount(1),
accumulation_mode=AccumulationMode.DISCARDING,
timestamp_combiner=TimestampCombiner.OUTPUT_AT_EARLIEST,
)
key, values = element
return [
windowed_value.WindowedValue(
(key, value.value), value.timestamp, [window])
for value in values]

ungrouped = pcoll | Map(reify_timestamps)
ungrouped._windowing = Windowing(
window_fn,
triggerfn=AfterCount(1),
accumulation_mode=AccumulationMode.DISCARDING,
timestamp_combiner=TimestampCombiner.OUTPUT_AT_EARLIEST)
result = (ungrouped
| GroupByKey()
| 'ExpandIterable' >> FlatMap(
lambda e: [(e[0], value) for value in e[1]])
| ParDo(RestoreTimestamps()))
| FlatMap(restore_timestamps))
result._windowing = windowing_saved
return result

Expand Down