Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class FlinkStreamingImpulseSource(PTransform):
def expand(self, pbegin):
assert isinstance(pbegin, pvalue.PBegin), (
'Input to transform must be a PBegin but found %s' % pbegin)
return pvalue.PCollection(pbegin.pipeline)
return pvalue.PCollection(pbegin.pipeline, is_bounded=False)

def get_windowing(self, inputs):
return Windowing(GlobalWindows())
Expand Down
3 changes: 2 additions & 1 deletion sdks/python/apache_beam/io/iobase.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,8 @@ def split_source(unused_impulse):
split.start_position, split.stop_position))))
else:
# Treat Read itself as a primitive.
return pvalue.PCollection(self.pipeline)
return pvalue.PCollection(self.pipeline,
is_bounded=self.source.is_bounded())

def get_windowing(self, unused_inputs):
return core.Windowing(window.GlobalWindows())
Expand Down
60 changes: 60 additions & 0 deletions sdks/python/apache_beam/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,16 @@ def reader(self):
return FakeSource._Reader(self._vals)


class FakeUnboundedSource(NativeSource):
"""Fake unbounded source. Does not work at runtime"""

def reader(self):
return None

def is_bounded(self):
return False


class DoubleParDo(beam.PTransform):
def expand(self, input):
return input | 'Inner' >> beam.Map(lambda a: a * 2)
Expand Down Expand Up @@ -461,6 +471,56 @@ def process(self, element, counter=DoFn.StateParam(BYTES_STATE)):
p.run()
self.assertEqual(pcoll.element_type, str)

def test_track_pcoll_unbounded(self):
pipeline = TestPipeline()
pcoll1 = pipeline | 'read' >> Read(FakeUnboundedSource())
pcoll2 = pcoll1 | 'do1' >> FlatMap(lambda x: [x + 1])
pcoll3 = pcoll2 | 'do2' >> FlatMap(lambda x: [x + 1])
self.assertIs(pcoll1.is_bounded, False)
self.assertIs(pcoll1.is_bounded, False)
self.assertIs(pcoll3.is_bounded, False)

def test_track_pcoll_bounded(self):
pipeline = TestPipeline()
pcoll1 = pipeline | 'label1' >> Create([1, 2, 3])
pcoll2 = pcoll1 | 'do1' >> FlatMap(lambda x: [x + 1])
pcoll3 = pcoll2 | 'do2' >> FlatMap(lambda x: [x + 1])
self.assertIs(pcoll1.is_bounded, True)
self.assertIs(pcoll2.is_bounded, True)
self.assertIs(pcoll3.is_bounded, True)

def test_track_pcoll_bounded_flatten(self):
pipeline = TestPipeline()
pcoll1_a = pipeline | 'label_a' >> Create([1, 2, 3])
pcoll2_a = pcoll1_a | 'do_a' >> FlatMap(lambda x: [x + 1])

pcoll1_b = pipeline | 'label_b' >> Create([1, 2, 3])
pcoll2_b = pcoll1_b | 'do_b' >> FlatMap(lambda x: [x + 1])

merged = (pcoll2_a, pcoll2_b) | beam.Flatten()

self.assertIs(pcoll1_a.is_bounded, True)
self.assertIs(pcoll2_a.is_bounded, True)
self.assertIs(pcoll1_b.is_bounded, True)
self.assertIs(pcoll2_b.is_bounded, True)
self.assertIs(merged.is_bounded, True)

def test_track_pcoll_unbounded_flatten(self):
pipeline = TestPipeline()
pcoll1_bounded = pipeline | 'label1' >> Create([1, 2, 3])
pcoll2_bounded = pcoll1_bounded | 'do1' >> FlatMap(lambda x: [x + 1])

pcoll1_unbounded = pipeline | 'read' >> Read(FakeUnboundedSource())
pcoll2_unbounded = pcoll1_unbounded | 'do2' >> FlatMap(lambda x: [x + 1])

merged = (pcoll2_bounded, pcoll2_unbounded) | beam.Flatten()

