From 16e5feff484d7022046fe59e621ab09f5c8799d9 Mon Sep 17 00:00:00 2001 From: Charles Chen Date: Mon, 29 Jan 2018 10:34:33 -0800 Subject: [PATCH 1/2] [BEAM-3566] Replace apply_* hooks in DirectRunner with PTransformOverrides --- sdks/python/apache_beam/io/gcp/pubsub_test.py | 56 +++- sdks/python/apache_beam/pipeline.py | 17 +- sdks/python/apache_beam/pipeline_test.py | 2 +- .../runners/dataflow/dataflow_runner_test.py | 4 +- .../runners/direct/direct_runner.py | 242 +++++++++++------- .../runners/direct/helper_transforms.py | 8 +- .../apache_beam/transforms/combiners.py | 39 +-- 7 files changed, 238 insertions(+), 130 deletions(-) diff --git a/sdks/python/apache_beam/io/gcp/pubsub_test.py b/sdks/python/apache_beam/io/gcp/pubsub_test.py index 0c4ba02db87a..36d40a1d0bbc 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub_test.py +++ b/sdks/python/apache_beam/io/gcp/pubsub_test.py @@ -22,10 +22,13 @@ import hamcrest as hc +from apache_beam import Map from apache_beam.io.gcp.pubsub import ReadStringsFromPubSub from apache_beam.io.gcp.pubsub import WriteStringsToPubSub from apache_beam.io.gcp.pubsub import _PubSubPayloadSink from apache_beam.io.gcp.pubsub import _PubSubPayloadSource +from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.runners.direct.direct_runner import _get_transform_overrides from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.transforms.display import DisplayData from apache_beam.transforms.display_test import DisplayDataItemMatcher @@ -43,25 +46,48 @@ class TestReadStringsFromPubSub(unittest.TestCase): def test_expand_with_topic(self): p = TestPipeline() - pcoll = p | ReadStringsFromPubSub('projects/fakeprj/topics/a_topic', - None, 'a_label') - # Ensure that the output type is str + p.options.view_as(StandardOptions).streaming = True + pcoll = (p + | ReadStringsFromPubSub('projects/fakeprj/topics/a_topic', + None, 'a_label') + | Map(lambda x: x)) + # Ensure that the output type is str. self.assertEqual(unicode, pcoll.element_type) + # Apply the necessary PTransformOverrides. + overrides = _get_transform_overrides(p.options) + p.replace_all(overrides) + + # Note that the direct output of ReadStringsFromPubSub will be replaced + # by a PTransformOverride, so we use a no-op Map. + read_transform = pcoll.producer.inputs[0].producer.transform + # Ensure that the properties passed through correctly - source = pcoll.producer.transform._source + source = read_transform._source self.assertEqual('a_topic', source.topic_name) self.assertEqual('a_label', source.id_label) def test_expand_with_subscription(self): p = TestPipeline() - pcoll = p | ReadStringsFromPubSub( - None, 'projects/fakeprj/subscriptions/a_subscription', 'a_label') + p.options.view_as(StandardOptions).streaming = True + pcoll = (p + | ReadStringsFromPubSub( + None, 'projects/fakeprj/subscriptions/a_subscription', + 'a_label') + | Map(lambda x: x)) # Ensure that the output type is str self.assertEqual(unicode, pcoll.element_type) + # Apply the necessary PTransformOverrides. + overrides = _get_transform_overrides(p.options) + p.replace_all(overrides) + + # Note that the direct output of ReadStringsFromPubSub will be replaced + # by a PTransformOverride, so we use a no-op Map. + read_transform = pcoll.producer.inputs[0].producer.transform + # Ensure that the properties passed through correctly - source = pcoll.producer.transform._source + source = read_transform._source self.assertEqual('a_subscription', source.subscription_name) self.assertEqual('a_label', source.id_label) @@ -80,12 +106,22 @@ def test_expand_with_both_topic_and_subscription(self): class TestWriteStringsToPubSub(unittest.TestCase): def test_expand(self): p = TestPipeline() - pdone = (p + p.options.view_as(StandardOptions).streaming = True + pcoll = (p | ReadStringsFromPubSub('projects/fakeprj/topics/baz') - | WriteStringsToPubSub('projects/fakeprj/topics/a_topic')) + | WriteStringsToPubSub('projects/fakeprj/topics/a_topic') + | Map(lambda x: x)) + + # Apply the necessary PTransformOverrides. + overrides = _get_transform_overrides(p.options) + p.replace_all(overrides) + + # Note that the direct output of ReadStringsFromPubSub will be replaced + # by a PTransformOverride, so we use a no-op Map. + write_transform = pcoll.producer.inputs[0].producer.transform # Ensure that the properties passed through correctly - self.assertEqual('a_topic', pdone.producer.transform.dofn.topic_name) + self.assertEqual('a_topic', write_transform.dofn.topic_name) @unittest.skipIf(pubsub is None, 'GCP dependencies are not installed') diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py index 4ac5ea86bf28..b34c08d808c3 100644 --- a/sdks/python/apache_beam/pipeline.py +++ b/sdks/python/apache_beam/pipeline.py @@ -62,6 +62,7 @@ from apache_beam.options.pipeline_options import TypeOptions from apache_beam.options.pipeline_options_validator import PipelineOptionsValidator from apache_beam.pvalue import PCollection +from apache_beam.pvalue import PDone from apache_beam.runners import PipelineRunner from apache_beam.runners import create_runner from apache_beam.transforms import ptransform @@ -197,6 +198,8 @@ def _replace_if_needed(self, original_transform_node): assert isinstance(original_transform_node, AppliedPTransform) replacement_transform = override.get_replacement_transform( original_transform_node.transform) + if replacement_transform is original_transform_node.transform: + return replacement_transform_node = AppliedPTransform( original_transform_node.parent, replacement_transform, @@ -227,6 +230,10 @@ def _replace_if_needed(self, original_transform_node): 'have a single input. Tried to replace input of ' 'AppliedPTransform %r that has %d inputs', original_transform_node, len(inputs)) + elif len(inputs) == 1: + input_node = inputs[0] + elif len(inputs) == 0: + input_node = pvalue.PBegin(self) # We have to add the new AppliedTransform to the stack before expand() # and pop it out later to make sure that parts get added correctly. @@ -239,16 +246,18 @@ def _replace_if_needed(self, original_transform_node): # with labels of the children of the original. self.pipeline._remove_labels_recursively(original_transform_node) - new_output = replacement_transform.expand(inputs[0]) + new_output = replacement_transform.expand(input_node) replacement_transform_node.add_output(new_output) + if not new_output.producer: + new_output.producer = replacement_transform_node # We only support replacing transforms with a single output with # another transform that produces a single output. # TODO: Support replacing PTransforms with multiple outputs. if (len(original_transform_node.outputs) > 1 or - not isinstance( - original_transform_node.outputs[None], PCollection) or - not isinstance(new_output, PCollection)): + not isinstance(original_transform_node.outputs[None], + (PCollection, PDone)) or + not isinstance(new_output, (PCollection, PDone))): raise NotImplementedError( 'PTransform overriding is only supported for PTransforms that ' 'have a single output. Tried to replace output of ' diff --git a/sdks/python/apache_beam/pipeline_test.py b/sdks/python/apache_beam/pipeline_test.py index 34ec48ecfd64..3381d1bdaa3d 100644 --- a/sdks/python/apache_beam/pipeline_test.py +++ b/sdks/python/apache_beam/pipeline_test.py @@ -323,7 +323,7 @@ def get_replacement_transform(self, ptransform): return TripleParDo() raise ValueError('Unsupported type of transform: %r', ptransform) - def get_overrides(): + def get_overrides(unused_pipeline_options): return [MyParDoOverride()] file_system_override_mock.side_effect = get_overrides diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py index 2d529e11d2c6..b5300a4a9f64 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py @@ -273,8 +273,10 @@ def test_group_by_key_input_visitor_with_valid_inputs(self): pcoll2.element_type = typehints.Any pcoll3.element_type = typehints.KV[typehints.Any, typehints.Any] for pcoll in [pcoll1, pcoll2, pcoll3]: + applied = AppliedPTransform(None, transform, "label", [pcoll]) + applied.outputs[None] = PCollection(None) DataflowRunner.group_by_key_input_visitor().visit_transform( - AppliedPTransform(None, transform, "label", [pcoll])) + applied) self.assertEqual(pcoll.element_type, typehints.KV[typehints.Any, typehints.Any]) diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py b/sdks/python/apache_beam/runners/direct/direct_runner.py index 7f3200ea5f6d..33a390fd0b41 100644 --- a/sdks/python/apache_beam/runners/direct/direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/direct_runner.py @@ -32,7 +32,6 @@ from apache_beam.metrics.execution import MetricsEnvironment from apache_beam.options.pipeline_options import DirectOptions from apache_beam.options.pipeline_options import StandardOptions -from apache_beam.options.pipeline_options import TypeOptions from apache_beam.options.value_provider import RuntimeValueProvider from apache_beam.pvalue import PCollection from apache_beam.runners.direct.bundle_factory import BundleFactory @@ -41,6 +40,7 @@ from apache_beam.runners.runner import PipelineResult from apache_beam.runners.runner import PipelineRunner from apache_beam.runners.runner import PipelineState +from apache_beam.transforms.core import CombinePerKey from apache_beam.transforms.core import _GroupAlsoByWindow from apache_beam.transforms.core import _GroupByKeyOnly from apache_beam.transforms.ptransform import PTransform @@ -87,7 +87,23 @@ def from_runner_api_parameter(payload, context): context.windowing_strategies.get_by_id(payload.value)) -def _get_transform_overrides(): +class _DirectReadStringsFromPubSub(PTransform): + def __init__(self, source): + self._source = source + + def _infer_output_coder(self, unused_input_type=None, + unused_input_coder=None): + return coders.StrUtf8Coder() + + def get_windowing(self, inputs): + return beam.Windowing(beam.window.GlobalWindows()) + + def expand(self, pvalue): + # This is handled as a native transform. + return PCollection(self.pipeline) + + +def _get_transform_overrides(pipeline_options): # A list of PTransformOverride objects to be applied before running a pipeline # using DirectRunner. # Currently this only works for overrides where the input and output types do @@ -95,10 +111,135 @@ def _get_transform_overrides(): # For internal use only; no backwards-compatibility guarantees. # Importing following locally to avoid a circular dependency. + from apache_beam.pipeline import PTransformOverride from apache_beam.runners.sdf_common import SplittableParDoOverride + from apache_beam.runners.direct.helper_transforms import LiftedCombinePerKey from apache_beam.runners.direct.sdf_direct_runner import ProcessKeyedElementsViaKeyedWorkItemsOverride - return [SplittableParDoOverride(), - ProcessKeyedElementsViaKeyedWorkItemsOverride()] + + class CombinePerKeyOverride(PTransformOverride): + def get_matcher(self): + def _matcher(applied_ptransform): + if isinstance(applied_ptransform.transform, CombinePerKey): + return True + return _matcher + + def get_replacement_transform(self, transform): + # TODO: Move imports to top. Pipeline <-> Runner dependency cause problems + # with resolving imports when they are at top. + # pylint: disable=wrong-import-position + try: + return LiftedCombinePerKey(transform.fn, transform.args, + transform.kwargs) + except NotImplementedError: + return transform + + class StreamingGroupByKeyOverride(PTransformOverride): + def get_matcher(self): + def _matcher(applied_ptransform): + # Note: we match the exact class, since we replace it with a subclass. + return applied_ptransform.transform.__class__ == _GroupByKeyOnly + return _matcher + + def get_replacement_transform(self, transform): + # Use specialized streaming implementation. + type_hints = transform.get_type_hints() + transform = (_StreamingGroupByKeyOnly() + .with_input_types(*type_hints.input_types[0]) + .with_output_types(*type_hints.output_types[0])) + return transform + + class StreamingGroupAlsoByWindowOverride(PTransformOverride): + def get_matcher(self): + def _matcher(applied_ptransform): + # Note: we match the exact class, since we replace it with a subclass. + return applied_ptransform.transform.__class__ == _GroupAlsoByWindow + return _matcher + + def get_replacement_transform(self, transform): + # Use specialized streaming implementation. + type_hints = transform.get_type_hints() + transform = (_StreamingGroupAlsoByWindow(transform.windowing) + .with_input_types(*type_hints.input_types[0]) + .with_output_types(*type_hints.output_types[0])) + return transform + + overrides = [SplittableParDoOverride(), + ProcessKeyedElementsViaKeyedWorkItemsOverride(), + CombinePerKeyOverride(),] + + # Add streaming overrides, if necessary. + if pipeline_options.view_as(StandardOptions).streaming: + overrides.append(StreamingGroupByKeyOverride()) + overrides.append(StreamingGroupAlsoByWindowOverride()) + + # Add PubSub overrides, if PubSub is available. + pubsub = None + try: + from apache_beam.io.gcp import pubsub + except ImportError: + pass + if pubsub: + class ReadStringsFromPubSubOverride(PTransformOverride): + def get_matcher(self): + def _matcher(applied_ptransform): + return isinstance(applied_ptransform.transform, + pubsub.ReadStringsFromPubSub) + return _matcher + + def get_replacement_transform(self, transform): + if not pipeline_options.view_as(StandardOptions).streaming: + raise Exception('PubSub I/O is only available in streaming mode ' + '(use the --streaming flag).') + return _DirectReadStringsFromPubSub(transform._source) + + class WriteStringsToPubSubOverride(PTransformOverride): + def get_matcher(self): + def _matcher(applied_ptransform): + return isinstance(applied_ptransform.transform, + pubsub.WriteStringsToPubSub) + return _matcher + + def get_replacement_transform(self, transform): + if not pipeline_options.view_as(StandardOptions).streaming: + raise Exception('PubSub I/O is only available in streaming mode ' + '(use the --streaming flag).') + + class _DirectWriteToPubSub(beam.DoFn): + _topic = None + + def __init__(self, project, topic_name): + self.project = project + self.topic_name = topic_name + + def start_bundle(self): + if self._topic is None: + self._topic = pubsub.Client(project=self.project).topic( + self.topic_name) + self._buffer = [] + + def process(self, elem): + self._buffer.append(elem.encode('utf-8')) + if len(self._buffer) >= 100: + self._flush() + + def finish_bundle(self): + self._flush() + + def _flush(self): + if self._buffer: + with self._topic.batch() as batch: + for datum in self._buffer: + batch.publish(datum) + self._buffer = [] + + project = transform._sink.project + topic_name = transform._sink.topic_name + return beam.ParDo(_DirectWriteToPubSub(project, topic_name)) + + overrides.append(ReadStringsFromPubSubOverride()) + overrides.append(WriteStringsToPubSubOverride()) + + return overrides class DirectRunner(PipelineRunner): @@ -106,103 +247,12 @@ class DirectRunner(PipelineRunner): def __init__(self): self._use_test_clock = False # use RealClock() in production - self._ptransform_overrides = _get_transform_overrides() - - def apply_CombinePerKey(self, transform, pcoll): - if pcoll.pipeline._options.view_as(TypeOptions).runtime_type_check: - # TODO(robertwb): This can be reenabled once expansion happens after run. - return transform.expand(pcoll) - # TODO: Move imports to top. Pipeline <-> Runner dependency cause problems - # with resolving imports when they are at top. - # pylint: disable=wrong-import-position - from apache_beam.runners.direct.helper_transforms import LiftedCombinePerKey - try: - return pcoll | LiftedCombinePerKey( - transform.fn, transform.args, transform.kwargs) - except NotImplementedError: - return transform.expand(pcoll) - - def apply_TestStream(self, transform, pcoll): - self._use_test_clock = True # use TestClock() for testing - return transform.expand(pcoll) - - def apply__GroupByKeyOnly(self, transform, pcoll): - if (transform.__class__ == _GroupByKeyOnly and - pcoll.pipeline._options.view_as(StandardOptions).streaming): - # Use specialized streaming implementation, if requested. - type_hints = transform.get_type_hints() - return pcoll | (_StreamingGroupByKeyOnly() - .with_input_types(*type_hints.input_types[0]) - .with_output_types(*type_hints.output_types[0])) - return transform.expand(pcoll) - - def apply__GroupAlsoByWindow(self, transform, pcoll): - if (transform.__class__ == _GroupAlsoByWindow and - pcoll.pipeline._options.view_as(StandardOptions).streaming): - # Use specialized streaming implementation, if requested. - type_hints = transform.get_type_hints() - return pcoll | (_StreamingGroupAlsoByWindow(transform.windowing) - .with_input_types(*type_hints.input_types[0]) - .with_output_types(*type_hints.output_types[0])) - return transform.expand(pcoll) - - def apply_ReadStringsFromPubSub(self, transform, pcoll): - try: - from google.cloud import pubsub as unused_pubsub - except ImportError: - raise ImportError('Google Cloud PubSub not available, please install ' - 'apache_beam[gcp]') - # Execute this as a native transform. - output = PCollection(pcoll.pipeline) - output.element_type = unicode - return output - - def apply_WriteStringsToPubSub(self, transform, pcoll): - try: - from google.cloud import pubsub - except ImportError: - raise ImportError('Google Cloud PubSub not available, please install ' - 'apache_beam[gcp]') - project = transform._sink.project - topic_name = transform._sink.topic_name - - class DirectWriteToPubSub(beam.DoFn): - _topic = None - - def __init__(self, project, topic_name): - self.project = project - self.topic_name = topic_name - - def start_bundle(self): - if self._topic is None: - self._topic = pubsub.Client(project=self.project).topic( - self.topic_name) - self._buffer = [] - - def process(self, elem): - self._buffer.append(elem.encode('utf-8')) - if len(self._buffer) >= 100: - self._flush() - - def finish_bundle(self): - self._flush() - - def _flush(self): - if self._buffer: - with self._topic.batch() as batch: - for datum in self._buffer: - batch.publish(datum) - self._buffer = [] - - output = pcoll | beam.ParDo(DirectWriteToPubSub(project, topic_name)) - output.element_type = unicode - return output def run_pipeline(self, pipeline): """Execute the entire pipeline and returns an DirectPipelineResult.""" # Performing configured PTransform overrides. - pipeline.replace_all(self._ptransform_overrides) + pipeline.replace_all(_get_transform_overrides(pipeline.options)) # TODO: Move imports to top. Pipeline <-> Runner dependency cause problems # with resolving imports when they are at top. diff --git a/sdks/python/apache_beam/runners/direct/helper_transforms.py b/sdks/python/apache_beam/runners/direct/helper_transforms.py index 26b0701bd02b..0c1da0351264 100644 --- a/sdks/python/apache_beam/runners/direct/helper_transforms.py +++ b/sdks/python/apache_beam/runners/direct/helper_transforms.py @@ -21,6 +21,7 @@ import apache_beam as beam from apache_beam import typehints from apache_beam.internal.util import ArgumentPlaceholder +from apache_beam.transforms.combiners import _CurriedFn from apache_beam.utils.windowed_value import WindowedValue @@ -28,8 +29,13 @@ class LiftedCombinePerKey(beam.PTransform): """An implementation of CombinePerKey that does mapper-side pre-combining. """ def __init__(self, combine_fn, args, kwargs): + args_to_check = itertools.chain(args, kwargs.values()) + if isinstance(combine_fn, _CurriedFn): + args_to_check = itertools.chain(args_to_check, + combine_fn.args, + combine_fn.kwargs.values()) if any(isinstance(arg, ArgumentPlaceholder) - for arg in itertools.chain(args, kwargs.values())): + for arg in args_to_check): # This isn't implemented in dataflow either... raise NotImplementedError('Deferred CombineFn side inputs.') self._combine_fn = beam.transforms.combiners.curry_combine_fn( diff --git a/sdks/python/apache_beam/transforms/combiners.py b/sdks/python/apache_beam/transforms/combiners.py index 8e4188aca673..149048f7c9c6 100644 --- a/sdks/python/apache_beam/transforms/combiners.py +++ b/sdks/python/apache_beam/transforms/combiners.py @@ -536,30 +536,35 @@ def extract_output(self, accumulator): return accumulator -def curry_combine_fn(fn, args, kwargs): - if not args and not kwargs: - return fn +class _CurriedFn(core.CombineFn): + """Wrapped CombineFn with extra arguments.""" + + def __init__(self, fn, args, kwargs): + self.fn = fn + self.args = args + self.kwargs = kwargs + + def create_accumulator(self): + return self.fn.create_accumulator(*self.args, **self.kwargs) - # Create CurriedFn class for the combiner - class CurriedFn(core.CombineFn): - """CombineFn that applies extra arguments.""" + def add_input(self, accumulator, element): + return self.fn.add_input(accumulator, element, *self.args, **self.kwargs) - def create_accumulator(self): - return fn.create_accumulator(*args, **kwargs) + def merge_accumulators(self, accumulators): + return self.fn.merge_accumulators(accumulators, *self.args, **self.kwargs) - def add_input(self, accumulator, element): - return fn.add_input(accumulator, element, *args, **kwargs) + def extract_output(self, accumulator): + return self.fn.extract_output(accumulator, *self.args, **self.kwargs) - def merge_accumulators(self, accumulators): - return fn.merge_accumulators(accumulators, *args, **kwargs) + def apply(self, elements): + return self.fn.apply(elements, *self.args, **self.kwargs) - def extract_output(self, accumulator): - return fn.extract_output(accumulator, *args, **kwargs) - def apply(self, elements): - return fn.apply(elements, *args, **kwargs) +def curry_combine_fn(fn, args, kwargs): + if not args and not kwargs: + return fn - return CurriedFn() + return _CurriedFn(fn, args, kwargs) class PhasedCombineFnExecutor(object): From e1a96eb5aa532a92a4b2e9066dd5b5bcd7e8719b Mon Sep 17 00:00:00 2001 From: Charles Chen Date: Fri, 2 Feb 2018 16:22:33 -0800 Subject: [PATCH 2/2] Address reviewer comments --- sdks/python/apache_beam/io/gcp/pubsub_test.py | 10 +- sdks/python/apache_beam/pipeline.py | 17 ++- sdks/python/apache_beam/pipeline_test.py | 7 +- .../runners/dataflow/ptransform_overrides.py | 6 +- .../runners/direct/direct_runner.py | 140 +++++++++--------- .../runners/direct/sdf_direct_runner.py | 9 +- sdks/python/apache_beam/runners/sdf_common.py | 15 +- .../apache_beam/transforms/combiners.py | 4 +- 8 files changed, 95 insertions(+), 113 deletions(-) diff --git a/sdks/python/apache_beam/io/gcp/pubsub_test.py b/sdks/python/apache_beam/io/gcp/pubsub_test.py index 36d40a1d0bbc..8bd9fa4f41aa 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub_test.py +++ b/sdks/python/apache_beam/io/gcp/pubsub_test.py @@ -22,7 +22,7 @@ import hamcrest as hc -from apache_beam import Map +import apache_beam as beam from apache_beam.io.gcp.pubsub import ReadStringsFromPubSub from apache_beam.io.gcp.pubsub import WriteStringsToPubSub from apache_beam.io.gcp.pubsub import _PubSubPayloadSink @@ -43,14 +43,14 @@ @unittest.skipIf(pubsub is None, 'GCP dependencies are not installed') -class TestReadStringsFromPubSub(unittest.TestCase): +class TestReadStringsFromPubSubOverride(unittest.TestCase): def test_expand_with_topic(self): p = TestPipeline() p.options.view_as(StandardOptions).streaming = True pcoll = (p | ReadStringsFromPubSub('projects/fakeprj/topics/a_topic', None, 'a_label') - | Map(lambda x: x)) + | beam.Map(lambda x: x)) # Ensure that the output type is str. self.assertEqual(unicode, pcoll.element_type) @@ -74,7 +74,7 @@ def test_expand_with_subscription(self): | ReadStringsFromPubSub( None, 'projects/fakeprj/subscriptions/a_subscription', 'a_label') - | Map(lambda x: x)) + | beam.Map(lambda x: x)) # Ensure that the output type is str self.assertEqual(unicode, pcoll.element_type) @@ -110,7 +110,7 @@ def test_expand(self): pcoll = (p | ReadStringsFromPubSub('projects/fakeprj/topics/baz') | WriteStringsToPubSub('projects/fakeprj/topics/a_topic') - | Map(lambda x: x)) + | beam.Map(lambda x: x)) # Apply the necessary PTransformOverrides. overrides = _get_transform_overrides(p.options) diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py index b34c08d808c3..c59a29afa7e5 100644 --- a/sdks/python/apache_beam/pipeline.py +++ b/sdks/python/apache_beam/pipeline.py @@ -181,7 +181,6 @@ def _remove_labels_recursively(self, applied_transform): def _replace(self, override): assert isinstance(override, PTransformOverride) - matcher = override.get_matcher() output_map = {} output_replacements = {} @@ -194,7 +193,7 @@ def __init__(self, pipeline): self.pipeline = pipeline def _replace_if_needed(self, original_transform_node): - if matcher(original_transform_node): + if override.matches(original_transform_node): assert isinstance(original_transform_node, AppliedPTransform) replacement_transform = override.get_replacement_transform( original_transform_node.transform) @@ -323,11 +322,10 @@ def visit_transform(self, transform_node): transform.inputs = input_replacements[transform] def _check_replacement(self, override): - matcher = override.get_matcher() class ReplacementValidator(PipelineVisitor): def visit_transform(self, transform_node): - if matcher(transform_node): + if override.matches(transform_node): raise RuntimeError('Transform node %r was not replaced as expected.', transform_node) @@ -861,12 +859,14 @@ class PTransformOverride(object): __metaclass__ = abc.ABCMeta @abc.abstractmethod - def get_matcher(self): - """Gives a matcher that will be used to to perform this override. + def matches(self, applied_ptransform): + """Determines whether the given AppliedPTransform matches. + + Args: + applied_ptransform: AppliedPTransform to be matched. Returns: - a callable that takes an AppliedPTransform as a parameter and returns a - boolean as a result. + a bool indicating whether the given AppliedPTransform is a match. """ raise NotImplementedError @@ -876,6 +876,7 @@ def get_replacement_transform(self, ptransform): Args: ptransform: PTransform to be replaced. + Returns: A PTransform that will be the replacement for the PTransform given as an argument. diff --git a/sdks/python/apache_beam/pipeline_test.py b/sdks/python/apache_beam/pipeline_test.py index 3381d1bdaa3d..3b26d3fc1bca 100644 --- a/sdks/python/apache_beam/pipeline_test.py +++ b/sdks/python/apache_beam/pipeline_test.py @@ -310,13 +310,10 @@ def raise_exception(exn): 'apache_beam.runners.direct.direct_runner._get_transform_overrides') def test_ptransform_overrides(self, file_system_override_mock): - def my_par_do_matcher(applied_ptransform): - return isinstance(applied_ptransform.transform, DoubleParDo) - class MyParDoOverride(PTransformOverride): - def get_matcher(self): - return my_par_do_matcher + def matches(self, applied_ptransform): + return isinstance(applied_ptransform.transform, DoubleParDo) def get_replacement_transform(self, ptransform): if isinstance(ptransform, DoubleParDo): diff --git a/sdks/python/apache_beam/runners/dataflow/ptransform_overrides.py b/sdks/python/apache_beam/runners/dataflow/ptransform_overrides.py index 680a4b7de5c2..0ce212fa31bd 100644 --- a/sdks/python/apache_beam/runners/dataflow/ptransform_overrides.py +++ b/sdks/python/apache_beam/runners/dataflow/ptransform_overrides.py @@ -24,11 +24,7 @@ class CreatePTransformOverride(PTransformOverride): """A ``PTransformOverride`` for ``Create`` in streaming mode.""" - def get_matcher(self): - return self.is_streaming_create - - @staticmethod - def is_streaming_create(applied_ptransform): + def matches(self, applied_ptransform): # Imported here to avoid circular dependencies. # pylint: disable=wrong-import-order, wrong-import-position from apache_beam import Create diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py b/sdks/python/apache_beam/runners/direct/direct_runner.py index 33a390fd0b41..d82fa15e111b 100644 --- a/sdks/python/apache_beam/runners/direct/direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/direct_runner.py @@ -117,11 +117,9 @@ def _get_transform_overrides(pipeline_options): from apache_beam.runners.direct.sdf_direct_runner import ProcessKeyedElementsViaKeyedWorkItemsOverride class CombinePerKeyOverride(PTransformOverride): - def get_matcher(self): - def _matcher(applied_ptransform): - if isinstance(applied_ptransform.transform, CombinePerKey): - return True - return _matcher + def matches(self, applied_ptransform): + if isinstance(applied_ptransform.transform, CombinePerKey): + return True def get_replacement_transform(self, transform): # TODO: Move imports to top. Pipeline <-> Runner dependency cause problems @@ -134,11 +132,9 @@ def get_replacement_transform(self, transform): return transform class StreamingGroupByKeyOverride(PTransformOverride): - def get_matcher(self): - def _matcher(applied_ptransform): - # Note: we match the exact class, since we replace it with a subclass. - return applied_ptransform.transform.__class__ == _GroupByKeyOnly - return _matcher + def matches(self, applied_ptransform): + # Note: we match the exact class, since we replace it with a subclass. + return applied_ptransform.transform.__class__ == _GroupByKeyOnly def get_replacement_transform(self, transform): # Use specialized streaming implementation. @@ -149,11 +145,9 @@ def get_replacement_transform(self, transform): return transform class StreamingGroupAlsoByWindowOverride(PTransformOverride): - def get_matcher(self): - def _matcher(applied_ptransform): - # Note: we match the exact class, since we replace it with a subclass. - return applied_ptransform.transform.__class__ == _GroupAlsoByWindow - return _matcher + def matches(self, applied_ptransform): + # Note: we match the exact class, since we replace it with a subclass. + return applied_ptransform.transform.__class__ == _GroupAlsoByWindow def get_replacement_transform(self, transform): # Use specialized streaming implementation. @@ -165,7 +159,7 @@ def get_replacement_transform(self, transform): overrides = [SplittableParDoOverride(), ProcessKeyedElementsViaKeyedWorkItemsOverride(), - CombinePerKeyOverride(),] + CombinePerKeyOverride()] # Add streaming overrides, if necessary. if pipeline_options.view_as(StandardOptions).streaming: @@ -173,73 +167,73 @@ def get_replacement_transform(self, transform): overrides.append(StreamingGroupAlsoByWindowOverride()) # Add PubSub overrides, if PubSub is available. - pubsub = None try: - from apache_beam.io.gcp import pubsub + from apache_beam.io.gcp import pubsub as unused_pubsub + overrides += _get_pubsub_transform_overrides(pipeline_options) except ImportError: pass - if pubsub: - class ReadStringsFromPubSubOverride(PTransformOverride): - def get_matcher(self): - def _matcher(applied_ptransform): - return isinstance(applied_ptransform.transform, - pubsub.ReadStringsFromPubSub) - return _matcher - - def get_replacement_transform(self, transform): - if not pipeline_options.view_as(StandardOptions).streaming: - raise Exception('PubSub I/O is only available in streaming mode ' - '(use the --streaming flag).') - return _DirectReadStringsFromPubSub(transform._source) - - class WriteStringsToPubSubOverride(PTransformOverride): - def get_matcher(self): - def _matcher(applied_ptransform): - return isinstance(applied_ptransform.transform, - pubsub.WriteStringsToPubSub) - return _matcher - - def get_replacement_transform(self, transform): - if not pipeline_options.view_as(StandardOptions).streaming: - raise Exception('PubSub I/O is only available in streaming mode ' - '(use the --streaming flag).') - - class _DirectWriteToPubSub(beam.DoFn): - _topic = None - - def __init__(self, project, topic_name): - self.project = project - self.topic_name = topic_name - - def start_bundle(self): - if self._topic is None: - self._topic = pubsub.Client(project=self.project).topic( - self.topic_name) - self._buffer = [] - def process(self, elem): - self._buffer.append(elem.encode('utf-8')) - if len(self._buffer) >= 100: - self._flush() + return overrides + + +def _get_pubsub_transform_overrides(pipeline_options): + from apache_beam.io.gcp import pubsub + from apache_beam.pipeline import PTransformOverride + + class ReadStringsFromPubSubOverride(PTransformOverride): + def matches(self, applied_ptransform): + return isinstance(applied_ptransform.transform, + pubsub.ReadStringsFromPubSub) - def finish_bundle(self): + def get_replacement_transform(self, transform): + if not pipeline_options.view_as(StandardOptions).streaming: + raise Exception('PubSub I/O is only available in streaming mode ' + '(use the --streaming flag).') + return _DirectReadStringsFromPubSub(transform._source) + + class WriteStringsToPubSubOverride(PTransformOverride): + def matches(self, applied_ptransform): + return isinstance(applied_ptransform.transform, + pubsub.WriteStringsToPubSub) + + def get_replacement_transform(self, transform): + if not pipeline_options.view_as(StandardOptions).streaming: + raise Exception('PubSub I/O is only available in streaming mode ' + '(use the --streaming flag).') + + class _DirectWriteToPubSub(beam.DoFn): + _topic = None + + def __init__(self, project, topic_name): + self.project = project + self.topic_name = topic_name + + def start_bundle(self): + if self._topic is None: + self._topic = pubsub.Client(project=self.project).topic( + self.topic_name) + self._buffer = [] + + def process(self, elem): + self._buffer.append(elem.encode('utf-8')) + if len(self._buffer) >= 100: self._flush() - def _flush(self): - if self._buffer: - with self._topic.batch() as batch: - for datum in self._buffer: - batch.publish(datum) - self._buffer = [] + def finish_bundle(self): + self._flush() - project = transform._sink.project - topic_name = transform._sink.topic_name - return beam.ParDo(_DirectWriteToPubSub(project, topic_name)) + def _flush(self): + if self._buffer: + with self._topic.batch() as batch: + for datum in self._buffer: + batch.publish(datum) + self._buffer = [] - overrides.append(ReadStringsFromPubSubOverride()) - overrides.append(WriteStringsToPubSubOverride()) + project = transform._sink.project + topic_name = transform._sink.topic_name + return beam.ParDo(_DirectWriteToPubSub(project, topic_name)) - return overrides + return [ReadStringsFromPubSubOverride(), WriteStringsToPubSubOverride()] class DirectRunner(PipelineRunner): diff --git a/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py b/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py index ddbe9649b424..aa247aa4118b 100644 --- a/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py @@ -44,12 +44,9 @@ class ProcessKeyedElementsViaKeyedWorkItemsOverride(PTransformOverride): """A transform override for ProcessElements transform.""" - def get_matcher(self): - def _matcher(applied_ptransform): - return isinstance( - applied_ptransform.transform, ProcessKeyedElements) - - return _matcher + def matches(self, applied_ptransform): + return isinstance( + applied_ptransform.transform, ProcessKeyedElements) def get_replacement_transform(self, ptransform): return ProcessKeyedElementsViaKeyedWorkItems(ptransform) diff --git a/sdks/python/apache_beam/runners/sdf_common.py b/sdks/python/apache_beam/runners/sdf_common.py index a7d80ac8b180..a3e141891236 100644 --- a/sdks/python/apache_beam/runners/sdf_common.py +++ b/sdks/python/apache_beam/runners/sdf_common.py @@ -37,15 +37,12 @@ class SplittableParDoOverride(PTransformOverride): SDF specific logic. """ - def get_matcher(self): - def _matcher(applied_ptransform): - assert isinstance(applied_ptransform, AppliedPTransform) - transform = applied_ptransform.transform - if isinstance(transform, ParDo): - signature = DoFnSignature(transform.fn) - return signature.is_splittable_dofn() - - return _matcher + def matches(self, applied_ptransform): + assert isinstance(applied_ptransform, AppliedPTransform) + transform = applied_ptransform.transform + if isinstance(transform, ParDo): + signature = DoFnSignature(transform.fn) + return signature.is_splittable_dofn() def get_replacement_transform(self, ptransform): assert isinstance(ptransform, ParDo) diff --git a/sdks/python/apache_beam/transforms/combiners.py b/sdks/python/apache_beam/transforms/combiners.py index 149048f7c9c6..e29855e5f8fc 100644 --- a/sdks/python/apache_beam/transforms/combiners.py +++ b/sdks/python/apache_beam/transforms/combiners.py @@ -563,8 +563,8 @@ def apply(self, elements): def curry_combine_fn(fn, args, kwargs): if not args and not kwargs: return fn - - return _CurriedFn(fn, args, kwargs) + else: + return _CurriedFn(fn, args, kwargs) class PhasedCombineFnExecutor(object):