diff --git a/sdks/python/apache_beam/io/flink/flink_streaming_impulse_source.py b/sdks/python/apache_beam/io/flink/flink_streaming_impulse_source.py index d4a0dfc81fc1..1edf743408a0 100644 --- a/sdks/python/apache_beam/io/flink/flink_streaming_impulse_source.py +++ b/sdks/python/apache_beam/io/flink/flink_streaming_impulse_source.py @@ -38,7 +38,7 @@ class FlinkStreamingImpulseSource(PTransform): def expand(self, pbegin): assert isinstance(pbegin, pvalue.PBegin), ( 'Input to transform must be a PBegin but found %s' % pbegin) - return pvalue.PCollection(pbegin.pipeline) + return pvalue.PCollection(pbegin.pipeline, is_bounded=False) def get_windowing(self, inputs): return Windowing(GlobalWindows()) diff --git a/sdks/python/apache_beam/io/iobase.py b/sdks/python/apache_beam/io/iobase.py index 6763c57efa80..605c1bf0984d 100644 --- a/sdks/python/apache_beam/io/iobase.py +++ b/sdks/python/apache_beam/io/iobase.py @@ -880,7 +880,8 @@ def split_source(unused_impulse): split.start_position, split.stop_position)))) else: # Treat Read itself as a primitive. - return pvalue.PCollection(self.pipeline) + return pvalue.PCollection(self.pipeline, + is_bounded=self.source.is_bounded()) def get_windowing(self, unused_inputs): return core.Windowing(window.GlobalWindows()) diff --git a/sdks/python/apache_beam/pipeline_test.py b/sdks/python/apache_beam/pipeline_test.py index 2b62745afefa..d1d9d0dfd791 100644 --- a/sdks/python/apache_beam/pipeline_test.py +++ b/sdks/python/apache_beam/pipeline_test.py @@ -90,6 +90,16 @@ def reader(self): return FakeSource._Reader(self._vals) +class FakeUnboundedSource(NativeSource): + """Fake unbounded source. Does not work at runtime""" + + def reader(self): + return None + + def is_bounded(self): + return False + + class DoubleParDo(beam.PTransform): def expand(self, input): return input | 'Inner' >> beam.Map(lambda a: a * 2) @@ -461,6 +471,56 @@ def process(self, element, counter=DoFn.StateParam(BYTES_STATE)): p.run() self.assertEqual(pcoll.element_type, str) + def test_track_pcoll_unbounded(self): + pipeline = TestPipeline() + pcoll1 = pipeline | 'read' >> Read(FakeUnboundedSource()) + pcoll2 = pcoll1 | 'do1' >> FlatMap(lambda x: [x + 1]) + pcoll3 = pcoll2 | 'do2' >> FlatMap(lambda x: [x + 1]) + self.assertIs(pcoll1.is_bounded, False) + self.assertIs(pcoll1.is_bounded, False) + self.assertIs(pcoll3.is_bounded, False) + + def test_track_pcoll_bounded(self): + pipeline = TestPipeline() + pcoll1 = pipeline | 'label1' >> Create([1, 2, 3]) + pcoll2 = pcoll1 | 'do1' >> FlatMap(lambda x: [x + 1]) + pcoll3 = pcoll2 | 'do2' >> FlatMap(lambda x: [x + 1]) + self.assertIs(pcoll1.is_bounded, True) + self.assertIs(pcoll2.is_bounded, True) + self.assertIs(pcoll3.is_bounded, True) + + def test_track_pcoll_bounded_flatten(self): + pipeline = TestPipeline() + pcoll1_a = pipeline | 'label_a' >> Create([1, 2, 3]) + pcoll2_a = pcoll1_a | 'do_a' >> FlatMap(lambda x: [x + 1]) + + pcoll1_b = pipeline | 'label_b' >> Create([1, 2, 3]) + pcoll2_b = pcoll1_b | 'do_b' >> FlatMap(lambda x: [x + 1]) + + merged = (pcoll2_a, pcoll2_b) | beam.Flatten() + + self.assertIs(pcoll1_a.is_bounded, True) + self.assertIs(pcoll2_a.is_bounded, True) + self.assertIs(pcoll1_b.is_bounded, True) + self.assertIs(pcoll2_b.is_bounded, True) + self.assertIs(merged.is_bounded, True) + + def test_track_pcoll_unbounded_flatten(self): + pipeline = TestPipeline() + pcoll1_bounded = pipeline | 'label1' >> Create([1, 2, 3]) + pcoll2_bounded = pcoll1_bounded | 'do1' >> FlatMap(lambda x: [x + 1]) + + pcoll1_unbounded = pipeline | 'read' >> Read(FakeUnboundedSource()) + pcoll2_unbounded = pcoll1_unbounded | 'do2' >> FlatMap(lambda x: [x + 1]) + + merged = (pcoll2_bounded, pcoll2_unbounded) | beam.Flatten() + + self.assertIs(pcoll1_bounded.is_bounded, True) + self.assertIs(pcoll2_bounded.is_bounded, True) + self.assertIs(pcoll1_unbounded.is_bounded, False) + self.assertIs(pcoll2_unbounded.is_bounded, False) + self.assertIs(merged.is_bounded, False) + class DoFnTest(unittest.TestCase): diff --git a/sdks/python/apache_beam/pvalue.py b/sdks/python/apache_beam/pvalue.py index b60d0cfc5746..7c9d869289a6 100644 --- a/sdks/python/apache_beam/pvalue.py +++ b/sdks/python/apache_beam/pvalue.py @@ -63,7 +63,8 @@ class PValue(object): (3) Has a value which is meaningful if the transform was executed. """ - def __init__(self, pipeline, tag=None, element_type=None, windowing=None): + def __init__(self, pipeline, tag=None, element_type=None, windowing=None, + is_bounded=True): """Initializes a PValue with all arguments hidden behind keyword arguments. Args: @@ -78,6 +79,7 @@ def __init__(self, pipeline, tag=None, element_type=None, windowing=None): # generating this PValue. The field gets initialized when a transform # gets applied. self.producer = None + self.is_bounded = is_bounded if windowing: self._windowing = windowing @@ -142,11 +144,21 @@ def __reduce_ex__(self, unused_version): # of a closure). return _InvalidUnpickledPCollection, () + @staticmethod + def from_(pcoll): + """Create a PCollection, using another PCollection as a starting point. + + Transfers relevant attributes. + """ + return PCollection(pcoll.pipeline, is_bounded=pcoll.is_bounded) + def to_runner_api(self, context): return beam_runner_api_pb2.PCollection( unique_name=self._unique_name(), coder_id=context.coder_id_from_element_type(self.element_type), - is_bounded=beam_runner_api_pb2.IsBounded.BOUNDED, + is_bounded=beam_runner_api_pb2.IsBounded.BOUNDED + if self.is_bounded + else beam_runner_api_pb2.IsBounded.UNBOUNDED, windowing_strategy_id=context.windowing_strategies.get_id( self.windowing)) @@ -165,7 +177,8 @@ def from_runner_api(proto, context): None, element_type=context.element_type_from_coder_id(proto.coder_id), windowing=context.windowing_strategies.get_by_id( - proto.windowing_strategy_id)) + proto.windowing_strategy_id), + is_bounded=proto.is_bounded == beam_runner_api_pb2.IsBounded.BOUNDED) class _InvalidUnpickledPCollection(object): diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py index ddc0ccdb2af7..383b9715e14a 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py @@ -310,7 +310,8 @@ def visit_transform(self, transform_node): new_side_input.pvalue = beam.pvalue.PCollection( pipeline, element_type=typehints.KV[ - bytes, side_input.pvalue.element_type]) + bytes, side_input.pvalue.element_type], + is_bounded=side_input.pvalue.is_bounded) parent = transform_node.parent or pipeline._root_transform() map_to_void_key = beam.pipeline.AppliedPTransform( pipeline, @@ -712,7 +713,7 @@ def apply_GroupByKey(self, transform, pcoll, options): coders.registry.verify_deterministic( coder.key_coder(), 'GroupByKey operation "%s"' % transform.label) - return pvalue.PCollection(pcoll.pipeline) + return pvalue.PCollection.from_(pcoll) def run_GroupByKey(self, transform_node, options): input_tag = transform_node.inputs[0].tag @@ -894,7 +895,7 @@ def _pardo_fn_data(transform_node, get_label): transform_node.inputs[0].windowing) def apply_CombineValues(self, transform, pcoll, options): - return pvalue.PCollection(pcoll.pipeline) + return pvalue.PCollection.from_(pcoll) def run_CombineValues(self, transform_node, options): transform = transform_node.transform @@ -947,7 +948,7 @@ def run_CombineValues(self, transform_node, options): def apply_Read(self, transform, pbegin, options): if hasattr(transform.source, 'format'): # Consider native Read to be a primitive for dataflow. - return beam.pvalue.PCollection(pbegin.pipeline) + return beam.pvalue.PCollection.from_(pbegin) else: debug_options = options.view_as(DebugOptions) if ( @@ -958,7 +959,7 @@ def apply_Read(self, transform, pbegin, options): return self.apply_PTransform(transform, pbegin, options) else: # Custom Read is also a primitive for non-FnAPI on dataflow. - return beam.pvalue.PCollection(pbegin.pipeline) + return beam.pvalue.PCollection.from_(pbegin) def run_Read(self, transform_node, options): transform = transform_node.transform diff --git a/sdks/python/apache_beam/runners/dataflow/native_io/streaming_create.py b/sdks/python/apache_beam/runners/dataflow/native_io/streaming_create.py index 980ad24c7962..481209ef90b9 100644 --- a/sdks/python/apache_beam/runners/dataflow/native_io/streaming_create.py +++ b/sdks/python/apache_beam/runners/dataflow/native_io/streaming_create.py @@ -61,7 +61,7 @@ class Impulse(PTransform): def expand(self, pbegin): assert isinstance(pbegin, pvalue.PBegin), ( 'Input to Impulse transform must be a PBegin but found %s' % pbegin) - return pvalue.PCollection(pbegin.pipeline) + return pvalue.PCollection(pbegin.pipeline, is_bounded=False) def get_windowing(self, inputs): return Windowing(GlobalWindows()) diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py b/sdks/python/apache_beam/runners/direct/direct_runner.py index 887b9f085477..7ae16a97497d 100644 --- a/sdks/python/apache_beam/runners/direct/direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/direct_runner.py @@ -251,7 +251,7 @@ def get_windowing(self, inputs): def expand(self, pvalue): # This is handled as a native transform. - return PCollection(self.pipeline) + return PCollection(self.pipeline, is_bounded=self._source.is_bounded()) class _DirectWriteToPubSubFn(DoFn): diff --git a/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py b/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py index 641863243d6a..307679018a5b 100644 --- a/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py @@ -85,7 +85,7 @@ def __init__(self, process_keyed_elements_transform): self.sdf = self._process_keyed_elements_transform.sdf def expand(self, pcoll): - return pvalue.PCollection(pcoll.pipeline) + return pvalue.PCollection.from_(pcoll) def new_process_fn(self, sdf): return ProcessFn( diff --git a/sdks/python/apache_beam/runners/sdf_common.py b/sdks/python/apache_beam/runners/sdf_common.py index e0573289f3bb..072d3dc74492 100644 --- a/sdks/python/apache_beam/runners/sdf_common.py +++ b/sdks/python/apache_beam/runners/sdf_common.py @@ -167,4 +167,4 @@ def __init__( self.ptransform_side_inputs = ptransform_side_inputs def expand(self, pcoll): - return pvalue.PCollection(pcoll.pipeline) + return pvalue.PCollection.from_(pcoll) diff --git a/sdks/python/apache_beam/testing/test_stream.py b/sdks/python/apache_beam/testing/test_stream.py index a5e9574331ec..02a860749c4b 100644 --- a/sdks/python/apache_beam/testing/test_stream.py +++ b/sdks/python/apache_beam/testing/test_stream.py @@ -135,7 +135,7 @@ def get_windowing(self, unused_inputs): def expand(self, pbegin): assert isinstance(pbegin, pvalue.PBegin) self.pipeline = pbegin.pipeline - return pvalue.PCollection(self.pipeline) + return pvalue.PCollection(self.pipeline, is_bounded=False) def _infer_output_coder(self, input_type=None, input_coder=None): return self.coder diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index 807e8b5d3e6f..f08d0700c420 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -1125,7 +1125,7 @@ def expand(self, pcoll): 'key types. Consider adding an input type hint for this transform.', key_coder, self) - return pvalue.PCollection(pcoll.pipeline) + return pvalue.PCollection.from_(pcoll) def with_outputs(self, *tags, **main_kw): """Returns a tagged tuple allowing access to the outputs of a @@ -2041,7 +2041,7 @@ def infer_output_type(self, input_type): def expand(self, pcoll): self._check_pcollection(pcoll) - return pvalue.PCollection(pcoll.pipeline) + return pvalue.PCollection.from_(pcoll) @typehints.with_input_types(typing.Tuple[K, typing.Iterable[V]]) @@ -2055,7 +2055,7 @@ def __init__(self, windowing): def expand(self, pcoll): self._check_pcollection(pcoll) - return pvalue.PCollection(pcoll.pipeline) + return pvalue.PCollection.from_(pcoll) class _GroupAlsoByWindowDoFn(DoFn): @@ -2338,7 +2338,8 @@ def _extract_input_pvalues(self, pvalueish): def expand(self, pcolls): for pcoll in pcolls: self._check_pcollection(pcoll) - result = pvalue.PCollection(self.pipeline) + is_bounded = all(pcoll.is_bounded for pcoll in pcolls) + result = pvalue.PCollection(self.pipeline, is_bounded=is_bounded) result.element_type = typehints.Union[ tuple(pcoll.element_type for pcoll in pcolls)] return result