self.assertIs(pcoll1_bounded.is_bounded, True)
self.assertIs(pcoll2_bounded.is_bounded, True)
self.assertIs(pcoll1_unbounded.is_bounded, False)
self.assertIs(pcoll2_unbounded.is_bounded, False)
self.assertIs(merged.is_bounded, False)


class DoFnTest(unittest.TestCase):

Expand Down
19 changes: 16 additions & 3 deletions sdks/python/apache_beam/pvalue.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ class PValue(object):
(3) Has a value which is meaningful if the transform was executed.
"""

def __init__(self, pipeline, tag=None, element_type=None, windowing=None):
def __init__(self, pipeline, tag=None, element_type=None, windowing=None,
is_bounded=True):
"""Initializes a PValue with all arguments hidden behind keyword arguments.

Args:
Expand All @@ -78,6 +79,7 @@ def __init__(self, pipeline, tag=None, element_type=None, windowing=None):
# generating this PValue. The field gets initialized when a transform
# gets applied.
self.producer = None
self.is_bounded = is_bounded
if windowing:
self._windowing = windowing

Expand Down Expand Up @@ -142,11 +144,21 @@ def __reduce_ex__(self, unused_version):
# of a closure).
return _InvalidUnpickledPCollection, ()

@staticmethod
def from_(pcoll):
"""Create a PCollection, using another PCollection as a starting point.

Transfers relevant attributes.
"""
return PCollection(pcoll.pipeline, is_bounded=pcoll.is_bounded)

def to_runner_api(self, context):
return beam_runner_api_pb2.PCollection(
unique_name=self._unique_name(),
coder_id=context.coder_id_from_element_type(self.element_type),
is_bounded=beam_runner_api_pb2.IsBounded.BOUNDED,
is_bounded=beam_runner_api_pb2.IsBounded.BOUNDED
if self.is_bounded
else beam_runner_api_pb2.IsBounded.UNBOUNDED,
windowing_strategy_id=context.windowing_strategies.get_id(
self.windowing))

Expand All @@ -165,7 +177,8 @@ def from_runner_api(proto, context):
None,
element_type=context.element_type_from_coder_id(proto.coder_id),
windowing=context.windowing_strategies.get_by_id(
proto.windowing_strategy_id))
proto.windowing_strategy_id),
is_bounded=proto.is_bounded == beam_runner_api_pb2.IsBounded.BOUNDED)


class _InvalidUnpickledPCollection(object):
Expand Down
11 changes: 6 additions & 5 deletions sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,8 @@ def visit_transform(self, transform_node):
new_side_input.pvalue = beam.pvalue.PCollection(
pipeline,
element_type=typehints.KV[
bytes, side_input.pvalue.element_type])
bytes, side_input.pvalue.element_type],
is_bounded=side_input.pvalue.is_bounded)
parent = transform_node.parent or pipeline._root_transform()
map_to_void_key = beam.pipeline.AppliedPTransform(
pipeline,
Expand Down Expand Up @@ -712,7 +713,7 @@ def apply_GroupByKey(self, transform, pcoll, options):
coders.registry.verify_deterministic(
coder.key_coder(), 'GroupByKey operation "%s"' % transform.label)

return pvalue.PCollection(pcoll.pipeline)
return pvalue.PCollection.from_(pcoll)

def run_GroupByKey(self, transform_node, options):
input_tag = transform_node.inputs[0].tag
Expand Down Expand Up @@ -894,7 +895,7 @@ def _pardo_fn_data(transform_node, get_label):
transform_node.inputs[0].windowing)

def apply_CombineValues(self, transform, pcoll, options):
return pvalue.PCollection(pcoll.pipeline)
return pvalue.PCollection.from_(pcoll)

