Skip to content
20 changes: 16 additions & 4 deletions sdks/python/apache_beam/testing/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
import glob
import io
import tempfile
from typing import Any
from typing import Iterable
from typing import List
from typing import NamedTuple

from apache_beam import pvalue
from apache_beam.transforms import window
Expand All @@ -35,6 +38,8 @@
from apache_beam.transforms.ptransform import PTransform
from apache_beam.transforms.ptransform import ptransform_fn
from apache_beam.transforms.util import CoGroupByKey
from apache_beam.utils.windowed_value import PANE_INFO_UNKNOWN
from apache_beam.utils.windowed_value import PaneInfo

__all__ = [
'assert_that',
Expand All @@ -56,8 +61,11 @@ class BeamAssertException(Exception):


# Used for reifying timestamps and windows for assert_that matchers.
TestWindowedValue = collections.namedtuple(
'TestWindowedValue', 'value timestamp windows')
class TestWindowedValue(NamedTuple):
value: Any
timestamp: Any
windows: List
pane_info: PaneInfo = PANE_INFO_UNKNOWN


def contains_in_any_order(iterable):
Expand Down Expand Up @@ -290,11 +298,15 @@ def assert_that(

class ReifyTimestampWindow(DoFn):
def process(
self, element, timestamp=DoFn.TimestampParam, window=DoFn.WindowParam):
self,
element,
timestamp=DoFn.TimestampParam,
window=DoFn.WindowParam,
pane_info=DoFn.PaneInfoParam):
# This returns TestWindowedValue instead of
# beam.utils.windowed_value.WindowedValue because ParDo will extract
# the timestamp and window out of the latter.
return [TestWindowedValue(element, timestamp, [window])]
return [TestWindowedValue(element, timestamp, [window], pane_info)]

class AddWindow(DoFn):
def process(self, element, window=DoFn.WindowParam):
Expand Down
78 changes: 75 additions & 3 deletions sdks/python/apache_beam/transforms/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,8 +933,6 @@ def is_compat_version_prior_to(options, breaking_change_version):
# keep the old behavior prior to a breaking change or use the new behavior.
# - If update_compatibility_version < breaking_change_version, we will return
# True and keep the old behavior.
# - If update_compatibility_version is None or >= breaking_change_version, we
# will return False and use the behavior from the breaking change.
update_compatibility_version = options.view_as(
pipeline_options.StreamingOptions).update_compatibility_version

Expand All @@ -949,6 +947,53 @@ def is_compat_version_prior_to(options, breaking_change_version):
return False


def reify_metadata_default_window(
element, timestamp=DoFn.TimestampParam, pane_info=DoFn.PaneInfoParam):
key, value = element
if timestamp == window.MIN_TIMESTAMP:
timestamp = None
return key, (value, timestamp, pane_info)


def restore_metadata_default_window(element):
key, values = element
return [
window.GlobalWindows.windowed_value(None).with_value((key, value))
if timestamp is None else window.GlobalWindows.windowed_value(
value=(key, value), timestamp=timestamp, pane_info=pane_info)
for (value, timestamp, pane_info) in values
]


def reify_metadata_custom_window(
element,
timestamp=DoFn.TimestampParam,
window=DoFn.WindowParam,
pane_info=DoFn.PaneInfoParam):
key, value = element
return key, windowed_value.WindowedValue(
value, timestamp, [window], pane_info)


def restore_metadata_custom_window(element):
key, windowed_values = element
return [wv.with_value((key, wv.value)) for wv in windowed_values]


def _reify_restore_metadata(is_default_windowing):
if is_default_windowing:
return reify_metadata_default_window, restore_metadata_default_window
return reify_metadata_custom_window, restore_metadata_custom_window


def _add_pre_map_gkb_types(pre_gbk_map, is_default_windowing):
if is_default_windowing:
return pre_gbk_map.with_input_types(tuple[K, V]).with_output_types(
tuple[K, tuple[V, Optional[Timestamp], windowed_value.PaneInfo]])
return pre_gbk_map.with_input_types(tuple[K, V]).with_output_types(
tuple[K, TypedWindowedValue[V]])


@typehints.with_input_types(tuple[K, V])
@typehints.with_output_types(tuple[K, V])
class ReshufflePerKey(PTransform):
Expand All @@ -957,7 +1002,7 @@ class ReshufflePerKey(PTransform):
in particular checkpointing, and preventing fusion of the surrounding
transforms.
"""
def expand(self, pcoll):
def expand_2_64_0(self, pcoll):
windowing_saved = pcoll.windowing
if windowing_saved.is_default():
# In this (common) case we can use a trivial trigger driver
Expand Down Expand Up @@ -1023,6 +1068,33 @@ def restore_timestamps(element):
result._windowing = windowing_saved
return result

def expand(self, pcoll):
if is_compat_version_prior_to(pcoll.pipeline.options, "2.65.0"):
return self.expand_2_64_0(pcoll)

windowing_saved = pcoll.windowing
is_default_windowing = windowing_saved.is_default()
reify_fn, restore_fn = _reify_restore_metadata(is_default_windowing)

pre_gbk_map = _add_pre_map_gkb_types(Map(reify_fn), is_default_windowing)

ungrouped = pcoll | pre_gbk_map

# TODO(https://github.com/apache/beam/issues/19785) Using global window as
# one of the standard window. This is to mitigate the Dataflow Java Runner
# Harness limitation to accept only standard coders.
ungrouped._windowing = Windowing(
window.GlobalWindows(),
triggerfn=Always(),
accumulation_mode=AccumulationMode.DISCARDING,
timestamp_combiner=TimestampCombiner.OUTPUT_AT_EARLIEST)
result = (
ungrouped
| GroupByKey()
| FlatMap(restore_fn).with_output_types(Any))
result._windowing = windowing_saved
return result


@typehints.with_input_types(T)
@typehints.with_output_types(T)
Expand Down
Loading
Loading