diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkPipelineRunner.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkPipelineRunner.java index a389d7a076c8..bb3acc798a64 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkPipelineRunner.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkPipelineRunner.java @@ -23,7 +23,9 @@ import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.options.PipelineOptionsValidator; import org.apache.beam.sdk.runners.PipelineRunner; +import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.util.GroupByKeyViaGroupByKeyOnly; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; @@ -157,6 +159,12 @@ public static FlinkPipelineRunner createForTest(boolean streaming) { @Override public Output apply( PTransform transform, Input input) { + + // In batch mode, expand GroupByKey to GroupByKeyOnly -> GroupAlsoByWindow + if (!options.isStreaming() && transform.getClass().equals(GroupByKey.class)) { + return (Output) super.apply(new GroupByKeyViaGroupByKeyOnly((GroupByKey) transform), input); + } + return super.apply(transform, input); } diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTransformTranslators.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTransformTranslators.java index a03352efae15..3b1f51866add 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTransformTranslators.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTransformTranslators.java @@ -43,7 +43,6 @@ import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.Flatten; -import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.View; @@ -52,6 +51,12 @@ import org.apache.beam.sdk.transforms.join.CoGroupByKey; import org.apache.beam.sdk.transforms.join.KeyedPCollectionTuple; import org.apache.beam.sdk.transforms.join.RawUnionValue; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.transforms.windowing.WindowFn; +import org.apache.beam.sdk.util.GroupByKeyViaGroupByKeyOnly; +import org.apache.beam.sdk.util.GroupByKeyViaGroupByKeyOnly.GroupByKeyOnly; +import org.apache.beam.sdk.util.WindowingStrategy; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; @@ -78,12 +83,14 @@ import org.apache.flink.api.java.operators.MapPartitionOperator; import org.apache.flink.api.java.operators.UnsortedGrouping; import org.apache.flink.core.fs.Path; +import org.joda.time.Instant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.lang.reflect.Field; +import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -114,8 +121,9 @@ public class FlinkBatchTransformTranslators { TRANSLATORS.put(Flatten.FlattenPCollectionList.class, new FlattenPCollectionTranslatorBatch()); - // TODO we're currently ignoring windows here but that has to change in the future - TRANSLATORS.put(GroupByKey.class, new GroupByKeyTranslatorBatch()); + TRANSLATORS.put(GroupByKeyViaGroupByKeyOnly.GroupByKeyOnly.class, new GroupByKeyOnlyTranslatorBatch()); + + TRANSLATORS.put(Window.Bound.class, new WindowBoundTranslatorBatch()); TRANSLATORS.put(ParDo.BoundMulti.class, new ParDoBoundMultiTranslatorBatch()); TRANSLATORS.put(ParDo.Bound.class, new ParDoBoundTranslatorBatch()); @@ -303,13 +311,64 @@ public void translateNode(Write.Bound transform, FlinkBatchTranslationContext } } + public static class WindowBoundTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { + + @Override + public void translateNode(Window.Bound transform, FlinkBatchTranslationContext context) { + PValue input = context.getInput(transform); + DataSet inputDataSet = context.getInputDataSet(input); + + @SuppressWarnings("unchecked") + final WindowingStrategy windowingStrategy = + (WindowingStrategy) + context.getOutput(transform).getWindowingStrategy(); + + final WindowFn windowFn = windowingStrategy.getWindowFn(); + FlinkDoFnFunction doFnWrapper = new FlinkDoFnFunction<>(createWindowAssigner(windowFn), context.getPipelineOptions()); + + TypeInformation typeInformation = context.getTypeInfo(context.getOutput(transform)); + MapPartitionOperator outputDataSet = new MapPartitionOperator(inputDataSet, typeInformation, doFnWrapper, transform.getName()); + + context.setOutputDataSet(context.getOutput(transform), outputDataSet); + } + + private static DoFn createWindowAssigner(final WindowFn windowFn) { + return new DoFn() { + + @Override + public void processElement(final ProcessContext c) throws Exception { + Collection windows = windowFn.assignWindows( + windowFn.new AssignContext() { + @Override + public T element() { + return c.element(); + } + + @Override + public Instant timestamp() { + return c.timestamp(); + } + + @Override + public Collection windows() { + return c.windowingInternals().windows(); + } + }); + + c.windowingInternals().outputWindowedValue( + c.element(), c.timestamp(), windows, c.pane()); + } + }; + } + } + /** - * Translates a GroupByKey while ignoring window assignments. Current ignores windows. + * Translates a {@link GroupByKeyOnly}, which ignores window assignments. */ - private static class GroupByKeyTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { + private static class GroupByKeyOnlyTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator> { @Override - public void translateNode(GroupByKey transform, FlinkBatchTranslationContext context) { + public void translateNode(GroupByKeyOnly transform, FlinkBatchTranslationContext context) { DataSet> inputDataSet = context.getInputDataSet(context.getInput(transform)); GroupReduceFunction, KV>> groupReduceFunction = new FlinkKeyedListAggregationFunction<>();