def run_CombineValues(self, transform_node, options):
transform = transform_node.transform
Expand Down Expand Up @@ -947,7 +948,7 @@ def run_CombineValues(self, transform_node, options):
def apply_Read(self, transform, pbegin, options):
if hasattr(transform.source, 'format'):
# Consider native Read to be a primitive for dataflow.
return beam.pvalue.PCollection(pbegin.pipeline)
return beam.pvalue.PCollection.from_(pbegin)
else:
debug_options = options.view_as(DebugOptions)
if (
Expand All @@ -958,7 +959,7 @@ def apply_Read(self, transform, pbegin, options):
return self.apply_PTransform(transform, pbegin, options)
else:
# Custom Read is also a primitive for non-FnAPI on dataflow.
return beam.pvalue.PCollection(pbegin.pipeline)
return beam.pvalue.PCollection.from_(pbegin)

def run_Read(self, transform_node, options):
transform = transform_node.transform
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class Impulse(PTransform):
def expand(self, pbegin):
assert isinstance(pbegin, pvalue.PBegin), (
'Input to Impulse transform must be a PBegin but found %s' % pbegin)
return pvalue.PCollection(pbegin.pipeline)
return pvalue.PCollection(pbegin.pipeline, is_bounded=False)

def get_windowing(self, inputs):
return Windowing(GlobalWindows())
Expand Down
2 changes: 1 addition & 1 deletion sdks/python/apache_beam/runners/direct/direct_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def get_windowing(self, inputs):

def expand(self, pvalue):
# This is handled as a native transform.
return PCollection(self.pipeline)
return PCollection(self.pipeline, is_bounded=self._source.is_bounded())


class _DirectWriteToPubSubFn(DoFn):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(self, process_keyed_elements_transform):
self.sdf = self._process_keyed_elements_transform.sdf

def expand(self, pcoll):
return pvalue.PCollection(pcoll.pipeline)
return pvalue.PCollection.from_(pcoll)

def new_process_fn(self, sdf):
return ProcessFn(
Expand Down
2 changes: 1 addition & 1 deletion sdks/python/apache_beam/runners/sdf_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,4 @@ def __init__(
self.ptransform_side_inputs = ptransform_side_inputs

def expand(self, pcoll):
return pvalue.PCollection(pcoll.pipeline)
return pvalue.PCollection.from_(pcoll)
2 changes: 1 addition & 1 deletion sdks/python/apache_beam/testing/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def get_windowing(self, unused_inputs):
def expand(self, pbegin):
assert isinstance(pbegin, pvalue.PBegin)
self.pipeline = pbegin.pipeline
return pvalue.PCollection(self.pipeline)
return pvalue.PCollection(self.pipeline, is_bounded=False)

def _infer_output_coder(self, input_type=None, input_coder=None):
return self.coder
Expand Down
9 changes: 5 additions & 4 deletions sdks/python/apache_beam/transforms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,7 +1125,7 @@ def expand(self, pcoll):
'key types. Consider adding an input type hint for this transform.',
key_coder, self)

return pvalue.PCollection(pcoll.pipeline)
return pvalue.PCollection.from_(pcoll)

def with_outputs(self, *tags, **main_kw):
"""Returns a tagged tuple allowing access to the outputs of a
Expand Down Expand Up @@ -2041,7 +2041,7 @@ def infer_output_type(self, input_type):

def expand(self, pcoll):
self._check_pcollection(pcoll)
return pvalue.PCollection(pcoll.pipeline)
return pvalue.PCollection.from_(pcoll)


@typehints.with_input_types(typing.Tuple[K, typing.Iterable[V]])
Expand All @@ -2055,7 +2055,7 @@ def __init__(self, windowing):

def expand(self, pcoll):
self._check_pcollection(pcoll)
return pvalue.PCollection(pcoll.pipeline)
return pvalue.PCollection.from_(pcoll)


class _GroupAlsoByWindowDoFn(DoFn):
Expand Down Expand Up @@ -2338,7 +2338,8 @@ def _extract_input_pvalues(self, pvalueish):
def expand(self, pcolls):
for pcoll in pcolls:
self._check_pcollection(pcoll)
result = pvalue.PCollection(self.pipeline)
is_bounded = all(pcoll.is_bounded for pcoll in pcolls)
result = pvalue.PCollection(self.pipeline, is_bounded=is_bounded)
result.element_type = typehints.Union[
tuple(pcoll.element_type for pcoll in pcolls)]
return result
Expand Down