diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ConsumerTrackingPipelineVisitor.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ConsumerTrackingPipelineVisitor.java index c790463dcf27..3300723ad298 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ConsumerTrackingPipelineVisitor.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ConsumerTrackingPipelineVisitor.java @@ -41,7 +41,7 @@ * {@link Pipeline}. This is used to schedule consuming {@link PTransform PTransforms} to consume * input after the upstream transform has produced and committed output. */ -public class ConsumerTrackingPipelineVisitor implements PipelineVisitor { +public class ConsumerTrackingPipelineVisitor extends PipelineVisitor.Defaults { private Map>> valueToConsumers = new HashMap<>(); private Collection> rootTransforms = new ArrayList<>(); private Collection> views = new ArrayList<>(); @@ -51,13 +51,14 @@ public class ConsumerTrackingPipelineVisitor implements PipelineVisitor { private boolean finalized = false; @Override - public void enterCompositeTransform(TransformTreeNode node) { + public CompositeBehavior enterCompositeTransform(TransformTreeNode node) { checkState( !finalized, "Attempting to traverse a pipeline (node %s) with a %s " + "which has already visited a Pipeline and is finalized", node.getFullName(), ConsumerTrackingPipelineVisitor.class.getSimpleName()); + return CompositeBehavior.ENTER_TRANSFORM; } @Override @@ -73,7 +74,7 @@ public void leaveCompositeTransform(TransformTreeNode node) { } @Override - public void visitTransform(TransformTreeNode node) { + public void visitPrimitiveTransform(TransformTreeNode node) { toFinalize.removeAll(node.getInput().expand()); AppliedPTransform appliedTransform = getAppliedTransform(node); stepNames.put(appliedTransform, genStepName()); diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java index b7c755ec1c8a..2fea00a6e5d4 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java @@ -56,12 +56,13 @@ private KeyedPValueTrackingVisitor( } @Override - public void enterCompositeTransform(TransformTreeNode node) { + public CompositeBehavior enterCompositeTransform(TransformTreeNode node) { checkState( !finalized, "Attempted to use a %s that has already been finalized on a pipeline (visiting node %s)", KeyedPValueTrackingVisitor.class.getSimpleName(), node); + return CompositeBehavior.ENTER_TRANSFORM; } @Override @@ -79,7 +80,7 @@ public void leaveCompositeTransform(TransformTreeNode node) { } @Override - public void visitTransform(TransformTreeNode node) {} + public void visitPrimitiveTransform(TransformTreeNode node) {} @Override public void visitValue(PValue value, TransformTreeNode producer) { diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchPipelineTranslator.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchPipelineTranslator.java index 456cf09457b6..3d39e8182cab 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchPipelineTranslator.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchPipelineTranslator.java @@ -43,11 +43,6 @@ public class FlinkBatchPipelineTranslator extends FlinkPipelineTranslator { private int depth = 0; - /** - * Composite transform that we want to translate before proceeding with other transforms. - */ - private PTransform currentCompositeTransform; - public FlinkBatchPipelineTranslator(ExecutionEnvironment env, PipelineOptions options) { this.batchContext = new FlinkBatchTranslationContext(env, options); } @@ -57,54 +52,33 @@ public FlinkBatchPipelineTranslator(ExecutionEnvironment env, PipelineOptions op // -------------------------------------------------------------------------------------------- @Override - public void enterCompositeTransform(TransformTreeNode node) { + public CompositeBehavior enterCompositeTransform(TransformTreeNode node) { LOG.info(genSpaces(this.depth) + "enterCompositeTransform- " + formatNodeName(node)); - PTransform transform = node.getTransform(); - if (transform != null && currentCompositeTransform == null) { - - BatchTransformTranslator translator = FlinkBatchTransformTranslators.getTranslator(transform); - if (translator != null) { - currentCompositeTransform = transform; - if (transform instanceof CoGroupByKey && node.getInput().expand().size() != 2) { - // we can only optimize CoGroupByKey for input size 2 - currentCompositeTransform = null; - } - } + BatchTransformTranslator translator = getTranslator(node); + + if (translator != null) { + applyBatchTransform(node.getTransform(), node, translator); + LOG.info(genSpaces(this.depth) + "translated-" + formatNodeName(node)); + return CompositeBehavior.DO_NOT_ENTER_TRANSFORM; + } else { + this.depth++; + return CompositeBehavior.ENTER_TRANSFORM; } - this.depth++; } @Override public void leaveCompositeTransform(TransformTreeNode node) { - PTransform transform = node.getTransform(); - if (transform != null && currentCompositeTransform == transform) { - - BatchTransformTranslator translator = FlinkBatchTransformTranslators.getTranslator(transform); - if (translator != null) { - LOG.info(genSpaces(this.depth) + "doingCompositeTransform- " + formatNodeName(node)); - applyBatchTransform(transform, node, translator); - currentCompositeTransform = null; - } else { - throw new IllegalStateException("Attempted to translate composite transform " + - "but no translator was found: " + currentCompositeTransform); - } - } this.depth--; LOG.info(genSpaces(this.depth) + "leaveCompositeTransform- " + formatNodeName(node)); } @Override - public void visitTransform(TransformTreeNode node) { - LOG.info(genSpaces(this.depth) + "visitTransform- " + formatNodeName(node)); - if (currentCompositeTransform != null) { - // ignore it - return; - } + public void visitPrimitiveTransform(TransformTreeNode node) { + LOG.info(genSpaces(this.depth) + "visitPrimitiveTransform- " + formatNodeName(node)); - // get the transformation corresponding to hte node we are + // get the transformation corresponding to the node we are // currently visiting and translate it into its Flink alternative. - PTransform transform = node.getTransform(); BatchTransformTranslator translator = FlinkBatchTransformTranslators.getTranslator(transform); if (translator == null) { @@ -114,11 +88,6 @@ public void visitTransform(TransformTreeNode node) { applyBatchTransform(transform, node, translator); } - @Override - public void visitValue(PValue value, TransformTreeNode producer) { - // do nothing here - } - private > void applyBatchTransform(PTransform transform, TransformTreeNode node, BatchTransformTranslator translator) { @SuppressWarnings("unchecked") @@ -140,6 +109,32 @@ public interface BatchTransformTranslator { void translateNode(Type transform, FlinkBatchTranslationContext context); } + /** + * Returns a translator for the given node, if it is possible, otherwise null. + */ + private static BatchTransformTranslator getTranslator(TransformTreeNode node) { + PTransform transform = node.getTransform(); + + // Root of the graph is null + if (transform == null) { + return null; + } + + BatchTransformTranslator translator = FlinkBatchTransformTranslators.getTranslator(transform); + + // No translator known + if (translator == null) { + return null; + } + + // We actually only specialize CoGroupByKey when exactly 2 inputs + if (transform instanceof CoGroupByKey && node.getInput().expand().size() != 2) { + return null; + } + + return translator; + } + private static String genSpaces(int n) { String s = ""; for (int i = 0; i < n; i++) { diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkPipelineTranslator.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkPipelineTranslator.java index 82d23b0fab0a..46e571205fb3 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkPipelineTranslator.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkPipelineTranslator.java @@ -28,7 +28,7 @@ * a {@link org.apache.flink.streaming.api.datastream.DataStream} (for streaming) or a * {@link org.apache.flink.api.java.DataSet} (for batch) one. */ -public abstract class FlinkPipelineTranslator implements Pipeline.PipelineVisitor { +public abstract class FlinkPipelineTranslator extends Pipeline.PipelineVisitor.Defaults { public void translate(Pipeline pipeline) { pipeline.traverseTopologically(this); diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingPipelineTranslator.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingPipelineTranslator.java index ebaf6ba0234b..31b2bee63f72 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingPipelineTranslator.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingPipelineTranslator.java @@ -43,9 +43,6 @@ public class FlinkStreamingPipelineTranslator extends FlinkPipelineTranslator { private int depth = 0; - /** Composite transform that we want to translate before proceeding with other transforms. */ - private PTransform currentCompositeTransform; - public FlinkStreamingPipelineTranslator(StreamExecutionEnvironment env, PipelineOptions options) { this.streamingContext = new FlinkStreamingTranslationContext(env, options); } @@ -55,47 +52,31 @@ public FlinkStreamingPipelineTranslator(StreamExecutionEnvironment env, Pipeline // -------------------------------------------------------------------------------------------- @Override - public void enterCompositeTransform(TransformTreeNode node) { + public CompositeBehavior enterCompositeTransform(TransformTreeNode node) { LOG.info(genSpaces(this.depth) + "enterCompositeTransform- " + formatNodeName(node)); PTransform transform = node.getTransform(); - if (transform != null && currentCompositeTransform == null) { - + if (transform != null) { StreamTransformTranslator translator = FlinkStreamingTransformTranslators.getTranslator(transform); if (translator != null) { - currentCompositeTransform = transform; + applyStreamingTransform(transform, node, translator); + LOG.info(genSpaces(this.depth) + "translated-" + formatNodeName(node)); + return CompositeBehavior.DO_NOT_ENTER_TRANSFORM; } } this.depth++; + return CompositeBehavior.ENTER_TRANSFORM; } @Override public void leaveCompositeTransform(TransformTreeNode node) { - PTransform transform = node.getTransform(); - if (transform != null && currentCompositeTransform == transform) { - - StreamTransformTranslator translator = FlinkStreamingTransformTranslators.getTranslator(transform); - if (translator != null) { - LOG.info(genSpaces(this.depth) + "doingCompositeTransform- " + formatNodeName(node)); - applyStreamingTransform(transform, node, translator); - currentCompositeTransform = null; - } else { - throw new IllegalStateException("Attempted to translate composite transform " + - "but no translator was found: " + currentCompositeTransform); - } - } this.depth--; LOG.info(genSpaces(this.depth) + "leaveCompositeTransform- " + formatNodeName(node)); } @Override - public void visitTransform(TransformTreeNode node) { - LOG.info(genSpaces(this.depth) + "visitTransform- " + formatNodeName(node)); - if (currentCompositeTransform != null) { - // ignore it - return; - } - + public void visitPrimitiveTransform(TransformTreeNode node) { + LOG.info(genSpaces(this.depth) + "visitPrimitiveTransform- " + formatNodeName(node)); // get the transformation corresponding to hte node we are // currently visiting and translate it into its Flink alternative. diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineRunner.java index 41b4df7717b2..407680215183 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineRunner.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineRunner.java @@ -680,17 +680,18 @@ public void visitValue(PValue value, TransformTreeNode producer) { } @Override - public void visitTransform(TransformTreeNode node) { + public void visitPrimitiveTransform(TransformTreeNode node) { if (ptransformViewsWithNonDeterministicKeyCoders.contains(node.getTransform())) { ptransformViewNamesWithNonDeterministicKeyCoders.add(node.getFullName()); } } @Override - public void enterCompositeTransform(TransformTreeNode node) { + public CompositeBehavior enterCompositeTransform(TransformTreeNode node) { if (ptransformViewsWithNonDeterministicKeyCoders.contains(node.getTransform())) { ptransformViewNamesWithNonDeterministicKeyCoders.add(node.getFullName()); } + return CompositeBehavior.ENTER_TRANSFORM; } @Override diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java index 4ef1bdb9878d..05879d9dd8f0 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java @@ -348,7 +348,7 @@ public void addCollectionToSingletonOutput(String name, /** * Translates a Pipeline into the Dataflow representation. */ - class Translator implements PipelineVisitor, TranslationContext { + class Translator extends PipelineVisitor.Defaults implements TranslationContext { /** The Pipeline to translate. */ private final Pipeline pipeline; @@ -493,16 +493,13 @@ public String getFullName(PTransform transform) { return currentTransform; } - @Override - public void enterCompositeTransform(TransformTreeNode node) { - } @Override public void leaveCompositeTransform(TransformTreeNode node) { } @Override - public void visitTransform(TransformTreeNode node) { + public void visitPrimitiveTransform(TransformTreeNode node) { PTransform transform = node.getTransform(); TransformTranslator translator = getTransformTranslator(transform.getClass()); diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineRunnerTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineRunnerTest.java index d4d4b3b73186..2993c5012027 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineRunnerTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineRunnerTest.java @@ -84,7 +84,6 @@ import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PDone; import org.apache.beam.sdk.values.PInput; -import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TimestampedValue; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; @@ -820,26 +819,15 @@ public void translate( } /** Records all the composite transforms visited within the Pipeline. */ - private static class CompositeTransformRecorder implements PipelineVisitor { + private static class CompositeTransformRecorder extends PipelineVisitor.Defaults { private List> transforms = new ArrayList<>(); @Override - public void enterCompositeTransform(TransformTreeNode node) { + public CompositeBehavior enterCompositeTransform(TransformTreeNode node) { if (node.getTransform() != null) { transforms.add(node.getTransform()); } - } - - @Override - public void leaveCompositeTransform(TransformTreeNode node) { - } - - @Override - public void visitTransform(TransformTreeNode node) { - } - - @Override - public void visitValue(PValue value, TransformTreeNode producer) { + return CompositeBehavior.ENTER_TRANSFORM; } public List> getCompositeTransforms() { diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunner.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunner.java index bae4e53874ec..af5acf1afc19 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunner.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunner.java @@ -41,7 +41,6 @@ import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; -import org.apache.beam.sdk.values.PValue; import org.apache.spark.SparkException; import org.apache.spark.api.java.JavaSparkContext; @@ -219,7 +218,7 @@ public EvaluationResult run(Pipeline pipeline) { /** * Evaluator on the pipeline. */ - public abstract static class Evaluator implements Pipeline.PipelineVisitor { + public abstract static class Evaluator extends Pipeline.PipelineVisitor.Defaults { protected static final Logger LOG = LoggerFactory.getLogger(Evaluator.class); protected final SparkPipelineTranslator translator; @@ -228,62 +227,29 @@ protected Evaluator(SparkPipelineTranslator translator) { this.translator = translator; } - // Set upon entering a composite node which can be directly mapped to a single - // TransformEvaluator. - private TransformTreeNode currentTranslatedCompositeNode; - - /** - * If true, we're currently inside a subtree of a composite node which directly maps to a - * single - * TransformEvaluator; children nodes are ignored, and upon post-visiting the translated - * composite node, the associated TransformEvaluator will be visited. - */ - private boolean inTranslatedCompositeNode() { - return currentTranslatedCompositeNode != null; - } - @Override - public void enterCompositeTransform(TransformTreeNode node) { - if (!inTranslatedCompositeNode() && node.getTransform() != null) { + public CompositeBehavior enterCompositeTransform(TransformTreeNode node) { + if (node.getTransform() != null) { @SuppressWarnings("unchecked") Class> transformClass = (Class>) node.getTransform().getClass(); if (translator.hasTranslation(transformClass)) { LOG.info("Entering directly-translatable composite transform: '{}'", node.getFullName()); LOG.debug("Composite transform class: '{}'", transformClass); - currentTranslatedCompositeNode = node; + doVisitTransform(node); + return CompositeBehavior.DO_NOT_ENTER_TRANSFORM; } } + return CompositeBehavior.ENTER_TRANSFORM; } @Override - public void leaveCompositeTransform(TransformTreeNode node) { - // NB: We depend on enterCompositeTransform and leaveCompositeTransform providing 'node' - // objects for which Object.equals() returns true iff they are the same logical node - // within the tree. - if (inTranslatedCompositeNode() && node.equals(currentTranslatedCompositeNode)) { - LOG.info("Post-visiting directly-translatable composite transform: '{}'", - node.getFullName()); - doVisitTransform(node); - currentTranslatedCompositeNode = null; - } - } - - @Override - public void visitTransform(TransformTreeNode node) { - if (inTranslatedCompositeNode()) { - LOG.info("Skipping '{}'; already in composite transform.", node.getFullName()); - return; - } + public void visitPrimitiveTransform(TransformTreeNode node) { doVisitTransform(node); } protected abstract > void doVisitTransform(TransformTreeNode node); - - @Override - public void visitValue(PValue value, TransformTreeNode producer) { - } } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java index 65a0755c41f2..4e7e63f0304d 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java @@ -220,8 +220,10 @@ public interface PipelineVisitor { /** * Called for each composite transform after all topological predecessors have been visited * but before any of its component transforms. + * + *

The return value controls whether or not child transforms are visited. */ - public void enterCompositeTransform(TransformTreeNode node); + public CompositeBehavior enterCompositeTransform(TransformTreeNode node); /** * Called for each composite transform after all of its component transforms and their outputs @@ -233,13 +235,42 @@ public interface PipelineVisitor { * Called for each primitive transform after all of its topological predecessors * and inputs have been visited. */ - public void visitTransform(TransformTreeNode node); + public void visitPrimitiveTransform(TransformTreeNode node); /** * Called for each value after the transform that produced the value has been * visited. */ public void visitValue(PValue value, TransformTreeNode producer); + + /** + * Control enum for indicating whether or not a traversal should process the contents of + * a composite transform or not. + */ + public enum CompositeBehavior { + ENTER_TRANSFORM, + DO_NOT_ENTER_TRANSFORM; + } + + /** + * Default no-op {@link PipelineVisitor} that enters all composite transforms. + * User implementations can override just those methods they are interested in. + */ + public class Defaults implements PipelineVisitor { + @Override + public CompositeBehavior enterCompositeTransform(TransformTreeNode node) { + return CompositeBehavior.ENTER_TRANSFORM; + } + + @Override + public void leaveCompositeTransform(TransformTreeNode node) { } + + @Override + public void visitPrimitiveTransform(TransformTreeNode node) { } + + @Override + public void visitValue(PValue value, TransformTreeNode producer) { } + } } /** diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AggregatorPipelineExtractor.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AggregatorPipelineExtractor.java index 86a851faade3..146ddfa42313 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AggregatorPipelineExtractor.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AggregatorPipelineExtractor.java @@ -56,7 +56,7 @@ public AggregatorPipelineExtractor(Pipeline pipeline) { return aggregatorSteps.asMap(); } - private static class AggregatorVisitor implements PipelineVisitor { + private static class AggregatorVisitor extends PipelineVisitor.Defaults { private final SetMultimap, PTransform> aggregatorSteps; public AggregatorVisitor(SetMultimap, PTransform> aggregatorSteps) { @@ -64,13 +64,7 @@ public AggregatorVisitor(SetMultimap, PTransform> aggrega } @Override - public void enterCompositeTransform(TransformTreeNode node) {} - - @Override - public void leaveCompositeTransform(TransformTreeNode node) {} - - @Override - public void visitTransform(TransformTreeNode node) { + public void visitPrimitiveTransform(TransformTreeNode node) { PTransform transform = node.getTransform(); addStepToAggregators(transform, getAggregators(transform)); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/DirectPipelineRunner.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/DirectPipelineRunner.java index 3cb970300c1d..590ce6fa8139 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/DirectPipelineRunner.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/DirectPipelineRunner.java @@ -828,7 +828,7 @@ T ensureSerializableByCoder(Coder coder, ///////////////////////////////////////////////////////////////////////////// - class Evaluator implements PipelineVisitor, EvaluationContext { + class Evaluator extends PipelineVisitor.Defaults implements EvaluationContext { /** * A map from PTransform to the step name of that transform. This is the internal name for the * transform (e.g. "s2"). @@ -881,15 +881,7 @@ public OutputT getOutput(PTransform transf } @Override - public void enterCompositeTransform(TransformTreeNode node) { - } - - @Override - public void leaveCompositeTransform(TransformTreeNode node) { - } - - @Override - public void visitTransform(TransformTreeNode node) { + public void visitPrimitiveTransform(TransformTreeNode node) { PTransform transform = node.getTransform(); fullNames.put(transform, node.getFullName()); TransformEvaluator evaluator = diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/RecordingPipelineVisitor.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/RecordingPipelineVisitor.java index 84df5fdea30f..d64738f5b924 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/RecordingPipelineVisitor.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/RecordingPipelineVisitor.java @@ -30,21 +30,13 @@ * *

Provided for internal unit tests. */ -public class RecordingPipelineVisitor implements Pipeline.PipelineVisitor { +public class RecordingPipelineVisitor extends Pipeline.PipelineVisitor.Defaults { public final List> transforms = new ArrayList<>(); public final List values = new ArrayList<>(); @Override - public void enterCompositeTransform(TransformTreeNode node) { - } - - @Override - public void leaveCompositeTransform(TransformTreeNode node) { - } - - @Override - public void visitTransform(TransformTreeNode node) { + public void visitPrimitiveTransform(TransformTreeNode node) { transforms.add(node.getTransform()); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformTreeNode.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformTreeNode.java index a6efc51825cb..59edd52cdc02 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformTreeNode.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformTreeNode.java @@ -17,7 +17,8 @@ */ package org.apache.beam.sdk.runners; -import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.Pipeline.PipelineVisitor; +import org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; @@ -198,7 +199,7 @@ public Collection getExpandedOutputs() { * transform (or child nodes for composite transforms), then the * output values. */ - public void visit(Pipeline.PipelineVisitor visitor, + public void visit(PipelineVisitor visitor, Set visitedValues) { if (!finishedSpecifying) { finishSpecifying(); @@ -212,13 +213,16 @@ public void visit(Pipeline.PipelineVisitor visitor, } if (isCompositeNode()) { - visitor.enterCompositeTransform(this); - for (TransformTreeNode child : parts) { - child.visit(visitor, visitedValues); + PipelineVisitor.CompositeBehavior recurse = visitor.enterCompositeTransform(this); + + if (recurse.equals(CompositeBehavior.ENTER_TRANSFORM)) { + for (TransformTreeNode child : parts) { + child.visit(visitor, visitedValues); + } } visitor.leaveCompositeTransform(this); } else { - visitor.visitTransform(this); + visitor.visitPrimitiveTransform(this); } // Visit outputs. diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/AggregatorPipelineExtractorTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/AggregatorPipelineExtractorTest.java index 7950a9e82020..74cc5e016122 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/AggregatorPipelineExtractorTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/AggregatorPipelineExtractorTest.java @@ -205,7 +205,7 @@ public VisitNodesAnswer(List nodes) { public Object answer(InvocationOnMock invocation) throws Throwable { PipelineVisitor visitor = (PipelineVisitor) invocation.getArguments()[0]; for (TransformTreeNode node : nodes) { - visitor.visitTransform(node); + visitor.visitPrimitiveTransform(node); } return null; } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/TransformTreeTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/TransformTreeTest.java index e4eb2048be20..aecebd7353d9 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/TransformTreeTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/TransformTreeTest.java @@ -40,7 +40,6 @@ import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionList; import org.apache.beam.sdk.values.PDone; -import org.apache.beam.sdk.values.PValue; import org.junit.Rule; import org.junit.Test; @@ -128,9 +127,9 @@ public void testCompositeCapture() throws Exception { final EnumSet left = EnumSet.noneOf(TransformsSeen.class); - p.traverseTopologically(new Pipeline.PipelineVisitor() { + p.traverseTopologically(new Pipeline.PipelineVisitor.Defaults() { @Override - public void enterCompositeTransform(TransformTreeNode node) { + public CompositeBehavior enterCompositeTransform(TransformTreeNode node) { PTransform transform = node.getTransform(); if (transform instanceof Sample.SampleAny) { assertTrue(visited.add(TransformsSeen.SAMPLE_ANY)); @@ -142,6 +141,7 @@ public void enterCompositeTransform(TransformTreeNode node) { assertTrue(node.isCompositeNode()); } assertThat(transform, not(instanceOf(Read.Bounded.class))); + return CompositeBehavior.ENTER_TRANSFORM; } @Override @@ -153,7 +153,7 @@ public void leaveCompositeTransform(TransformTreeNode node) { } @Override - public void visitTransform(TransformTreeNode node) { + public void visitPrimitiveTransform(TransformTreeNode node) { PTransform transform = node.getTransform(); // Pick is a composite, should not be visited here. assertThat(transform, not(instanceOf(Sample.SampleAny.class))); @@ -163,10 +163,6 @@ public void visitTransform(TransformTreeNode node) { assertTrue(visited.add(TransformsSeen.READ)); } } - - @Override - public void visitValue(PValue value, TransformTreeNode producer) { - } }); assertTrue(visited.equals(EnumSet.allOf(TransformsSeen.class)));