diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineExecutionEnvironment.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineExecutionEnvironment.java index 7a6c61f8b36d..7f7281e14bd9 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineExecutionEnvironment.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineExecutionEnvironment.java @@ -93,14 +93,15 @@ public void translate(FlinkRunner flinkRunner, Pipeline pipeline) { throw new RuntimeException(e); } - pipeline.replaceAll(FlinkTransformOverrides.getDefaultOverrides(options.isStreaming())); - PipelineTranslationOptimizer optimizer = new PipelineTranslationOptimizer(TranslationMode.BATCH, options); optimizer.translate(pipeline); TranslationMode translationMode = optimizer.getTranslationMode(); + pipeline.replaceAll(FlinkTransformOverrides.getDefaultOverrides( + translationMode == TranslationMode.STREAMING)); + FlinkPipelineTranslator translator; if (translationMode == TranslationMode.STREAMING) { this.flinkStreamEnv = createStreamExecutionEnvironment(); diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java index 811c15940c1f..a2923a97cc4f 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java @@ -58,7 +58,6 @@ import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VoidCoder; import org.apache.beam.sdk.io.BoundedSource; -import org.apache.beam.sdk.io.Read; import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.transforms.CombineFnBase.GlobalCombineFn; @@ -253,7 +252,7 @@ void translateNode( if (context.getOutput(transform).isBounded().equals(PCollection.IsBounded.BOUNDED)) { boundedTranslator.translateNode(transform, context); } else { - unboundedTranslator.translateNode((Read.Unbounded) transform, context); + unboundedTranslator.translateNode(transform, context); } } } diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/PipelineTranslationOptimizer.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/PipelineTranslationOptimizer.java index 3acc3eafca13..8877f1a044ac 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/PipelineTranslationOptimizer.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/PipelineTranslationOptimizer.java @@ -17,9 +17,11 @@ */ package org.apache.beam.runners.flink; -import org.apache.beam.sdk.io.Read; +import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.runners.TransformHierarchy; import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollection.IsBounded; import org.apache.beam.sdk.values.PValue; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -60,13 +62,21 @@ public void leaveCompositeTransform(TransformHierarchy.Node node) {} @Override public void visitPrimitiveTransform(TransformHierarchy.Node node) { - Class transformClass = node.getTransform().getClass(); - if (transformClass == Read.Unbounded.class) { + AppliedPTransform appliedPTransform = node.toAppliedPTransform(getPipeline()); + if (hasUnboundedOutput(appliedPTransform)) { + Class transformClass = node.getTransform().getClass(); LOG.info("Found {}. Switching to streaming execution.", transformClass); translationMode = TranslationMode.STREAMING; } } + private boolean hasUnboundedOutput(AppliedPTransform transform) { + return transform.getOutputs().values().stream() + .filter(value -> value instanceof PCollection) + .map(value -> (PCollection) value) + .anyMatch(collection -> collection.isBounded() == IsBounded.UNBOUNDED); + } + @Override public void visitValue(PValue value, TransformHierarchy.Node producer) {} } diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkPipelineExecutionEnvironmentTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkPipelineExecutionEnvironmentTest.java new file mode 100644 index 000000000000..0e5ce144135e --- /dev/null +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkPipelineExecutionEnvironmentTest.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import java.io.Serializable; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.io.GenerateSequence; +import org.apache.beam.sdk.io.TextIO; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.windowing.FixedWindows; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.joda.time.Duration; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link FlinkPipelineExecutionEnvironment}. + */ +@RunWith(JUnit4.class) +public class FlinkPipelineExecutionEnvironmentTest implements Serializable { + + @Test + public void shouldRecognizeAndTranslateStreamingPipeline() { + FlinkPipelineOptions options = PipelineOptionsFactory.as(FlinkPipelineOptions.class); + options.setRunner(TestFlinkRunner.class); + options.setFlinkMaster("[auto]"); + + FlinkRunner flinkRunner = FlinkRunner.fromOptions(options); + FlinkPipelineExecutionEnvironment flinkEnv = new FlinkPipelineExecutionEnvironment(options); + Pipeline pipeline = Pipeline.create(); + + pipeline + .apply(GenerateSequence.from(0).withRate(1, Duration.standardSeconds(1))) + .apply(ParDo.of(new DoFn() { + @ProcessElement + public void processElement(ProcessContext c) throws Exception { + c.output(Long.toString(c.element())); + } + })) + .apply(Window.into(FixedWindows.of(Duration.standardHours(1)))) + .apply(TextIO.write().withNumShards(1).withWindowedWrites().to("/dummy/path")); + + flinkEnv.translate(flinkRunner, pipeline); + + // no exception should be thrown + } + +} + +