diff --git a/sdks/python/apache_beam/testing/util.py b/sdks/python/apache_beam/testing/util.py index 5f2d211d2b72..cbb2119b83f6 100644 --- a/sdks/python/apache_beam/testing/util.py +++ b/sdks/python/apache_beam/testing/util.py @@ -23,7 +23,10 @@ import glob import io import tempfile +from typing import Any from typing import Iterable +from typing import List +from typing import NamedTuple from apache_beam import pvalue from apache_beam.transforms import window @@ -35,6 +38,8 @@ from apache_beam.transforms.ptransform import PTransform from apache_beam.transforms.ptransform import ptransform_fn from apache_beam.transforms.util import CoGroupByKey +from apache_beam.utils.windowed_value import PANE_INFO_UNKNOWN +from apache_beam.utils.windowed_value import PaneInfo __all__ = [ 'assert_that', @@ -56,8 +61,11 @@ class BeamAssertException(Exception): # Used for reifying timestamps and windows for assert_that matchers. -TestWindowedValue = collections.namedtuple( - 'TestWindowedValue', 'value timestamp windows') +class TestWindowedValue(NamedTuple): + value: Any + timestamp: Any + windows: List + pane_info: PaneInfo = PANE_INFO_UNKNOWN def contains_in_any_order(iterable): @@ -290,11 +298,15 @@ def assert_that( class ReifyTimestampWindow(DoFn): def process( - self, element, timestamp=DoFn.TimestampParam, window=DoFn.WindowParam): + self, + element, + timestamp=DoFn.TimestampParam, + window=DoFn.WindowParam, + pane_info=DoFn.PaneInfoParam): # This returns TestWindowedValue instead of # beam.utils.windowed_value.WindowedValue because ParDo will extract # the timestamp and window out of the latter. - return [TestWindowedValue(element, timestamp, [window])] + return [TestWindowedValue(element, timestamp, [window], pane_info)] class AddWindow(DoFn): def process(self, element, window=DoFn.WindowParam): diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index 812c95c36519..7c3a1929ba9d 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -933,8 +933,6 @@ def is_compat_version_prior_to(options, breaking_change_version): # keep the old behavior prior to a breaking change or use the new behavior. # - If update_compatibility_version < breaking_change_version, we will return # True and keep the old behavior. - # - If update_compatibility_version is None or >= breaking_change_version, we - # will return False and use the behavior from the breaking change. update_compatibility_version = options.view_as( pipeline_options.StreamingOptions).update_compatibility_version @@ -949,6 +947,53 @@ def is_compat_version_prior_to(options, breaking_change_version): return False +def reify_metadata_default_window( + element, timestamp=DoFn.TimestampParam, pane_info=DoFn.PaneInfoParam): + key, value = element + if timestamp == window.MIN_TIMESTAMP: + timestamp = None + return key, (value, timestamp, pane_info) + + +def restore_metadata_default_window(element): + key, values = element + return [ + window.GlobalWindows.windowed_value(None).with_value((key, value)) + if timestamp is None else window.GlobalWindows.windowed_value( + value=(key, value), timestamp=timestamp, pane_info=pane_info) + for (value, timestamp, pane_info) in values + ] + + +def reify_metadata_custom_window( + element, + timestamp=DoFn.TimestampParam, + window=DoFn.WindowParam, + pane_info=DoFn.PaneInfoParam): + key, value = element + return key, windowed_value.WindowedValue( + value, timestamp, [window], pane_info) + + +def restore_metadata_custom_window(element): + key, windowed_values = element + return [wv.with_value((key, wv.value)) for wv in windowed_values] + + +def _reify_restore_metadata(is_default_windowing): + if is_default_windowing: + return reify_metadata_default_window, restore_metadata_default_window + return reify_metadata_custom_window, restore_metadata_custom_window + + +def _add_pre_map_gkb_types(pre_gbk_map, is_default_windowing): + if is_default_windowing: + return pre_gbk_map.with_input_types(tuple[K, V]).with_output_types( + tuple[K, tuple[V, Optional[Timestamp], windowed_value.PaneInfo]]) + return pre_gbk_map.with_input_types(tuple[K, V]).with_output_types( + tuple[K, TypedWindowedValue[V]]) + + @typehints.with_input_types(tuple[K, V]) @typehints.with_output_types(tuple[K, V]) class ReshufflePerKey(PTransform): @@ -957,7 +1002,7 @@ class ReshufflePerKey(PTransform): in particular checkpointing, and preventing fusion of the surrounding transforms. """ - def expand(self, pcoll): + def expand_2_64_0(self, pcoll): windowing_saved = pcoll.windowing if windowing_saved.is_default(): # In this (common) case we can use a trivial trigger driver @@ -1023,6 +1068,33 @@ def restore_timestamps(element): result._windowing = windowing_saved return result + def expand(self, pcoll): + if is_compat_version_prior_to(pcoll.pipeline.options, "2.65.0"): + return self.expand_2_64_0(pcoll) + + windowing_saved = pcoll.windowing + is_default_windowing = windowing_saved.is_default() + reify_fn, restore_fn = _reify_restore_metadata(is_default_windowing) + + pre_gbk_map = _add_pre_map_gkb_types(Map(reify_fn), is_default_windowing) + + ungrouped = pcoll | pre_gbk_map + + # TODO(https://github.com/apache/beam/issues/19785) Using global window as + # one of the standard window. This is to mitigate the Dataflow Java Runner + # Harness limitation to accept only standard coders. + ungrouped._windowing = Windowing( + window.GlobalWindows(), + triggerfn=Always(), + accumulation_mode=AccumulationMode.DISCARDING, + timestamp_combiner=TimestampCombiner.OUTPUT_AT_EARLIEST) + result = ( + ungrouped + | GroupByKey() + | FlatMap(restore_fn).with_output_types(Any)) + result._windowing = windowing_saved + return result + @typehints.with_input_types(T) @typehints.with_output_types(T) diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py index 2443a049ddba..c8304255238c 100644 --- a/sdks/python/apache_beam/transforms/util_test.py +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -18,6 +18,7 @@ """Unit tests for the transform.util classes.""" # pytype: skip-file +# pylint: disable=too-many-function-args import collections import importlib @@ -33,6 +34,8 @@ import pytest import pytz +from parameterized import param +from parameterized import parameterized import apache_beam as beam from apache_beam import GroupByKey @@ -75,6 +78,9 @@ from apache_beam.utils import timestamp from apache_beam.utils.timestamp import MAX_TIMESTAMP from apache_beam.utils.timestamp import MIN_TIMESTAMP +from apache_beam.utils.windowed_value import PANE_INFO_UNKNOWN +from apache_beam.utils.windowed_value import PaneInfo +from apache_beam.utils.windowed_value import PaneInfoTiming from apache_beam.utils.windowed_value import WindowedValue warnings.filterwarnings( @@ -793,7 +799,10 @@ def test_reshuffle_windows_unchanged(self): with TestPipeline() as pipeline: data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)] expected_data = [ - TestWindowedValue(v, t - .001, [w]) + TestWindowedValue( + v, + t - .001, [w], + pane_info=PaneInfo(True, False, PaneInfoTiming.ON_TIME, 0, 0)) for (v, t, w) in [((1, contains_in_any_order([2, 1])), 4.0, IntervalWindow(1.0, 4.0)), @@ -826,6 +835,7 @@ def test_reshuffle_window_fn_preserved(self): any_order = contains_in_any_order with TestPipeline() as pipeline: data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)] + expected_windows = [ TestWindowedValue(v, t, [w]) for (v, t, w) in [((1, 1), 1.0, IntervalWindow(1.0, 3.0)), ( @@ -838,7 +848,10 @@ def test_reshuffle_window_fn_preserved(self): IntervalWindow(4.0, 6.0))] ] expected_merged_windows = [ - TestWindowedValue(v, t - .001, [w]) + TestWindowedValue( + v, + t - .001, [w], + pane_info=PaneInfo(True, False, PaneInfoTiming.ON_TIME, 0, 0)) for (v, t, w) in [((1, any_order([2, 1])), 4.0, IntervalWindow(1.0, 4.0)), ( (2, any_order([2, 1])), 4.0, IntervalWindow(1.0, 4.0)), ( @@ -942,6 +955,209 @@ def test_reshuffle_streaming_global_window_with_buckets(self): assert_that( after_reshuffle, equal_to(expected_data), label='after reshuffle') + @parameterized.expand([ + param(compat_version=None), + param(compat_version="2.64.0"), + ]) + def test_reshuffle_custom_window_preserves_metadata(self, compat_version): + """Tests that Reshuffle preserves pane info.""" + element_count = 12 + timestamp_value = timestamp.Timestamp(0) + l = [ + TimestampedValue(("key", i), timestamp_value) + for i in range(element_count) + ] + + expected_timestamp = GlobalWindow().max_timestamp() + expected = [ + TestWindowedValue( + ('key', [0, 1, 2]), + expected_timestamp, + [GlobalWindow()], + pane_info=PaneInfo( + is_first=True, + is_last=False, + timing=PaneInfoTiming.EARLY, # 0 + index=0, + nonspeculative_index=-1 + ) + ), + TestWindowedValue( + ('key', [3, 4, 5]), + expected_timestamp, + [GlobalWindow()], + pane_info=PaneInfo( + is_first=False, + is_last=False, + timing=PaneInfoTiming.EARLY, # 0 + index=1, + nonspeculative_index=-1 + ) + ), + TestWindowedValue( + ('key', [6, 7, 8]), + expected_timestamp, + [GlobalWindow()], + pane_info=PaneInfo( + is_first=False, + is_last=False, + timing=PaneInfoTiming.EARLY, # 0 + index=2, + nonspeculative_index=-1 + ) + ), + TestWindowedValue( + ('key', [9, 10, 11]), + expected_timestamp, + [GlobalWindow()], + pane_info=PaneInfo( + is_first=False, + is_last=False, + timing=PaneInfoTiming.EARLY, # 0 + index=3, + nonspeculative_index=-1 + ) + ) + ] if compat_version is None else ( + [ + TestWindowedValue( + ('key', [0, 1, 2]), + expected_timestamp, + [GlobalWindow()], + PANE_INFO_UNKNOWN + ), + TestWindowedValue( + ('key', [3, 4, 5]), + expected_timestamp, + [GlobalWindow()], + PANE_INFO_UNKNOWN + ), + TestWindowedValue( + ('key', [6, 7, 8]), + expected_timestamp, + [GlobalWindow()], + PANE_INFO_UNKNOWN + ), + TestWindowedValue( + ('key', [9, 10, 11]), + expected_timestamp, + [GlobalWindow()], + PANE_INFO_UNKNOWN + ) + ] + ) + + options = PipelineOptions(update_compatibility_version=compat_version) + options.view_as(StandardOptions).streaming = True + + with beam.Pipeline(options=options) as p: + stream_source = ( + TestStream().advance_watermark_to(0).advance_processing_time( + 100).add_elements(l[:element_count // 4]).advance_processing_time( + 100).advance_watermark_to(100).add_elements( + l[element_count // 4:2 * element_count // 4]). + advance_processing_time(100).advance_watermark_to(200).add_elements( + l[2 * element_count // 4:3 * element_count // + 4]).advance_processing_time( + 100).advance_watermark_to(300).add_elements( + l[3 * element_count // 4:]).advance_processing_time( + 100).advance_watermark_to_infinity()) + grouped = ( + p | stream_source + | "Rewindow" >> beam.WindowInto( + beam.window.GlobalWindows(), + trigger=trigger.Repeatedly(trigger.AfterProcessingTime(1)), + accumulation_mode=trigger.AccumulationMode.DISCARDING) + | beam.GroupByKey()) + + after_reshuffle = (grouped | 'Reshuffle' >> beam.Reshuffle()) + + assert_that( + after_reshuffle, + equal_to(expected), + label='CheckMetadataPreserved', + reify_windows=True) + + @parameterized.expand([ + param(compat_version=None), + param(compat_version="2.64.0"), + ]) + def test_reshuffle_default_window_preserves_metadata(self, compat_version): + """Tests that Reshuffle preserves timestamp, window, and pane info + metadata.""" + + no_firing = PaneInfo( + is_first=True, + is_last=True, + timing=PaneInfoTiming.UNKNOWN, + index=0, + nonspeculative_index=0) + + on_time_only = PaneInfo( + is_first=True, + is_last=True, + timing=PaneInfoTiming.ON_TIME, + index=0, + nonspeculative_index=0) + + late_firing = PaneInfo( + is_first=False, + is_last=False, + timing=PaneInfoTiming.LATE, + index=1, + nonspeculative_index=1) + + expected_preserved = [ + TestWindowedValue('a', MIN_TIMESTAMP, [GlobalWindow()], no_firing), + TestWindowedValue( + 'b', timestamp.Timestamp(0), [GlobalWindow()], on_time_only), + TestWindowedValue( + 'c', timestamp.Timestamp(33), [GlobalWindow()], late_firing), + TestWindowedValue( + 'd', GlobalWindow().max_timestamp(), [GlobalWindow()], no_firing) + ] + + expected_not_preserved = [ + TestWindowedValue( + 'a', MIN_TIMESTAMP, [GlobalWindow()], PANE_INFO_UNKNOWN), + TestWindowedValue( + 'b', timestamp.Timestamp(0), [GlobalWindow()], PANE_INFO_UNKNOWN), + TestWindowedValue( + 'c', timestamp.Timestamp(33), [GlobalWindow()], PANE_INFO_UNKNOWN), + TestWindowedValue( + 'd', + GlobalWindow().max_timestamp(), [GlobalWindow()], + PANE_INFO_UNKNOWN) + ] + + expected = ( + expected_preserved + if compat_version is None else expected_not_preserved) + + options = PipelineOptions(update_compatibility_version=compat_version) + with TestPipeline(options=options) as pipeline: + # Create windowed values with specific metadata + elements = [ + WindowedValue('a', MIN_TIMESTAMP, [GlobalWindow()], no_firing), + WindowedValue( + 'b', timestamp.Timestamp(0), [GlobalWindow()], on_time_only), + WindowedValue( + 'c', timestamp.Timestamp(33), [GlobalWindow()], late_firing), + WindowedValue( + 'd', GlobalWindow().max_timestamp(), [GlobalWindow()], no_firing) + ] + + after_reshuffle = ( + pipeline + | 'Create' >> beam.Create(elements) + | 'Reshuffle' >> beam.Reshuffle()) + + assert_that( + after_reshuffle, + equal_to(expected), + label='CheckMetadataPreserved', + reify_windows=True) + @pytest.mark.it_validatesrunner def test_reshuffle_preserves_timestamps(self): with TestPipeline() as pipeline: