Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 47 additions & 11 deletions sdks/python/apache_beam/io/gcp/pubsub_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@

import hamcrest as hc

import apache_beam as beam
from apache_beam.io.gcp.pubsub import ReadStringsFromPubSub
from apache_beam.io.gcp.pubsub import WriteStringsToPubSub
from apache_beam.io.gcp.pubsub import _PubSubPayloadSink
from apache_beam.io.gcp.pubsub import _PubSubPayloadSource
from apache_beam.options.pipeline_options import StandardOptions
from apache_beam.runners.direct.direct_runner import _get_transform_overrides
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.transforms.display import DisplayData
from apache_beam.transforms.display_test import DisplayDataItemMatcher
Expand All @@ -40,28 +43,51 @@


@unittest.skipIf(pubsub is None, 'GCP dependencies are not installed')
class TestReadStringsFromPubSub(unittest.TestCase):
class TestReadStringsFromPubSubOverride(unittest.TestCase):
def test_expand_with_topic(self):
p = TestPipeline()
pcoll = p | ReadStringsFromPubSub('projects/fakeprj/topics/a_topic',
None, 'a_label')
# Ensure that the output type is str
p.options.view_as(StandardOptions).streaming = True
pcoll = (p
| ReadStringsFromPubSub('projects/fakeprj/topics/a_topic',
None, 'a_label')
| beam.Map(lambda x: x))
# Ensure that the output type is str.
self.assertEqual(unicode, pcoll.element_type)

# Apply the necessary PTransformOverrides.
overrides = _get_transform_overrides(p.options)
p.replace_all(overrides)

# Note that the direct output of ReadStringsFromPubSub will be replaced
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests seem rather brittle. Is there a better way to test this transform application than grabbing the internal source and verifying a couple of properties on it. https://beam.apache.org/documentation/pipelines/test-your-pipeline/

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ping on this comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've addressed this here: #4529 (comment)

Copy link
Contributor Author

@charlesccychen charlesccychen Feb 3, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

robertwb wrote:
These tests seem rather brittle. Is there a better way to test this transform application than grabbing the internal source and verifying a couple of properties on it. https://beam.apache.org/documentation/pipelines/test-your-pipeline/

Thanks, and I agree, but the issue is that there really isn't anything to test.

The test here isn't a test of the transform; rather, it was (and still is) testing the behavior of replacing the transform with the correct DirectRunner replacement. The transform itself is just a wrapper with runner-specific overrides.

I changed the test name to reflect the intended target of the test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I meant https://beam.apache.org/contribute/ptransform-style-guide/#testing-transform-construction-and-validation

If this is about the direct runner, we should put it into the direct runner tests. Best is if we could create a mock/in memory PubSub and make sure this works end-to-end (on any runner).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But, as mentioned, fixing these existing tests should not block this PR. Please file a JIRA.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# by a PTransformOverride, so we use a no-op Map.
read_transform = pcoll.producer.inputs[0].producer.transform

# Ensure that the properties passed through correctly
source = pcoll.producer.transform._source
source = read_transform._source
self.assertEqual('a_topic', source.topic_name)
self.assertEqual('a_label', source.id_label)

def test_expand_with_subscription(self):
p = TestPipeline()
pcoll = p | ReadStringsFromPubSub(
None, 'projects/fakeprj/subscriptions/a_subscription', 'a_label')
p.options.view_as(StandardOptions).streaming = True
pcoll = (p
| ReadStringsFromPubSub(
None, 'projects/fakeprj/subscriptions/a_subscription',
'a_label')
| beam.Map(lambda x: x))
# Ensure that the output type is str
self.assertEqual(unicode, pcoll.element_type)

# Apply the necessary PTransformOverrides.
overrides = _get_transform_overrides(p.options)
p.replace_all(overrides)

# Note that the direct output of ReadStringsFromPubSub will be replaced
# by a PTransformOverride, so we use a no-op Map.
read_transform = pcoll.producer.inputs[0].producer.transform

# Ensure that the properties passed through correctly
source = pcoll.producer.transform._source
source = read_transform._source
self.assertEqual('a_subscription', source.subscription_name)
self.assertEqual('a_label', source.id_label)

Expand All @@ -80,12 +106,22 @@ def test_expand_with_both_topic_and_subscription(self):
class TestWriteStringsToPubSub(unittest.TestCase):
def test_expand(self):
p = TestPipeline()
pdone = (p
p.options.view_as(StandardOptions).streaming = True
pcoll = (p
| ReadStringsFromPubSub('projects/fakeprj/topics/baz')
| WriteStringsToPubSub('projects/fakeprj/topics/a_topic'))
| WriteStringsToPubSub('projects/fakeprj/topics/a_topic')
| beam.Map(lambda x: x))

# Apply the necessary PTransformOverrides.
overrides = _get_transform_overrides(p.options)
p.replace_all(overrides)

# Note that the direct output of ReadStringsFromPubSub will be replaced
# by a PTransformOverride, so we use a no-op Map.
write_transform = pcoll.producer.inputs[0].producer.transform

# Ensure that the properties passed through correctly
self.assertEqual('a_topic', pdone.producer.transform.dofn.topic_name)
self.assertEqual('a_topic', write_transform.dofn.topic_name)


@unittest.skipIf(pubsub is None, 'GCP dependencies are not installed')
Expand Down
34 changes: 22 additions & 12 deletions sdks/python/apache_beam/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from apache_beam.options.pipeline_options import TypeOptions
from apache_beam.options.pipeline_options_validator import PipelineOptionsValidator
from apache_beam.pvalue import PCollection
from apache_beam.pvalue import PDone
from apache_beam.runners import PipelineRunner
from apache_beam.runners import create_runner
from apache_beam.transforms import ptransform
Expand Down Expand Up @@ -180,7 +181,6 @@ def _remove_labels_recursively(self, applied_transform):
def _replace(self, override):

assert isinstance(override, PTransformOverride)
matcher = override.get_matcher()

output_map = {}
output_replacements = {}
Expand All @@ -193,10 +193,12 @@ def __init__(self, pipeline):
self.pipeline = pipeline

def _replace_if_needed(self, original_transform_node):
if matcher(original_transform_node):
if override.matches(original_transform_node):
assert isinstance(original_transform_node, AppliedPTransform)
replacement_transform = override.get_replacement_transform(
original_transform_node.transform)
if replacement_transform is original_transform_node.transform:
return

replacement_transform_node = AppliedPTransform(
original_transform_node.parent, replacement_transform,
Expand Down Expand Up @@ -227,6 +229,10 @@ def _replace_if_needed(self, original_transform_node):
'have a single input. Tried to replace input of '
'AppliedPTransform %r that has %d inputs',
original_transform_node, len(inputs))
elif len(inputs) == 1:
input_node = inputs[0]
elif len(inputs) == 0:
input_node = pvalue.PBegin(self)

# We have to add the new AppliedTransform to the stack before expand()
# and pop it out later to make sure that parts get added correctly.
Expand All @@ -239,16 +245,18 @@ def _replace_if_needed(self, original_transform_node):
# with labels of the children of the original.
self.pipeline._remove_labels_recursively(original_transform_node)

new_output = replacement_transform.expand(inputs[0])
new_output = replacement_transform.expand(input_node)
replacement_transform_node.add_output(new_output)
if not new_output.producer:
new_output.producer = replacement_transform_node

# We only support replacing transforms with a single output with
# another transform that produces a single output.
# TODO: Support replacing PTransforms with multiple outputs.
if (len(original_transform_node.outputs) > 1 or
not isinstance(
original_transform_node.outputs[None], PCollection) or
not isinstance(new_output, PCollection)):
not isinstance(original_transform_node.outputs[None],
(PCollection, PDone)) or
not isinstance(new_output, (PCollection, PDone))):
raise NotImplementedError(
'PTransform overriding is only supported for PTransforms that '
'have a single output. Tried to replace output of '
Expand Down Expand Up @@ -314,11 +322,10 @@ def visit_transform(self, transform_node):
transform.inputs = input_replacements[transform]

def _check_replacement(self, override):
matcher = override.get_matcher()

class ReplacementValidator(PipelineVisitor):
def visit_transform(self, transform_node):
if matcher(transform_node):
if override.matches(transform_node):
raise RuntimeError('Transform node %r was not replaced as expected.',
transform_node)

Expand Down Expand Up @@ -852,12 +859,14 @@ class PTransformOverride(object):
__metaclass__ = abc.ABCMeta

@abc.abstractmethod
def get_matcher(self):
"""Gives a matcher that will be used to to perform this override.
def matches(self, applied_ptransform):
"""Determines whether the given AppliedPTransform matches.

Args:
applied_ptransform: AppliedPTransform to be matched.

Returns:
a callable that takes an AppliedPTransform as a parameter and returns a
boolean as a result.
a bool indicating whether the given AppliedPTransform is a match.
"""
raise NotImplementedError

Expand All @@ -867,6 +876,7 @@ def get_replacement_transform(self, ptransform):

Args:
ptransform: PTransform to be replaced.

Returns:
A PTransform that will be the replacement for the PTransform given as an
argument.
Expand Down
9 changes: 3 additions & 6 deletions sdks/python/apache_beam/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,20 +310,17 @@ def raise_exception(exn):
'apache_beam.runners.direct.direct_runner._get_transform_overrides')
def test_ptransform_overrides(self, file_system_override_mock):

def my_par_do_matcher(applied_ptransform):
return isinstance(applied_ptransform.transform, DoubleParDo)

class MyParDoOverride(PTransformOverride):

def get_matcher(self):
return my_par_do_matcher
def matches(self, applied_ptransform):
return isinstance(applied_ptransform.transform, DoubleParDo)

def get_replacement_transform(self, ptransform):
if isinstance(ptransform, DoubleParDo):
return TripleParDo()
raise ValueError('Unsupported type of transform: %r', ptransform)

def get_overrides():
def get_overrides(unused_pipeline_options):
return [MyParDoOverride()]

file_system_override_mock.side_effect = get_overrides
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,10 @@ def test_group_by_key_input_visitor_with_valid_inputs(self):
pcoll2.element_type = typehints.Any
pcoll3.element_type = typehints.KV[typehints.Any, typehints.Any]
for pcoll in [pcoll1, pcoll2, pcoll3]:
applied = AppliedPTransform(None, transform, "label", [pcoll])
applied.outputs[None] = PCollection(None)
DataflowRunner.group_by_key_input_visitor().visit_transform(
AppliedPTransform(None, transform, "label", [pcoll]))
applied)
self.assertEqual(pcoll.element_type,
typehints.KV[typehints.Any, typehints.Any])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,7 @@
class CreatePTransformOverride(PTransformOverride):
"""A ``PTransformOverride`` for ``Create`` in streaming mode."""

def get_matcher(self):
return self.is_streaming_create

@staticmethod
def is_streaming_create(applied_ptransform):
def matches(self, applied_ptransform):
# Imported here to avoid circular dependencies.
# pylint: disable=wrong-import-order, wrong-import-position
from apache_beam import Create
Expand Down
Loading