diff --git a/sdks/python/apache_beam/runners/portability/flink_runner_test.py b/sdks/python/apache_beam/runners/portability/flink_runner_test.py index 94e30bf5b9d5..e97ce4905358 100644 --- a/sdks/python/apache_beam/runners/portability/flink_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/flink_runner_test.py @@ -398,6 +398,9 @@ def test_callbacks_with_exception(self): def test_register_finalizations(self): raise unittest.SkipTest("BEAM-11021") + def test_custom_merging_window(self): + raise unittest.SkipTest("BEAM-11004") + # Inherits all other tests. diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py index bc691235f6b2..a08aa5fa8475 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py @@ -24,6 +24,8 @@ import collections import copy import itertools +import uuid +import weakref from typing import TYPE_CHECKING from typing import Any from typing import DefaultDict @@ -55,6 +57,7 @@ from apache_beam.runners.portability.fn_api_runner.translations import split_buffer_id from apache_beam.runners.portability.fn_api_runner.translations import unique_name from apache_beam.runners.worker import bundle_processor +from apache_beam.transforms import core from apache_beam.transforms import trigger from apache_beam.transforms import window from apache_beam.transforms.window import GlobalWindow @@ -69,7 +72,6 @@ from apache_beam.runners.portability.fn_api_runner.fn_runner import DataOutput from apache_beam.runners.portability.fn_api_runner.fn_runner import OutputTimers from apache_beam.runners.portability.fn_api_runner.translations import DataSideInput - from apache_beam.transforms import core from apache_beam.transforms.window import BoundedWindow ENCODED_IMPULSE_VALUE = WindowedValueCoder( @@ -338,6 +340,222 @@ def from_runner_api_parameter(window_coder_id, context): context.coders[window_coder_id.decode('utf-8')]) +class GenericMergingWindowFn(window.WindowFn): + + URN = 'internal-generic-merging' + + TO_SDK_TRANSFORM = 'read' + FROM_SDK_TRANSFORM = 'write' + + _HANDLES = {} # type: Dict[str, GenericMergingWindowFn] + + def __init__(self, execution_context, windowing_strategy_proto): + # type: (FnApiRunnerExecutionContext, beam_runner_api_pb2.WindowingStrategy) -> None + self._worker_handler = None # type: Optional[worker_handlers.WorkerHandler] + self._handle_id = handle_id = uuid.uuid4().hex + self._HANDLES[handle_id] = self + # ExecutionContexts are expensive, we don't want to keep them in the + # static dictionary forever. Instead we hold a weakref and pop self + # out of the dict once this context goes away. + self._execution_context_ref_obj = weakref.ref( + execution_context, lambda _: self._HANDLES.pop(handle_id, None)) + self._windowing_strategy_proto = windowing_strategy_proto + self._counter = 0 + # Lazily created in make_process_bundle_descriptor() + self._process_bundle_descriptor = None + self._bundle_processor_id = None # type: Optional[str] + self.windowed_input_coder_impl = None # type: Optional[CoderImpl] + self.windowed_output_coder_impl = None # type: Optional[CoderImpl] + + def _execution_context_ref(self): + # type: () -> FnApiRunnerExecutionContext + result = self._execution_context_ref_obj() + assert result is not None + return result + + def payload(self): + # type: () -> bytes + return self._handle_id.encode('utf-8') + + @staticmethod + @window.urns.RunnerApiFn.register_urn(URN, bytes) + def from_runner_api_parameter(handle_id, unused_context): + # type: (bytes, Any) -> GenericMergingWindowFn + return GenericMergingWindowFn._HANDLES[handle_id.decode('utf-8')] + + def assign(self, assign_context): + # type: (window.WindowFn.AssignContext) -> Iterable[window.BoundedWindow] + raise NotImplementedError() + + def merge(self, merge_context): + # type: (window.WindowFn.MergeContext) -> None + worker_handler = self.worker_handle() + + assert self.windowed_input_coder_impl is not None + assert self.windowed_output_coder_impl is not None + process_bundle_id = self.uid('process') + to_worker = worker_handler.data_conn.output_stream( + process_bundle_id, self.TO_SDK_TRANSFORM) + to_worker.write( + self.windowed_input_coder_impl.encode_nested( + window.GlobalWindows.windowed_value((b'', merge_context.windows)))) + to_worker.close() + + process_bundle_req = beam_fn_api_pb2.InstructionRequest( + instruction_id=process_bundle_id, + process_bundle=beam_fn_api_pb2.ProcessBundleRequest( + process_bundle_descriptor_id=self._bundle_processor_id)) + result_future = worker_handler.control_conn.push(process_bundle_req) + for output in worker_handler.data_conn.input_elements( + process_bundle_id, [self.FROM_SDK_TRANSFORM], + abort_callback=lambda: bool(result_future.is_done() and result_future. + get().error)): + if isinstance(output, beam_fn_api_pb2.Elements.Data): + windowed_result = self.windowed_output_coder_impl.decode_nested( + output.data) + for merge_result, originals in windowed_result.value[1][1]: + merge_context.merge(originals, merge_result) + else: + raise RuntimeError("Unexpected data: %s" % output) + + result = result_future.get() + if result.error: + raise RuntimeError(result.error) + # The result was "returned" via the merge callbacks on merge_context above. + + def get_window_coder(self): + # type: () -> coders.Coder + return self._execution_context_ref().pipeline_context.coders[ + self._windowing_strategy_proto.window_coder_id] + + def worker_handle(self): + # type: () -> worker_handlers.WorkerHandler + if self._worker_handler is None: + worker_handler_manager = self._execution_context_ref( + ).worker_handler_manager + self._worker_handler = worker_handler_manager.get_worker_handlers( + self._windowing_strategy_proto.environment_id, 1)[0] + process_bundle_decriptor = self.make_process_bundle_descriptor( + self._worker_handler.data_api_service_descriptor(), + self._worker_handler.state_api_service_descriptor()) + worker_handler_manager.register_process_bundle_descriptor( + process_bundle_decriptor) + return self._worker_handler + + def make_process_bundle_descriptor( + self, data_api_service_descriptor, state_api_service_descriptor): + # type: (Optional[endpoints_pb2.ApiServiceDescriptor], Optional[endpoints_pb2.ApiServiceDescriptor]) -> beam_fn_api_pb2.ProcessBundleDescriptor + + """Creates a ProcessBundleDescriptor for invoking the WindowFn's + merge operation. + """ + def make_channel_payload(coder_id): + # type: (str) -> bytes + data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id) + if data_api_service_descriptor: + data_spec.api_service_descriptor.url = (data_api_service_descriptor.url) + return data_spec.SerializeToString() + + pipeline_context = self._execution_context_ref().pipeline_context + global_windowing_strategy_id = self.uid('global_windowing_strategy') + global_windowing_strategy_proto = core.Windowing( + window.GlobalWindows()).to_runner_api(pipeline_context) + coders = dict(pipeline_context.coders.get_id_to_proto_map()) + + def make_coder(urn, *components): + # type: (str, str) -> str + coder_proto = beam_runner_api_pb2.Coder( + spec=beam_runner_api_pb2.FunctionSpec(urn=urn), + component_coder_ids=components) + coder_id = self.uid('coder') + coders[coder_id] = coder_proto + pipeline_context.coders.put_proto(coder_id, coder_proto) + return coder_id + + bytes_coder_id = make_coder(common_urns.coders.BYTES.urn) + window_coder_id = self._windowing_strategy_proto.window_coder_id + global_window_coder_id = make_coder(common_urns.coders.GLOBAL_WINDOW.urn) + iter_window_coder_id = make_coder( + common_urns.coders.ITERABLE.urn, window_coder_id) + input_coder_id = make_coder( + common_urns.coders.KV.urn, bytes_coder_id, iter_window_coder_id) + output_coder_id = make_coder( + common_urns.coders.KV.urn, + bytes_coder_id, + make_coder( + common_urns.coders.KV.urn, + iter_window_coder_id, + make_coder( + common_urns.coders.ITERABLE.urn, + make_coder( + common_urns.coders.KV.urn, + window_coder_id, + iter_window_coder_id)))) + windowed_input_coder_id = make_coder( + common_urns.coders.WINDOWED_VALUE.urn, + input_coder_id, + global_window_coder_id) + windowed_output_coder_id = make_coder( + common_urns.coders.WINDOWED_VALUE.urn, + output_coder_id, + global_window_coder_id) + + self.windowed_input_coder_impl = pipeline_context.coders[ + windowed_input_coder_id].get_impl() + self.windowed_output_coder_impl = pipeline_context.coders[ + windowed_output_coder_id].get_impl() + + self._bundle_processor_id = self.uid('merge_windows') + return beam_fn_api_pb2.ProcessBundleDescriptor( + id=self._bundle_processor_id, + transforms={ + self.TO_SDK_TRANSFORM: beam_runner_api_pb2.PTransform( + unique_name='MergeWindows/Read', + spec=beam_runner_api_pb2.FunctionSpec( + urn=bundle_processor.DATA_INPUT_URN, + payload=make_channel_payload(windowed_input_coder_id)), + outputs={'input': 'input'}), + 'Merge': beam_runner_api_pb2.PTransform( + unique_name='MergeWindows/Merge', + spec=beam_runner_api_pb2.FunctionSpec( + urn=common_urns.primitives.MERGE_WINDOWS.urn, + payload=self._windowing_strategy_proto.window_fn. + SerializeToString()), + inputs={'input': 'input'}, + outputs={'output': 'output'}), + self.FROM_SDK_TRANSFORM: beam_runner_api_pb2.PTransform( + unique_name='MergeWindows/Write', + spec=beam_runner_api_pb2.FunctionSpec( + urn=bundle_processor.DATA_OUTPUT_URN, + payload=make_channel_payload(windowed_output_coder_id)), + inputs={'output': 'output'}), + }, + pcollections={ + 'input': beam_runner_api_pb2.PCollection( + unique_name='input', + windowing_strategy_id=global_windowing_strategy_id, + coder_id=input_coder_id), + 'output': beam_runner_api_pb2.PCollection( + unique_name='output', + windowing_strategy_id=global_windowing_strategy_id, + coder_id=output_coder_id), + }, + coders=coders, + windowing_strategies={ + global_windowing_strategy_id: global_windowing_strategy_proto, + }, + environments=dict( + self._execution_context_ref().pipeline_components.environments. + items()), + state_api_service_descriptor=state_api_service_descriptor, + timer_api_service_descriptor=data_api_service_descriptor) + + def uid(self, name=''): + # type: (str) -> str + self._counter += 1 + return '%s_%s_%s' % (self._handle_id, name, self._counter) + + class FnApiRunnerExecutionContext(object): """ :var pcoll_buffers: (dict): Mapping of @@ -443,23 +661,22 @@ def _make_safe_windowing_strategy(self, id): windowing_strategy_proto = self.pipeline_components.windowing_strategies[id] if windowing_strategy_proto.window_fn.urn in SAFE_WINDOW_FNS: return id - elif (windowing_strategy_proto.merge_status == - beam_runner_api_pb2.MergeStatus.NON_MERGING) or True: + else: safe_id = id + '_safe' while safe_id in self.pipeline_components.windowing_strategies: safe_id += '_' safe_proto = copy.copy(windowing_strategy_proto) - safe_proto.window_fn.urn = GenericNonMergingWindowFn.URN - safe_proto.window_fn.payload = ( - windowing_strategy_proto.window_coder_id.encode('utf-8')) + if (windowing_strategy_proto.merge_status == + beam_runner_api_pb2.MergeStatus.NON_MERGING): + safe_proto.window_fn.urn = GenericNonMergingWindowFn.URN + safe_proto.window_fn.payload = ( + windowing_strategy_proto.window_coder_id.encode('utf-8')) + else: + window_fn = GenericMergingWindowFn(self, windowing_strategy_proto) + safe_proto.window_fn.urn = GenericMergingWindowFn.URN + safe_proto.window_fn.payload = window_fn.payload() self.pipeline_context.windowing_strategies.put_proto(safe_id, safe_proto) return safe_id - elif windowing_strategy_proto.window_fn.urn == python_urns.PICKLED_WINDOWFN: - return id - else: - raise NotImplementedError( - '[BEAM-10119] Unknown merging WindowFn: %s' % - windowing_strategy_proto) @property def state_servicer(self): diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py index 7bbdd3cca40e..912074eb516d 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import collections +import gc import logging import os import random @@ -46,6 +47,7 @@ from tenacity import stop_after_attempt import apache_beam as beam +from apache_beam.coders import coders from apache_beam.coders.coders import StrUtf8Coder from apache_beam.io import restriction_trackers from apache_beam.io.watermark_estimators import ManualWatermarkEstimator @@ -780,6 +782,21 @@ def test_windowing(self): | beam.Map(lambda k_vs1: (k_vs1[0], sorted(k_vs1[1])))) assert_that(res, equal_to([('k', [1, 2]), ('k', [100, 101, 102])])) + def test_custom_merging_window(self): + with self.create_pipeline() as p: + res = ( + p + | beam.Create([1, 2, 100, 101, 102]) + | beam.Map(lambda t: window.TimestampedValue(('k', t), t)) + | beam.WindowInto(CustomMergingWindowFn()) + | beam.GroupByKey() + | beam.Map(lambda k_vs1: (k_vs1[0], sorted(k_vs1[1])))) + assert_that( + res, equal_to([('k', [1]), ('k', [101]), ('k', [2, 100, 102])])) + gc.collect() + from apache_beam.runners.portability.fn_api_runner.execution import GenericMergingWindowFn + self.assertEqual(GenericMergingWindowFn._HANDLES, {}) + @unittest.skip('BEAM-9119: test is flaky') def test_large_elements(self): with self.create_pipeline() as p: @@ -2002,6 +2019,26 @@ def test_gbk_many_values(self): assert_that(r, equal_to([VALUES_PER_ELEMENT * NUM_OF_ELEMENTS])) +# TODO(robertwb): Why does pickling break when this is inlined? +class CustomMergingWindowFn(window.WindowFn): + def assign(self, assign_context): + return [ + window.IntervalWindow( + assign_context.timestamp, assign_context.timestamp + 1) + ] + + def merge(self, merge_context): + evens = [w for w in merge_context.windows if w.start % 2 == 0] + if evens: + merge_context.merge( + evens, + window.IntervalWindow( + min(w.start for w in evens), max(w.end for w in evens))) + + def get_window_coder(self): + return coders.IntervalWindowCoder() + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/apache_beam/runners/portability/spark_runner_test.py b/sdks/python/apache_beam/runners/portability/spark_runner_test.py index 062e06f75a54..3473cada687c 100644 --- a/sdks/python/apache_beam/runners/portability/spark_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/spark_runner_test.py @@ -181,6 +181,9 @@ def test_flattened_side_input(self): super(SparkRunnerTest, self).test_flattened_side_input(with_transcoding=False) + def test_custom_merging_window(self): + raise unittest.SkipTest("BEAM-11004") + # Inherits all other tests from PortableRunnerTest. diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py index 1e1c0586b6a6..f05228e95b14 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py @@ -76,6 +76,7 @@ from apache_beam.transforms import core from apache_beam.transforms import sideinputs from apache_beam.transforms import userstate +from apache_beam.transforms import window from apache_beam.utils import counters from apache_beam.utils import proto_utils from apache_beam.utils import timestamp @@ -1856,6 +1857,48 @@ def process(self, element): factory, transform_id, transform_proto, consumers, MapWindows()) +@BeamTransformFactory.register_urn( + common_urns.primitives.MERGE_WINDOWS.urn, beam_runner_api_pb2.FunctionSpec) +def create_merge_windows( + factory, # type: BeamTransformFactory + transform_id, # type: str + transform_proto, # type: beam_runner_api_pb2.PTransform + mapping_fn_spec, # type: beam_runner_api_pb2.FunctionSpec + consumers # type: Dict[str, List[operations.Operation]] +): + assert mapping_fn_spec.urn == python_urns.PICKLED_WINDOWFN + window_fn = pickler.loads(mapping_fn_spec.payload) + + class MergeWindows(beam.DoFn): + def process(self, element): + nonce, windows = element + + original_windows = set(windows) # type: Set[window.BoundedWindow] + merged_windows = collections.defaultdict( + set + ) # type: MutableMapping[window.BoundedWindow, Set[window.BoundedWindow]] + + class RecordingMergeContext(window.WindowFn.MergeContext): + def merge( + self, + to_be_merged, # type: Iterable[window.BoundedWindow] + merge_result, # type: window.BoundedWindow + ): + originals = merged_windows[merge_result] + for window in to_be_merged: + if window in original_windows: + originals.add(window) + original_windows.remove(window) + else: + originals.update(merged_windows.pop(window)) + + window_fn.merge(RecordingMergeContext(windows)) + yield nonce, (original_windows, merged_windows.items()) + + return _create_simple_pardo_operation( + factory, transform_id, transform_proto, consumers, MergeWindows()) + + @BeamTransformFactory.register_urn(common_urns.primitives.TO_STRING.urn, None) def create_to_string_fn( factory, # type: BeamTransformFactory