From cdb13cbb74398122f6029992e3f38c3f7108e5d3 Mon Sep 17 00:00:00 2001 From: jto Date: Tue, 23 Jul 2024 11:58:28 +0200 Subject: [PATCH 01/19] [Flink] Set return type of bounded sources --- .../runners/flink/FlinkStreamingTransformTranslators.java | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java index d2171d27a142..d243a88c5af4 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java @@ -389,6 +389,9 @@ public void translateNode( new SerializablePipelineOptions(context.getPipelineOptions()), parallelism); + TypeInformation> typeInfo = + context.getTypeInfo(output); + DataStream> source; try { source = @@ -396,7 +399,8 @@ public void translateNode( .getExecutionEnvironment() .fromSource( flinkBoundedSource, WatermarkStrategy.noWatermarks(), fullName, outputTypeInfo) - .uid(fullName); + .uid(fullName) + .returns(typeInfo); } catch (Exception e) { throw new RuntimeException("Error while translating BoundedSource: " + rawSource, e); } From d27e895315110e51327b7813a1a4a8fd525dddf9 Mon Sep 17 00:00:00 2001 From: jto Date: Tue, 23 Jul 2024 17:58:29 +0200 Subject: [PATCH 02/19] [Flink] Use a lazy split enumerator for bounded sources [Flink] fix lazy enumerator package --- .../streaming/io/source/FlinkSource.java | 21 +- .../LazyFlinkSourceSplitEnumerator.java | 180 ++++++++++++++++++ .../bounded/FlinkBoundedSourceReader.java | 6 + 3 files changed, 202 insertions(+), 5 deletions(-) create mode 100644 runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/LazyFlinkSourceSplitEnumerator.java diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/FlinkSource.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/FlinkSource.java index 506b651da68f..3e5d68df1df7 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/FlinkSource.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/FlinkSource.java @@ -118,8 +118,20 @@ public Boundedness getBoundedness() { @Override public SplitEnumerator, Map>>> createEnumerator(SplitEnumeratorContext> enumContext) throws Exception { - return new FlinkSourceSplitEnumerator<>( - enumContext, beamSource, serializablePipelineOptions.get(), numSplits); + + if(boundedness == Boundedness.BOUNDED) { + return new LazyFlinkSourceSplitEnumerator<>( + enumContext, + beamSource, + serializablePipelineOptions.get(), + numSplits); + } else { + return new FlinkSourceSplitEnumerator<>( + enumContext, + beamSource, + serializablePipelineOptions.get(), + numSplits); + } } @Override @@ -128,9 +140,8 @@ public Boundedness getBoundedness() { SplitEnumeratorContext> enumContext, Map>> checkpoint) throws Exception { - FlinkSourceSplitEnumerator enumerator = - new FlinkSourceSplitEnumerator<>( - enumContext, beamSource, serializablePipelineOptions.get(), numSplits); + SplitEnumerator, Map>>> enumerator = + createEnumerator(enumContext); checkpoint.forEach( (subtaskId, splitsForSubtask) -> enumerator.addSplitsBack(splitsForSubtask, subtaskId)); return enumerator; diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/LazyFlinkSourceSplitEnumerator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/LazyFlinkSourceSplitEnumerator.java new file mode 100644 index 000000000000..fdd14025a95a --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/LazyFlinkSourceSplitEnumerator.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming.io.source; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import javax.annotation.Nullable; + +import org.apache.beam.runners.flink.FlinkPipelineOptions; +import org.apache.beam.runners.flink.translation.wrappers.streaming.io.source.FlinkSourceSplit; +import org.apache.beam.runners.flink.translation.wrappers.streaming.io.source.FlinkSourceSplitEnumerator; +import org.apache.beam.runners.flink.translation.wrappers.streaming.io.source.compat.SplitEnumeratorCompat; +import org.apache.beam.sdk.io.BoundedSource; +import org.apache.beam.sdk.io.FileBasedSource; +import org.apache.beam.sdk.io.Source; +import org.apache.beam.sdk.io.UnboundedSource; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.flink.api.connector.source.SplitEnumeratorContext; +import org.apache.flink.api.connector.source.SplitsAssignment; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A Flink {@link org.apache.flink.api.connector.source.SplitEnumerator SplitEnumerator} + * implementation that holds a Beam {@link Source} and does the following: + * + *
    + *
  • Split the Beam {@link Source} to desired number of splits. + *
  • Lazily assign the splits to the Flink Source Reader. + *
+ * + * @param The output type of the encapsulated Beam {@link Source}. + */ +public class LazyFlinkSourceSplitEnumerator + implements SplitEnumeratorCompat, Map>>> { + private static final Logger LOG = LoggerFactory.getLogger(LazyFlinkSourceSplitEnumerator.class); + private final SplitEnumeratorContext> context; + private final Source beamSource; + private final PipelineOptions pipelineOptions; + private final int numSplits; + private final List> pendingSplits; + + public LazyFlinkSourceSplitEnumerator( + SplitEnumeratorContext> context, + Source beamSource, + PipelineOptions pipelineOptions, + int numSplits) { + this.context = context; + this.beamSource = beamSource; + this.pipelineOptions = pipelineOptions; + this.numSplits = numSplits; + this.pendingSplits = new ArrayList<>(numSplits); + } + + @Override + public void start() { + try { + LOG.info("Starting source {}", beamSource); + List> beamSplitSourceList = splitBeamSource(); + int i = 0; + for (Source beamSplitSource : beamSplitSourceList) { + pendingSplits.add(new FlinkSourceSplit<>(i, beamSplitSource)); + i++; + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public void handleSplitRequest(int subtask, @Nullable String hostname) { + if (!context.registeredReaders().containsKey(subtask)) { + // reader failed between sending the request and now. skip this request. + return; + } + + if (LOG.isInfoEnabled()) { + final String hostInfo = + hostname == null ? "(no host locality info)" : "(on host '" + hostname + "')"; + LOG.info("Subtask {} {} is requesting a file source split", subtask, hostInfo); + } + + if (!pendingSplits.isEmpty()) { + final FlinkSourceSplit split = pendingSplits.remove(pendingSplits.size() - 1); + context.assignSplit(split, subtask); + LOG.info("Assigned split to subtask {} : {}", subtask, split); + } else { + context.signalNoMoreSplits(subtask); + LOG.info("No more splits available for subtask {}", subtask); + } + } + + @Override + public void addSplitsBack(List> splits, int subtaskId) { + LOG.info("Adding splits {} back from subtask {}", splits, subtaskId); + pendingSplits.addAll(splits); + } + + @Override + public void addReader(int subtaskId) { + // this source is purely lazy-pull-based, nothing to do upon registration + } + + @Override + public Map>> snapshotState(long checkpointId) throws Exception { + LOG.info("Taking snapshot for checkpoint {}", checkpointId); + return snapshotState(); + } + + @Override + public Map>> snapshotState() throws Exception { + // For type compatibility reasons, we return a Map but we do not actually care about the key + Map>> state = new HashMap<>(1); + state.put(1, pendingSplits); + return state; + } + + @Override + public void close() throws IOException { + // NoOp + } + + private long getDesiredSizeBytes(int numSplits, BoundedSource boundedSource) throws Exception { + long totalSize = boundedSource.getEstimatedSizeBytes(pipelineOptions); + long defaultSplitSize = totalSize / numSplits; + long maxSplitSize = 0; + if (pipelineOptions != null) { + maxSplitSize = pipelineOptions.as(FlinkPipelineOptions.class).getFileInputSplitMaxSizeMB(); + } + if (beamSource instanceof FileBasedSource && maxSplitSize > 0) { + // Most of the time parallelism is < number of files in source. + // Each file becomes a unique split which commonly create skew. + // This limits the size of splits to reduce skew. + return Math.min(defaultSplitSize, maxSplitSize * 1024 * 1024); + } else { + return defaultSplitSize; + } + } + + // -------------- Private helper methods ---------------------- + private List> splitBeamSource() throws Exception { + if (beamSource instanceof BoundedSource) { + BoundedSource boundedSource = (BoundedSource) beamSource; + long desiredSizeBytes = getDesiredSizeBytes(numSplits, boundedSource); + List> splits = + ((BoundedSource) beamSource).split(desiredSizeBytes, pipelineOptions); + LOG.info("Split bounded source {} in {} splits", beamSource, splits.size()); + return splits; + } else if (beamSource instanceof UnboundedSource) { + List> splits = + ((UnboundedSource) beamSource).split(numSplits, pipelineOptions); + LOG.info("Split source {} to {} splits", beamSource, splits); + return splits; + } else { + throw new IllegalStateException("Unknown source type " + beamSource.getClass()); + } + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/bounded/FlinkBoundedSourceReader.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/bounded/FlinkBoundedSourceReader.java index e4bd4496ae90..d87d84d93dc2 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/bounded/FlinkBoundedSourceReader.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/bounded/FlinkBoundedSourceReader.java @@ -100,6 +100,11 @@ protected FlinkBoundedSourceReader( @Override public InputStatus pollNext(ReaderOutput> output) throws Exception { checkExceptionAndMaybeThrow(); + + if(currentReader == null && currentSplitId == -1) { + context.sendSplitRequest(); + } + if (currentReader == null && !moveToNextNonEmptyReader()) { // Nothing to read for now. if (noMoreSplits()) { @@ -137,6 +142,7 @@ public InputStatus pollNext(ReaderOutput> output) throws Except LOG.debug("Finished reading from {}", currentSplitId); currentReader = null; currentSplitId = -1; + context.sendSplitRequest(); } // Always return MORE_AVAILABLE here regardless of the availability of next record. If there // is no more From a3267cb3b88b8e3da754e4f7b20885f23fca2c22 Mon Sep 17 00:00:00 2001 From: jto Date: Mon, 19 Aug 2024 10:01:41 +0200 Subject: [PATCH 03/19] [Flink] Default to maxParallelism = parallelism in batch --- .../beam/runners/flink/FlinkExecutionEnvironments.java | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkExecutionEnvironments.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkExecutionEnvironments.java index 7c1bc87ced03..67a091b46ff4 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkExecutionEnvironments.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkExecutionEnvironments.java @@ -237,6 +237,13 @@ public static StreamExecutionEnvironment createStreamExecutionEnvironment( flinkStreamEnv.setParallelism(parallelism); if (options.getMaxParallelism() > 0) { flinkStreamEnv.setMaxParallelism(options.getMaxParallelism()); + } else if(!options.isStreaming()) { + // In Flink maxParallelism defines the number of keyGroups. + // (see https://github.com/apache/flink/blob/e9dd4683f758b463d0b5ee18e49cecef6a70c5cf/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupRangeAssignment.java#L76) + // The default value (parallelism * 1.5) + // (see https://github.com/apache/flink/blob/e9dd4683f758b463d0b5ee18e49cecef6a70c5cf/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupRangeAssignment.java#L137-L147) + // create a lot of skew so we force maxParallelism = parallelism in Batch mode. + flinkStreamEnv.setMaxParallelism(parallelism); } // set parallelism in the options (required by some execution code) options.setParallelism(parallelism); From d531300137a4d0a5bc28c866fc9352fbb98080a8 Mon Sep 17 00:00:00 2001 From: jto Date: Tue, 20 Aug 2024 11:43:48 +0200 Subject: [PATCH 04/19] [Flink] Avoid re-serializing trigger on every element --- .../core/GroupAlsoByWindowViaWindowSetNewDoFn.java | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/GroupAlsoByWindowViaWindowSetNewDoFn.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/GroupAlsoByWindowViaWindowSetNewDoFn.java index 0759487565b0..853a182b2ca0 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/GroupAlsoByWindowViaWindowSetNewDoFn.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/GroupAlsoByWindowViaWindowSetNewDoFn.java @@ -18,6 +18,8 @@ package org.apache.beam.runners.core; import java.util.Collection; + +import org.apache.beam.model.pipeline.v1.RunnerApi; import org.apache.beam.runners.core.triggers.ExecutableTriggerStateMachine; import org.apache.beam.runners.core.triggers.TriggerStateMachines; import org.apache.beam.sdk.transforms.DoFn; @@ -41,6 +43,7 @@ public class GroupAlsoByWindowViaWindowSetNewDoFn< extends DoFn> { private static final long serialVersionUID = 1L; + private final RunnerApi.Trigger triggerProto; public static DoFn, KV> create( @@ -86,6 +89,7 @@ public GroupAlsoByWindowViaWindowSetNewDoFn( this.windowingStrategy = noWildcard; this.reduceFn = reduceFn; this.stateInternalsFactory = stateInternalsFactory; + this.triggerProto = TriggerTranslation.toProto(windowingStrategy.getTrigger()); } private OutputWindowedValue> outputWindowedValue() { @@ -123,9 +127,7 @@ public void processElement(ProcessContext c) throws Exception { new ReduceFnRunner<>( key, windowingStrategy, - ExecutableTriggerStateMachine.create( - TriggerStateMachines.stateMachineForTrigger( - TriggerTranslation.toProto(windowingStrategy.getTrigger()))), + ExecutableTriggerStateMachine.create(TriggerStateMachines.stateMachineForTrigger(triggerProto)), stateInternals, timerInternals, outputWindowedValue(), From 5e55454ce0211b4727a900af634d6bf53911359c Mon Sep 17 00:00:00 2001 From: jto Date: Tue, 20 Aug 2024 11:23:50 +0200 Subject: [PATCH 05/19] [Flink] Avoid re-evaluating options every time a new state is stored --- .../types/CoderTypeSerializer.java | 19 ++--- .../types/CoderTypeSerializer.java | 19 ++--- .../streaming/state/FlinkStateInternals.java | 75 ++++++++++--------- 3 files changed, 50 insertions(+), 63 deletions(-) diff --git a/runners/flink/1.14/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java b/runners/flink/1.14/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java index 956aad428d8b..6c21ea8edc00 100644 --- a/runners/flink/1.14/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java +++ b/runners/flink/1.14/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java @@ -50,23 +50,16 @@ public class CoderTypeSerializer extends TypeSerializer { private final Coder coder; - /** - * {@link SerializablePipelineOptions} deserialization will cause {@link - * org.apache.beam.sdk.io.FileSystems} registration needed for {@link - * org.apache.beam.sdk.transforms.Reshuffle} translation. - */ - private final SerializablePipelineOptions pipelineOptions; - private final boolean fasterCopy; public CoderTypeSerializer(Coder coder, SerializablePipelineOptions pipelineOptions) { + this(coder, Preconditions.checkNotNull(pipelineOptions).get().as(FlinkPipelineOptions.class).getFasterCopy()); + } + + public CoderTypeSerializer(Coder coder, boolean fasterCopy) { Preconditions.checkNotNull(coder); - Preconditions.checkNotNull(pipelineOptions); this.coder = coder; - this.pipelineOptions = pipelineOptions; - - FlinkPipelineOptions options = pipelineOptions.get().as(FlinkPipelineOptions.class); - this.fasterCopy = options.getFasterCopy(); + this.fasterCopy = fasterCopy; } @Override @@ -76,7 +69,7 @@ public boolean isImmutableType() { @Override public CoderTypeSerializer duplicate() { - return new CoderTypeSerializer<>(coder, pipelineOptions); + return new CoderTypeSerializer<>(coder, fasterCopy); } @Override diff --git a/runners/flink/1.17/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java b/runners/flink/1.17/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java index 0f87271a9779..911dd3185adf 100644 --- a/runners/flink/1.17/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java +++ b/runners/flink/1.17/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java @@ -47,23 +47,16 @@ public class CoderTypeSerializer extends TypeSerializer { private final Coder coder; - /** - * {@link SerializablePipelineOptions} deserialization will cause {@link - * org.apache.beam.sdk.io.FileSystems} registration needed for {@link - * org.apache.beam.sdk.transforms.Reshuffle} translation. - */ - private final SerializablePipelineOptions pipelineOptions; - private final boolean fasterCopy; public CoderTypeSerializer(Coder coder, SerializablePipelineOptions pipelineOptions) { + this(coder, Preconditions.checkNotNull(pipelineOptions).get().as(FlinkPipelineOptions.class).getFasterCopy()); + } + + public CoderTypeSerializer(Coder coder, boolean fasterCopy) { Preconditions.checkNotNull(coder); - Preconditions.checkNotNull(pipelineOptions); this.coder = coder; - this.pipelineOptions = pipelineOptions; - - FlinkPipelineOptions options = pipelineOptions.get().as(FlinkPipelineOptions.class); - this.fasterCopy = options.getFasterCopy(); + this.fasterCopy = fasterCopy; } @Override @@ -73,7 +66,7 @@ public boolean isImmutableType() { @Override public CoderTypeSerializer duplicate() { - return new CoderTypeSerializer<>(coder, pipelineOptions); + return new CoderTypeSerializer<>(coder, fasterCopy); } @Override diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java index 205270c22332..bb662669179d 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java @@ -33,6 +33,7 @@ import org.apache.beam.runners.core.StateNamespaces; import org.apache.beam.runners.core.StateTag; import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.runners.flink.FlinkPipelineOptions; import org.apache.beam.runners.flink.translation.types.CoderTypeSerializer; import org.apache.beam.runners.flink.translation.wrappers.streaming.FlinkKeyUtils; import org.apache.beam.sdk.coders.Coder; @@ -162,7 +163,7 @@ public String toString() { // State to persist combined watermark holds for all keys of this partition private final MapStateDescriptor watermarkHoldStateDescriptor; - private final SerializablePipelineOptions pipelineOptions; + private final boolean fasterCopy; public FlinkStateInternals( KeyedStateBackend flinkStateBackend, @@ -171,13 +172,13 @@ public FlinkStateInternals( throws Exception { this.flinkStateBackend = Objects.requireNonNull(flinkStateBackend); this.keyCoder = Objects.requireNonNull(keyCoder); + this.fasterCopy = pipelineOptions.get().as(FlinkPipelineOptions.class).getFasterCopy(); + watermarkHoldStateDescriptor = new MapStateDescriptor<>( "watermark-holds", StringSerializer.INSTANCE, - new CoderTypeSerializer<>(InstantCoder.of(), pipelineOptions)); - this.pipelineOptions = pipelineOptions; - + new CoderTypeSerializer<>(InstantCoder.of(), fasterCopy)); restoreWatermarkHoldsView(); } @@ -241,7 +242,7 @@ private FlinkStateBinder(StateNamespace namespace, StateContext stateContext) public ValueState bindValue( String id, StateSpec> spec, Coder coder) { FlinkValueState valueState = - new FlinkValueState<>(flinkStateBackend, id, namespace, coder, pipelineOptions); + new FlinkValueState<>(flinkStateBackend, id, namespace, coder, fasterCopy); collectGlobalWindowStateDescriptor( valueState.flinkStateDescriptor, valueState.namespace.stringKey(), @@ -252,7 +253,7 @@ public ValueState bindValue( @Override public BagState bindBag(String id, StateSpec> spec, Coder elemCoder) { FlinkBagState bagState = - new FlinkBagState<>(flinkStateBackend, id, namespace, elemCoder, pipelineOptions); + new FlinkBagState<>(flinkStateBackend, id, namespace, elemCoder, fasterCopy); collectGlobalWindowStateDescriptor( bagState.flinkStateDescriptor, bagState.namespace.stringKey(), StringSerializer.INSTANCE); return bagState; @@ -261,7 +262,7 @@ public BagState bindBag(String id, StateSpec> spec, Coder< @Override public SetState bindSet(String id, StateSpec> spec, Coder elemCoder) { FlinkSetState setState = - new FlinkSetState<>(flinkStateBackend, id, namespace, elemCoder, pipelineOptions); + new FlinkSetState<>(flinkStateBackend, id, namespace, elemCoder, fasterCopy); collectGlobalWindowStateDescriptor( setState.flinkStateDescriptor, setState.namespace.stringKey(), StringSerializer.INSTANCE); return setState; @@ -275,7 +276,7 @@ public MapState bindMap( Coder mapValueCoder) { FlinkMapState mapState = new FlinkMapState<>( - flinkStateBackend, id, namespace, mapKeyCoder, mapValueCoder, pipelineOptions); + flinkStateBackend, id, namespace, mapKeyCoder, mapValueCoder, fasterCopy); collectGlobalWindowStateDescriptor( mapState.flinkStateDescriptor, mapState.namespace.stringKey(), StringSerializer.INSTANCE); return mapState; @@ -285,7 +286,7 @@ public MapState bindMap( public OrderedListState bindOrderedList( String id, StateSpec> spec, Coder elemCoder) { FlinkOrderedListState flinkOrderedListState = - new FlinkOrderedListState<>(flinkStateBackend, id, namespace, elemCoder, pipelineOptions); + new FlinkOrderedListState<>(flinkStateBackend, id, namespace, elemCoder, fasterCopy); collectGlobalWindowStateDescriptor( flinkOrderedListState.flinkStateDescriptor, flinkOrderedListState.namespace.stringKey(), @@ -311,7 +312,7 @@ public CombiningState bindCom Combine.CombineFn combineFn) { FlinkCombiningState combiningState = new FlinkCombiningState<>( - flinkStateBackend, id, combineFn, namespace, accumCoder, pipelineOptions); + flinkStateBackend, id, combineFn, namespace, accumCoder, fasterCopy); collectGlobalWindowStateDescriptor( combiningState.flinkStateDescriptor, combiningState.namespace.stringKey(), @@ -334,7 +335,7 @@ CombiningState bindCombiningWithContext( namespace, accumCoder, CombineContextFactory.createFromStateContext(stateContext), - pipelineOptions); + fasterCopy); collectGlobalWindowStateDescriptor( combiningStateWithContext.flinkStateDescriptor, combiningStateWithContext.namespace.stringKey(), @@ -380,14 +381,14 @@ private static class FlinkValueState implements ValueState { String stateId, StateNamespace namespace, Coder coder, - SerializablePipelineOptions pipelineOptions) { + boolean fasterCopy) { this.namespace = namespace; this.stateId = stateId; this.flinkStateBackend = flinkStateBackend; flinkStateDescriptor = - new ValueStateDescriptor<>(stateId, new CoderTypeSerializer<>(coder, pipelineOptions)); + new ValueStateDescriptor<>(stateId, new CoderTypeSerializer<>(coder, fasterCopy)); } @Override @@ -463,12 +464,12 @@ private static class FlinkOrderedListState implements OrderedListState { String stateId, StateNamespace namespace, Coder coder, - SerializablePipelineOptions pipelineOptions) { + boolean fasterCopy) { this.namespace = namespace; this.flinkStateBackend = flinkStateBackend; this.flinkStateDescriptor = new ListStateDescriptor<>( - stateId, new CoderTypeSerializer<>(TimestampedValueCoder.of(coder), pipelineOptions)); + stateId, new CoderTypeSerializer<>(TimestampedValueCoder.of(coder), fasterCopy)); } @Override @@ -586,14 +587,14 @@ private static class FlinkBagState implements BagState { String stateId, StateNamespace namespace, Coder coder, - SerializablePipelineOptions pipelineOptions) { + boolean fasterCopy) { this.namespace = namespace; this.stateId = stateId; this.flinkStateBackend = flinkStateBackend; this.storesVoidValues = coder instanceof VoidCoder; this.flinkStateDescriptor = - new ListStateDescriptor<>(stateId, new CoderTypeSerializer<>(coder, pipelineOptions)); + new ListStateDescriptor<>(stateId, new CoderTypeSerializer<>(coder, fasterCopy)); } @Override @@ -726,7 +727,7 @@ private static class FlinkCombiningState Combine.CombineFn combineFn, StateNamespace namespace, Coder accumCoder, - SerializablePipelineOptions pipelineOptions) { + boolean fasterCopy) { this.namespace = namespace; this.stateId = stateId; @@ -735,7 +736,7 @@ private static class FlinkCombiningState flinkStateDescriptor = new ValueStateDescriptor<>( - stateId, new CoderTypeSerializer<>(accumCoder, pipelineOptions)); + stateId, new CoderTypeSerializer<>(accumCoder, fasterCopy)); } @Override @@ -891,7 +892,7 @@ private static class FlinkCombiningStateWithContext StateNamespace namespace, Coder accumCoder, CombineWithContext.Context context, - SerializablePipelineOptions pipelineOptions) { + boolean fasterCopy) { this.namespace = namespace; this.stateId = stateId; @@ -901,7 +902,7 @@ private static class FlinkCombiningStateWithContext flinkStateDescriptor = new ValueStateDescriptor<>( - stateId, new CoderTypeSerializer<>(accumCoder, pipelineOptions)); + stateId, new CoderTypeSerializer<>(accumCoder, fasterCopy)); } @Override @@ -1178,15 +1179,15 @@ private static class FlinkMapState implements MapState mapKeyCoder, Coder mapValueCoder, - SerializablePipelineOptions pipelineOptions) { + boolean fasterCopy) { this.namespace = namespace; this.stateId = stateId; this.flinkStateBackend = flinkStateBackend; this.flinkStateDescriptor = new MapStateDescriptor<>( stateId, - new CoderTypeSerializer<>(mapKeyCoder, pipelineOptions), - new CoderTypeSerializer<>(mapValueCoder, pipelineOptions)); + new CoderTypeSerializer<>(mapKeyCoder, fasterCopy), + new CoderTypeSerializer<>(mapValueCoder, fasterCopy)); } @Override @@ -1402,14 +1403,14 @@ private static class FlinkSetState implements SetState { String stateId, StateNamespace namespace, Coder coder, - SerializablePipelineOptions pipelineOptions) { + boolean fasterCopy) { this.namespace = namespace; this.stateId = stateId; this.flinkStateBackend = flinkStateBackend; this.flinkStateDescriptor = new MapStateDescriptor<>( stateId, - new CoderTypeSerializer<>(coder, pipelineOptions), + new CoderTypeSerializer<>(coder, fasterCopy), BooleanSerializer.INSTANCE); } @@ -1571,12 +1572,12 @@ private void restoreWatermarkHoldsView() throws Exception { public static class EarlyBinder implements StateBinder { private final KeyedStateBackend keyedStateBackend; - private final SerializablePipelineOptions pipelineOptions; + private final Boolean fasterCopy; public EarlyBinder( KeyedStateBackend keyedStateBackend, SerializablePipelineOptions pipelineOptions) { this.keyedStateBackend = keyedStateBackend; - this.pipelineOptions = pipelineOptions; + this.fasterCopy = pipelineOptions.get().as(FlinkPipelineOptions.class).getFasterCopy(); } @Override @@ -1584,7 +1585,7 @@ public ValueState bindValue(String id, StateSpec> spec, Cod try { keyedStateBackend.getOrCreateKeyedState( StringSerializer.INSTANCE, - new ValueStateDescriptor<>(id, new CoderTypeSerializer<>(coder, pipelineOptions))); + new ValueStateDescriptor<>(id, new CoderTypeSerializer<>(coder, fasterCopy))); } catch (Exception e) { throw new RuntimeException(e); } @@ -1597,7 +1598,7 @@ public BagState bindBag(String id, StateSpec> spec, Coder try { keyedStateBackend.getOrCreateKeyedState( StringSerializer.INSTANCE, - new ListStateDescriptor<>(id, new CoderTypeSerializer<>(elemCoder, pipelineOptions))); + new ListStateDescriptor<>(id, new CoderTypeSerializer<>(elemCoder, fasterCopy))); } catch (Exception e) { throw new RuntimeException(e); } @@ -1612,7 +1613,7 @@ public SetState bindSet(String id, StateSpec> spec, Coder StringSerializer.INSTANCE, new MapStateDescriptor<>( id, - new CoderTypeSerializer<>(elemCoder, pipelineOptions), + new CoderTypeSerializer<>(elemCoder, fasterCopy), BooleanSerializer.INSTANCE)); } catch (Exception e) { throw new RuntimeException(e); @@ -1631,8 +1632,8 @@ public org.apache.beam.sdk.state.MapState bindMap( StringSerializer.INSTANCE, new MapStateDescriptor<>( id, - new CoderTypeSerializer<>(mapKeyCoder, pipelineOptions), - new CoderTypeSerializer<>(mapValueCoder, pipelineOptions))); + new CoderTypeSerializer<>(mapKeyCoder, fasterCopy), + new CoderTypeSerializer<>(mapValueCoder, fasterCopy))); } catch (Exception e) { throw new RuntimeException(e); } @@ -1647,7 +1648,7 @@ public OrderedListState bindOrderedList( StringSerializer.INSTANCE, new ListStateDescriptor<>( id, - new CoderTypeSerializer<>(TimestampedValueCoder.of(elemCoder), pipelineOptions))); + new CoderTypeSerializer<>(TimestampedValueCoder.of(elemCoder), fasterCopy))); } catch (Exception e) { throw new RuntimeException(e); } @@ -1674,7 +1675,7 @@ public CombiningState bindCom try { keyedStateBackend.getOrCreateKeyedState( StringSerializer.INSTANCE, - new ValueStateDescriptor<>(id, new CoderTypeSerializer<>(accumCoder, pipelineOptions))); + new ValueStateDescriptor<>(id, new CoderTypeSerializer<>(accumCoder, fasterCopy))); } catch (Exception e) { throw new RuntimeException(e); } @@ -1691,7 +1692,7 @@ CombiningState bindCombiningWithContext( try { keyedStateBackend.getOrCreateKeyedState( StringSerializer.INSTANCE, - new ValueStateDescriptor<>(id, new CoderTypeSerializer<>(accumCoder, pipelineOptions))); + new ValueStateDescriptor<>(id, new CoderTypeSerializer<>(accumCoder, fasterCopy))); } catch (Exception e) { throw new RuntimeException(e); } @@ -1707,7 +1708,7 @@ public WatermarkHoldState bindWatermark( new MapStateDescriptor<>( "watermark-holds", StringSerializer.INSTANCE, - new CoderTypeSerializer<>(InstantCoder.of(), pipelineOptions))); + new CoderTypeSerializer<>(InstantCoder.of(), fasterCopy))); } catch (Exception e) { throw new RuntimeException(e); } From 2f3567081bf2d6a3852cdeabecf85002fbfa3ece Mon Sep 17 00:00:00 2001 From: jto Date: Wed, 21 Aug 2024 10:59:17 +0200 Subject: [PATCH 06/19] [Flink] Only serialize states namespace keys if necessary --- .../wrappers/streaming/DoFnOperator.java | 4 +- .../ExecutableStageDoFnOperator.java | 7 +- .../streaming/state/FlinkStateInternals.java | 285 ++++++++++++++---- .../streaming/FlinkStateInternalsTest.java | 4 + 4 files changed, 229 insertions(+), 71 deletions(-) diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java index 63f5ede00242..055831a8103b 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java @@ -462,7 +462,7 @@ public void initializeState(StateInitializationContext context) throws Exception if (keyCoder != null) { keyedStateInternals = new FlinkStateInternals<>( - (KeyedStateBackend) getKeyedStateBackend(), keyCoder, serializedOptions); + (KeyedStateBackend) getKeyedStateBackend(), keyCoder, windowingStrategy.getWindowFn().windowCoder(), serializedOptions); if (timerService == null) { timerService = @@ -590,7 +590,7 @@ private void earlyBindStateIfNeeded() throws IllegalArgumentException, IllegalAc if (doFn != null) { DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass()); FlinkStateInternals.EarlyBinder earlyBinder = - new FlinkStateInternals.EarlyBinder(getKeyedStateBackend(), serializedOptions); + new FlinkStateInternals.EarlyBinder(getKeyedStateBackend(), serializedOptions, windowingStrategy.getWindowFn().windowCoder()); for (DoFnSignature.StateDeclaration value : signature.stateDeclarations().values()) { StateSpec spec = (StateSpec) signature.stateDeclarations().get(value.id()).field().get(doFn); diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java index 456f75b0ee67..7ec37cbe6dd3 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java @@ -247,7 +247,7 @@ protected Lock getLockToAcquireForStateAccessDuringBundles() { public void open() throws Exception { executableStage = ExecutableStage.fromPayload(payload); hasSdfProcessFn = hasSDF(executableStage); - initializeUserState(executableStage, getKeyedStateBackend(), pipelineOptions); + initializeUserState(executableStage, getKeyedStateBackend(), pipelineOptions, windowCoder); // TODO: Wire this into the distributed cache and make it pluggable. // TODO: Do we really want this layer of indirection when accessing the stage bundle factory? // It's a little strange because this operator is responsible for the lifetime of the stage @@ -1280,14 +1280,15 @@ void cleanupState(StateInternals stateInternals, Consumer keyContext private static void initializeUserState( ExecutableStage executableStage, @Nullable KeyedStateBackend keyedStateBackend, - SerializablePipelineOptions pipelineOptions) { + SerializablePipelineOptions pipelineOptions, + Coder windowCoder) { executableStage .getUserStates() .forEach( ref -> { try { keyedStateBackend.getOrCreateKeyedState( - StringSerializer.INSTANCE, + new FlinkStateInternals.FlinkStateNamespaceKeySerializer(windowCoder), new ListStateDescriptor<>( ref.localName(), new CoderTypeSerializer<>(ByteStringCoder.of(), pipelineOptions))); diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java index bb662669179d..8102582c4817 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java @@ -17,6 +17,7 @@ */ package org.apache.beam.runners.flink.translation.wrappers.streaming.state; +import java.io.IOException; import java.nio.ByteBuffer; import java.util.Collections; import java.util.HashSet; @@ -28,6 +29,8 @@ import java.util.function.Function; import java.util.stream.Stream; import javax.annotation.Nonnull; + +import com.esotericsoftware.kryo.serializers.DefaultSerializers; import org.apache.beam.runners.core.StateInternals; import org.apache.beam.runners.core.StateNamespace; import org.apache.beam.runners.core.StateNamespaces; @@ -56,6 +59,7 @@ import org.apache.beam.sdk.state.WatermarkHoldState; import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.CombineWithContext; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; import org.apache.beam.sdk.util.CombineContextFactory; @@ -75,8 +79,13 @@ import org.apache.flink.api.common.state.StateDescriptor; import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializerSchemaCompatibility; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; import org.apache.flink.api.common.typeutils.base.BooleanSerializer; import org.apache.flink.api.common.typeutils.base.StringSerializer; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.runtime.state.JavaSerializer; import org.apache.flink.runtime.state.KeyedStateBackend; import org.apache.flink.runtime.state.VoidNamespace; import org.apache.flink.runtime.state.VoidNamespaceSerializer; @@ -103,6 +112,7 @@ public class FlinkStateInternals implements StateInternals { private final KeyedStateBackend flinkStateBackend; private final Coder keyCoder; + FlinkStateNamespaceKeySerializer namespaceKeySerializer; private static class StateAndNamespaceDescriptor { static StateAndNamespaceDescriptor of( @@ -168,11 +178,13 @@ public String toString() { public FlinkStateInternals( KeyedStateBackend flinkStateBackend, Coder keyCoder, + Coder windowCoder, SerializablePipelineOptions pipelineOptions) throws Exception { this.flinkStateBackend = Objects.requireNonNull(flinkStateBackend); this.keyCoder = Objects.requireNonNull(keyCoder); this.fasterCopy = pipelineOptions.get().as(FlinkPipelineOptions.class).getFasterCopy(); + this.namespaceKeySerializer = new FlinkStateNamespaceKeySerializer(windowCoder); watermarkHoldStateDescriptor = new MapStateDescriptor<>( @@ -242,29 +254,28 @@ private FlinkStateBinder(StateNamespace namespace, StateContext stateContext) public ValueState bindValue( String id, StateSpec> spec, Coder coder) { FlinkValueState valueState = - new FlinkValueState<>(flinkStateBackend, id, namespace, coder, fasterCopy); + new FlinkValueState<>(flinkStateBackend, id, namespace, coder, namespaceKeySerializer, fasterCopy); collectGlobalWindowStateDescriptor( valueState.flinkStateDescriptor, - valueState.namespace.stringKey(), - StringSerializer.INSTANCE); + valueState.namespace, namespaceKeySerializer); return valueState; } @Override public BagState bindBag(String id, StateSpec> spec, Coder elemCoder) { FlinkBagState bagState = - new FlinkBagState<>(flinkStateBackend, id, namespace, elemCoder, fasterCopy); + new FlinkBagState<>(flinkStateBackend, id, namespace, elemCoder, namespaceKeySerializer, fasterCopy); collectGlobalWindowStateDescriptor( - bagState.flinkStateDescriptor, bagState.namespace.stringKey(), StringSerializer.INSTANCE); + bagState.flinkStateDescriptor, bagState.namespace, namespaceKeySerializer); return bagState; } @Override public SetState bindSet(String id, StateSpec> spec, Coder elemCoder) { FlinkSetState setState = - new FlinkSetState<>(flinkStateBackend, id, namespace, elemCoder, fasterCopy); + new FlinkSetState<>(flinkStateBackend, id, namespace, elemCoder, namespaceKeySerializer, fasterCopy); collectGlobalWindowStateDescriptor( - setState.flinkStateDescriptor, setState.namespace.stringKey(), StringSerializer.INSTANCE); + setState.flinkStateDescriptor, setState.namespace, namespaceKeySerializer); return setState; } @@ -276,9 +287,9 @@ public MapState bindMap( Coder mapValueCoder) { FlinkMapState mapState = new FlinkMapState<>( - flinkStateBackend, id, namespace, mapKeyCoder, mapValueCoder, fasterCopy); + flinkStateBackend, id, namespace, mapKeyCoder, mapValueCoder, namespaceKeySerializer, fasterCopy); collectGlobalWindowStateDescriptor( - mapState.flinkStateDescriptor, mapState.namespace.stringKey(), StringSerializer.INSTANCE); + mapState.flinkStateDescriptor, mapState.namespace, namespaceKeySerializer); return mapState; } @@ -286,11 +297,11 @@ public MapState bindMap( public OrderedListState bindOrderedList( String id, StateSpec> spec, Coder elemCoder) { FlinkOrderedListState flinkOrderedListState = - new FlinkOrderedListState<>(flinkStateBackend, id, namespace, elemCoder, fasterCopy); + new FlinkOrderedListState<>(flinkStateBackend, id, namespace, elemCoder, namespaceKeySerializer, fasterCopy); collectGlobalWindowStateDescriptor( flinkOrderedListState.flinkStateDescriptor, - flinkOrderedListState.namespace.stringKey(), - StringSerializer.INSTANCE); + flinkOrderedListState.namespace, + namespaceKeySerializer); return flinkOrderedListState; } @@ -312,11 +323,11 @@ public CombiningState bindCom Combine.CombineFn combineFn) { FlinkCombiningState combiningState = new FlinkCombiningState<>( - flinkStateBackend, id, combineFn, namespace, accumCoder, fasterCopy); + flinkStateBackend, id, combineFn, namespace, accumCoder, namespaceKeySerializer, fasterCopy); collectGlobalWindowStateDescriptor( combiningState.flinkStateDescriptor, - combiningState.namespace.stringKey(), - StringSerializer.INSTANCE); + combiningState.namespace, + namespaceKeySerializer); return combiningState; } @@ -334,12 +345,13 @@ CombiningState bindCombiningWithContext( combineFn, namespace, accumCoder, + namespaceKeySerializer, CombineContextFactory.createFromStateContext(stateContext), fasterCopy); collectGlobalWindowStateDescriptor( combiningStateWithContext.flinkStateDescriptor, - combiningStateWithContext.namespace.stringKey(), - StringSerializer.INSTANCE); + combiningStateWithContext.namespace, + namespaceKeySerializer); return combiningStateWithContext; } @@ -369,23 +381,146 @@ private void collectGlobalWindowStateDescriptor( } } + public static class FlinkStateNamespaceKeySerializer extends TypeSerializer { + + public Coder getCoder() { + return coder; + } + + private final Coder coder; + + public FlinkStateNamespaceKeySerializer(Coder coder) { + this.coder = coder; + } + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public TypeSerializer duplicate() { + return this; + } + + @Override + public StateNamespace createInstance() { + return null; + } + + @Override + public StateNamespace copy(StateNamespace from) { + return from; + } + + @Override + public StateNamespace copy(StateNamespace from, StateNamespace reuse) { + return from; + } + + @Override + public int getLength() { + return -1; + } + + @Override + public void serialize(StateNamespace record, DataOutputView target) throws IOException { + StringSerializer.INSTANCE.serialize(record.stringKey(), target); + } + + @Override + public StateNamespace deserialize(DataInputView source) throws IOException { + return StateNamespaces.fromString(StringSerializer.INSTANCE.deserialize(source), coder); + } + + @Override + public StateNamespace deserialize(StateNamespace reuse, DataInputView source) throws IOException { + return deserialize(source); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + throw new UnsupportedOperationException("copy is not supported for FlinkStateNamespace key"); + } + + @Override + public boolean equals(Object obj) { + return obj instanceof FlinkStateNamespaceKeySerializer; + } + + @Override + public int hashCode() { + return Objects.hashCode(getClass()); + } + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new FlinkStateNameSpaceSerializerSnapshot(this); + } + + /** Serializer configuration snapshot for compatibility and format evolution. */ + @SuppressWarnings("WeakerAccess") + public final static class FlinkStateNameSpaceSerializerSnapshot implements TypeSerializerSnapshot { + + @Nullable + private Coder windowCoder; + + public FlinkStateNameSpaceSerializerSnapshot(){ + + } + + FlinkStateNameSpaceSerializerSnapshot(FlinkStateNamespaceKeySerializer ser) { + this.windowCoder = ser.getCoder(); + } + + @Override + public int getCurrentVersion() { + return 0; + } + + @Override + public void writeSnapshot(DataOutputView out) throws IOException { + new JavaSerializer>().serialize(windowCoder, out); + } + + @Override + public void readSnapshot(int readVersion, DataInputView in, ClassLoader userCodeClassLoader) throws IOException { + this.windowCoder = new JavaSerializer>().deserialize(in); + } + + @Override + public TypeSerializer restoreSerializer() { + return new FlinkStateNamespaceKeySerializer(windowCoder); + } + + @Override + public TypeSerializerSchemaCompatibility resolveSchemaCompatibility(TypeSerializer newSerializer) { + return TypeSerializerSchemaCompatibility.compatibleAsIs(); + } + } + } + private static class FlinkValueState implements ValueState { private final StateNamespace namespace; private final String stateId; private final ValueStateDescriptor flinkStateDescriptor; private final KeyedStateBackend flinkStateBackend; + private final FlinkStateNamespaceKeySerializer namespaceSerializer; FlinkValueState( KeyedStateBackend flinkStateBackend, String stateId, StateNamespace namespace, Coder coder, + FlinkStateNamespaceKeySerializer namespaceSerializer, boolean fasterCopy) { this.namespace = namespace; this.stateId = stateId; this.flinkStateBackend = flinkStateBackend; + this.namespaceSerializer = namespaceSerializer; + flinkStateDescriptor = new ValueStateDescriptor<>(stateId, new CoderTypeSerializer<>(coder, fasterCopy)); @@ -396,7 +531,7 @@ public void write(T input) { try { flinkStateBackend .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + namespace, namespaceSerializer, flinkStateDescriptor) .update(input); } catch (Exception e) { throw new RuntimeException("Error updating state.", e); @@ -413,7 +548,7 @@ public T read() { try { return flinkStateBackend .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + namespace, namespaceSerializer, flinkStateDescriptor) .value(); } catch (Exception e) { throw new RuntimeException("Error reading state.", e); @@ -424,8 +559,7 @@ public T read() { public void clear() { try { flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .clear(); } catch (Exception e) { throw new RuntimeException("Error clearing state.", e); @@ -458,18 +592,21 @@ private static class FlinkOrderedListState implements OrderedListState { private final StateNamespace namespace; private final ListStateDescriptor> flinkStateDescriptor; private final KeyedStateBackend flinkStateBackend; + private final FlinkStateNamespaceKeySerializer namespaceSerializer; FlinkOrderedListState( KeyedStateBackend flinkStateBackend, String stateId, StateNamespace namespace, Coder coder, + FlinkStateNamespaceKeySerializer namespaceSerializer, boolean fasterCopy) { this.namespace = namespace; this.flinkStateBackend = flinkStateBackend; this.flinkStateDescriptor = new ListStateDescriptor<>( stateId, new CoderTypeSerializer<>(TimestampedValueCoder.of(coder), fasterCopy)); + this.namespaceSerializer = namespaceSerializer; } @Override @@ -484,7 +621,7 @@ public void clearRange(Instant minTimestamp, Instant limitTimestamp) { try { ListState> partitionedState = flinkStateBackend.getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); + namespace, namespaceSerializer, flinkStateDescriptor); partitionedState.update(Lists.newArrayList(sortedMap.values())); } catch (Exception e) { throw new RuntimeException("Error adding to bag state.", e); @@ -501,7 +638,7 @@ public void add(TimestampedValue value) { try { ListState> partitionedState = flinkStateBackend.getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); + namespace, namespaceSerializer, flinkStateDescriptor); partitionedState.add(value); } catch (Exception e) { throw new RuntimeException("Error adding to bag state.", e); @@ -517,7 +654,7 @@ public Boolean read() { Iterable> result = flinkStateBackend .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + namespace, namespaceSerializer, flinkStateDescriptor) .get(); return result == null; } catch (Exception e) { @@ -543,7 +680,7 @@ private SortedMap> readAsMap() { try { ListState> partitionedState = flinkStateBackend.getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); + namespace, namespaceSerializer, flinkStateDescriptor); listValues = MoreObjects.firstNonNull(partitionedState.get(), Collections.emptyList()); } catch (Exception e) { throw new RuntimeException("Error reading state.", e); @@ -566,7 +703,7 @@ public void clear() { try { flinkStateBackend .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + namespace, namespaceSerializer, flinkStateDescriptor) .clear(); } catch (Exception e) { throw new RuntimeException("Error clearing state.", e); @@ -581,12 +718,14 @@ private static class FlinkBagState implements BagState { private final ListStateDescriptor flinkStateDescriptor; private final KeyedStateBackend flinkStateBackend; private final boolean storesVoidValues; + private final FlinkStateNamespaceKeySerializer namespaceSerializer; FlinkBagState( KeyedStateBackend flinkStateBackend, String stateId, StateNamespace namespace, Coder coder, + FlinkStateNamespaceKeySerializer namespaceSerializer, boolean fasterCopy) { this.namespace = namespace; @@ -595,6 +734,7 @@ private static class FlinkBagState implements BagState { this.storesVoidValues = coder instanceof VoidCoder; this.flinkStateDescriptor = new ListStateDescriptor<>(stateId, new CoderTypeSerializer<>(coder, fasterCopy)); + this.namespaceSerializer = namespaceSerializer; } @Override @@ -602,7 +742,7 @@ public void add(T input) { try { ListState partitionedState = flinkStateBackend.getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); + namespace, namespaceSerializer, flinkStateDescriptor); if (storesVoidValues) { Preconditions.checkState(input == null, "Expected to a null value but was: %s", input); // Flink does not allow storing null values @@ -626,7 +766,7 @@ public Iterable read() { try { ListState partitionedState = flinkStateBackend.getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); + namespace, namespaceSerializer, flinkStateDescriptor); Iterable result = partitionedState.get(); if (storesVoidValues) { return () -> { @@ -663,7 +803,7 @@ public Boolean read() { Iterable result = flinkStateBackend .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + namespace, namespaceSerializer, flinkStateDescriptor) .get(); return result == null; } catch (Exception e) { @@ -683,7 +823,7 @@ public void clear() { try { flinkStateBackend .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + namespace, namespaceSerializer, flinkStateDescriptor) .clear(); } catch (Exception e) { throw new RuntimeException("Error clearing state.", e); @@ -720,6 +860,7 @@ private static class FlinkCombiningState private final Combine.CombineFn combineFn; private final ValueStateDescriptor flinkStateDescriptor; private final KeyedStateBackend flinkStateBackend; + private final FlinkStateNamespaceKeySerializer namespaceSerializer; FlinkCombiningState( KeyedStateBackend flinkStateBackend, @@ -727,12 +868,14 @@ private static class FlinkCombiningState Combine.CombineFn combineFn, StateNamespace namespace, Coder accumCoder, + FlinkStateNamespaceKeySerializer namespaceSerializer, boolean fasterCopy) { this.namespace = namespace; this.stateId = stateId; this.combineFn = combineFn; this.flinkStateBackend = flinkStateBackend; + this.namespaceSerializer = namespaceSerializer; flinkStateDescriptor = new ValueStateDescriptor<>( @@ -749,7 +892,7 @@ public void add(InputT value) { try { org.apache.flink.api.common.state.ValueState state = flinkStateBackend.getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); + namespace, namespaceSerializer, flinkStateDescriptor); AccumT current = state.value(); if (current == null) { @@ -767,7 +910,7 @@ public void addAccum(AccumT accum) { try { org.apache.flink.api.common.state.ValueState state = flinkStateBackend.getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); + namespace, namespaceSerializer, flinkStateDescriptor); AccumT current = state.value(); if (current == null) { @@ -787,7 +930,7 @@ public AccumT getAccum() { AccumT accum = flinkStateBackend .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + namespace, namespaceSerializer, flinkStateDescriptor) .value(); return accum != null ? accum : combineFn.createAccumulator(); } catch (Exception e) { @@ -805,7 +948,7 @@ public OutputT read() { try { org.apache.flink.api.common.state.ValueState state = flinkStateBackend.getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); + namespace, namespaceSerializer, flinkStateDescriptor); AccumT accum = state.value(); if (accum != null) { @@ -826,7 +969,7 @@ public Boolean read() { try { return flinkStateBackend .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + namespace, namespaceSerializer, flinkStateDescriptor) .value() == null; } catch (Exception e) { @@ -846,7 +989,7 @@ public void clear() { try { flinkStateBackend .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + namespace, namespaceSerializer, flinkStateDescriptor) .clear(); } catch (Exception e) { throw new RuntimeException("Error clearing state.", e); @@ -884,6 +1027,7 @@ private static class FlinkCombiningStateWithContext private final ValueStateDescriptor flinkStateDescriptor; private final KeyedStateBackend flinkStateBackend; private final CombineWithContext.Context context; + private final FlinkStateNamespaceKeySerializer namespaceSerializer; FlinkCombiningStateWithContext( KeyedStateBackend flinkStateBackend, @@ -891,6 +1035,7 @@ private static class FlinkCombiningStateWithContext CombineWithContext.CombineFnWithContext combineFn, StateNamespace namespace, Coder accumCoder, + FlinkStateNamespaceKeySerializer namespaceSerializer, CombineWithContext.Context context, boolean fasterCopy) { @@ -899,6 +1044,7 @@ private static class FlinkCombiningStateWithContext this.combineFn = combineFn; this.flinkStateBackend = flinkStateBackend; this.context = context; + this.namespaceSerializer = namespaceSerializer; flinkStateDescriptor = new ValueStateDescriptor<>( @@ -915,7 +1061,7 @@ public void add(InputT value) { try { org.apache.flink.api.common.state.ValueState state = flinkStateBackend.getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); + namespace, namespaceSerializer, flinkStateDescriptor); AccumT current = state.value(); if (current == null) { @@ -933,7 +1079,7 @@ public void addAccum(AccumT accum) { try { org.apache.flink.api.common.state.ValueState state = flinkStateBackend.getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); + namespace, namespaceSerializer, flinkStateDescriptor); AccumT current = state.value(); if (current == null) { @@ -953,7 +1099,7 @@ public AccumT getAccum() { AccumT accum = flinkStateBackend .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + namespace, namespaceSerializer, flinkStateDescriptor) .value(); return accum != null ? accum : combineFn.createAccumulator(context); } catch (Exception e) { @@ -971,7 +1117,7 @@ public OutputT read() { try { org.apache.flink.api.common.state.ValueState state = flinkStateBackend.getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); + namespace, namespaceSerializer, flinkStateDescriptor); AccumT accum = state.value(); if (accum != null) { @@ -992,7 +1138,7 @@ public Boolean read() { try { return flinkStateBackend .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + namespace, namespaceSerializer, flinkStateDescriptor) .value() == null; } catch (Exception e) { @@ -1012,7 +1158,7 @@ public void clear() { try { flinkStateBackend .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + namespace, namespaceSerializer, flinkStateDescriptor) .clear(); } catch (Exception e) { throw new RuntimeException("Error clearing state.", e); @@ -1172,6 +1318,7 @@ private static class FlinkMapState implements MapState flinkStateDescriptor; private final KeyedStateBackend flinkStateBackend; + private final FlinkStateNamespaceKeySerializer namespaceSerializer; FlinkMapState( KeyedStateBackend flinkStateBackend, @@ -1179,6 +1326,7 @@ private static class FlinkMapState implements MapState mapKeyCoder, Coder mapValueCoder, + FlinkStateNamespaceKeySerializer namespaceSerializer, boolean fasterCopy) { this.namespace = namespace; this.stateId = stateId; @@ -1188,6 +1336,7 @@ private static class FlinkMapState implements MapState(mapKeyCoder, fasterCopy), new CoderTypeSerializer<>(mapValueCoder, fasterCopy)); + this.namespaceSerializer = namespaceSerializer; } @Override @@ -1205,7 +1354,7 @@ public ReadableState get(final KeyT input) { ValueT value = flinkStateBackend .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + namespace, namespaceSerializer, flinkStateDescriptor) .get(key); return (value != null) ? value : defaultValue; } catch (Exception e) { @@ -1225,7 +1374,7 @@ public void put(KeyT key, ValueT value) { try { flinkStateBackend .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + namespace, namespaceSerializer, flinkStateDescriptor) .put(key, value); } catch (Exception e) { throw new RuntimeException("Error put kv to state.", e); @@ -1239,13 +1388,13 @@ public ReadableState computeIfAbsent( ValueT current = flinkStateBackend .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + namespace, namespaceSerializer, flinkStateDescriptor) .get(key); if (current == null) { flinkStateBackend .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + namespace, namespaceSerializer, flinkStateDescriptor) .put(key, mappingFunction.apply(key)); } return ReadableStates.immediate(current); @@ -1259,7 +1408,7 @@ public void remove(KeyT key) { try { flinkStateBackend .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + namespace, namespaceSerializer, flinkStateDescriptor) .remove(key); } catch (Exception e) { throw new RuntimeException("Error remove map state key.", e); @@ -1275,7 +1424,7 @@ public Iterable read() { Iterable result = flinkStateBackend .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + namespace, namespaceSerializer, flinkStateDescriptor) .keys(); return result != null ? ImmutableList.copyOf(result) : Collections.emptyList(); } catch (Exception e) { @@ -1299,7 +1448,7 @@ public Iterable read() { Iterable result = flinkStateBackend .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + namespace, namespaceSerializer, flinkStateDescriptor) .values(); return result != null ? ImmutableList.copyOf(result) : Collections.emptyList(); } catch (Exception e) { @@ -1323,7 +1472,7 @@ public Iterable> read() { Iterable> result = flinkStateBackend .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + namespace, namespaceSerializer, flinkStateDescriptor) .entries(); return result != null ? ImmutableList.copyOf(result) : Collections.emptyList(); } catch (Exception e) { @@ -1362,7 +1511,7 @@ public void clear() { try { flinkStateBackend .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + namespace, namespaceSerializer, flinkStateDescriptor) .clear(); } catch (Exception e) { throw new RuntimeException("Error clearing state.", e); @@ -1397,12 +1546,14 @@ private static class FlinkSetState implements SetState { private final String stateId; private final MapStateDescriptor flinkStateDescriptor; private final KeyedStateBackend flinkStateBackend; + private final FlinkStateNamespaceKeySerializer namespaceSerializer; FlinkSetState( KeyedStateBackend flinkStateBackend, String stateId, StateNamespace namespace, Coder coder, + FlinkStateNamespaceKeySerializer namespaceSerializer, boolean fasterCopy) { this.namespace = namespace; this.stateId = stateId; @@ -1420,7 +1571,7 @@ public ReadableState contains(final T t) { Boolean result = flinkStateBackend .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + namespace, namespaceSerializer, flinkStateDescriptor) .get(t); return ReadableStates.immediate(result != null && result); } catch (Exception e) { @@ -1433,7 +1584,7 @@ public ReadableState addIfAbsent(final T t) { try { org.apache.flink.api.common.state.MapState state = flinkStateBackend.getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); + namespace, namespaceSerializer, flinkStateDescriptor); boolean alreadyContained = state.contains(t); if (!alreadyContained) { state.put(t, true); @@ -1449,7 +1600,7 @@ public void remove(T t) { try { flinkStateBackend .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + namespace, namespaceSerializer, flinkStateDescriptor) .remove(t); } catch (Exception e) { throw new RuntimeException("Error remove value to state.", e); @@ -1466,7 +1617,7 @@ public void add(T value) { try { flinkStateBackend .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + namespace, namespaceSerializer, flinkStateDescriptor) .put(value, true); } catch (Exception e) { throw new RuntimeException("Error add value to state.", e); @@ -1482,7 +1633,7 @@ public Boolean read() { Iterable result = flinkStateBackend .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + namespace, namespaceSerializer, flinkStateDescriptor) .keys(); return result == null || Iterables.isEmpty(result); } catch (Exception e) { @@ -1503,7 +1654,7 @@ public Iterable read() { Iterable result = flinkStateBackend .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + namespace, namespaceSerializer, flinkStateDescriptor) .keys(); return result != null ? ImmutableList.copyOf(result) : Collections.emptyList(); } catch (Exception e) { @@ -1516,7 +1667,7 @@ public void clear() { try { flinkStateBackend .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + namespace, namespaceSerializer, flinkStateDescriptor) .clear(); } catch (Exception e) { throw new RuntimeException("Error clearing state.", e); @@ -1573,18 +1724,20 @@ public static class EarlyBinder implements StateBinder { private final KeyedStateBackend keyedStateBackend; private final Boolean fasterCopy; + private final FlinkStateNamespaceKeySerializer namespaceSerializer; public EarlyBinder( - KeyedStateBackend keyedStateBackend, SerializablePipelineOptions pipelineOptions) { + KeyedStateBackend keyedStateBackend, SerializablePipelineOptions pipelineOptions, Coder windowCoder) { this.keyedStateBackend = keyedStateBackend; this.fasterCopy = pipelineOptions.get().as(FlinkPipelineOptions.class).getFasterCopy(); + this.namespaceSerializer = new FlinkStateNamespaceKeySerializer(windowCoder); } @Override public ValueState bindValue(String id, StateSpec> spec, Coder coder) { try { keyedStateBackend.getOrCreateKeyedState( - StringSerializer.INSTANCE, + namespaceSerializer, new ValueStateDescriptor<>(id, new CoderTypeSerializer<>(coder, fasterCopy))); } catch (Exception e) { throw new RuntimeException(e); @@ -1597,7 +1750,7 @@ public ValueState bindValue(String id, StateSpec> spec, Cod public BagState bindBag(String id, StateSpec> spec, Coder elemCoder) { try { keyedStateBackend.getOrCreateKeyedState( - StringSerializer.INSTANCE, + namespaceSerializer, new ListStateDescriptor<>(id, new CoderTypeSerializer<>(elemCoder, fasterCopy))); } catch (Exception e) { throw new RuntimeException(e); @@ -1610,7 +1763,7 @@ public BagState bindBag(String id, StateSpec> spec, Coder public SetState bindSet(String id, StateSpec> spec, Coder elemCoder) { try { keyedStateBackend.getOrCreateKeyedState( - StringSerializer.INSTANCE, + namespaceSerializer, new MapStateDescriptor<>( id, new CoderTypeSerializer<>(elemCoder, fasterCopy), @@ -1629,7 +1782,7 @@ public org.apache.beam.sdk.state.MapState bindMap( Coder mapValueCoder) { try { keyedStateBackend.getOrCreateKeyedState( - StringSerializer.INSTANCE, + namespaceSerializer, new MapStateDescriptor<>( id, new CoderTypeSerializer<>(mapKeyCoder, fasterCopy), @@ -1645,7 +1798,7 @@ public OrderedListState bindOrderedList( String id, StateSpec> spec, Coder elemCoder) { try { keyedStateBackend.getOrCreateKeyedState( - StringSerializer.INSTANCE, + namespaceSerializer, new ListStateDescriptor<>( id, new CoderTypeSerializer<>(TimestampedValueCoder.of(elemCoder), fasterCopy))); @@ -1674,7 +1827,7 @@ public CombiningState bindCom Combine.CombineFn combineFn) { try { keyedStateBackend.getOrCreateKeyedState( - StringSerializer.INSTANCE, + namespaceSerializer, new ValueStateDescriptor<>(id, new CoderTypeSerializer<>(accumCoder, fasterCopy))); } catch (Exception e) { throw new RuntimeException(e); @@ -1691,7 +1844,7 @@ CombiningState bindCombiningWithContext( CombineWithContext.CombineFnWithContext combineFn) { try { keyedStateBackend.getOrCreateKeyedState( - StringSerializer.INSTANCE, + namespaceSerializer, new ValueStateDescriptor<>(id, new CoderTypeSerializer<>(accumCoder, fasterCopy))); } catch (Exception e) { throw new RuntimeException(e); diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java index a2d6f5027abb..7ea726699977 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java @@ -69,6 +69,7 @@ protected StateInternals createStateInternals() { return new FlinkStateInternals<>( keyedStateBackend, StringUtf8Coder.of(), + IntervalWindow.getCoder(), new SerializablePipelineOptions(FlinkPipelineOptions.defaults())); } catch (Exception e) { throw new RuntimeException(e); @@ -82,6 +83,7 @@ public void testWatermarkHoldsPersistence() throws Exception { new FlinkStateInternals<>( keyedStateBackend, StringUtf8Coder.of(), + IntervalWindow.getCoder(), new SerializablePipelineOptions(FlinkPipelineOptions.defaults())); StateTag stateTag = @@ -137,6 +139,7 @@ public void testWatermarkHoldsPersistence() throws Exception { new FlinkStateInternals<>( keyedStateBackend, StringUtf8Coder.of(), + IntervalWindow.getCoder(), new SerializablePipelineOptions(FlinkPipelineOptions.defaults())); globalWindow = stateInternals.state(StateNamespaces.global(), stateTag); fixedWindow = @@ -174,6 +177,7 @@ public void testGlobalWindowWatermarkHoldClear() throws Exception { new FlinkStateInternals<>( keyedStateBackend, StringUtf8Coder.of(), + IntervalWindow.getCoder(), new SerializablePipelineOptions(FlinkPipelineOptions.defaults())); StateTag stateTag = StateTags.watermarkStateInternal("hold", TimestampCombiner.EARLIEST); From 5088e93589a0fbb73385cdc222bd03a9c5b6d930 Mon Sep 17 00:00:00 2001 From: jto Date: Tue, 6 Aug 2024 14:53:14 +0200 Subject: [PATCH 07/19] [Flink] Make ToKeyedWorkItem part of the DoFnOperator --- ...nkStreamingPortablePipelineTranslator.java | 20 +- .../FlinkStreamingTransformTranslators.java | 155 ++--- .../wrappers/streaming/DoFnOperator.java | 52 +- .../ExecutableStageDoFnOperator.java | 3 +- .../streaming/SplittableDoFnOperator.java | 5 +- .../streaming/WindowDoFnOperator.java | 22 +- .../flink/FlinkPipelineOptionsTest.java | 4 +- .../wrappers/streaming/DoFnOperatorTest.java | 54 +- .../streaming/WindowDoFnOperatorTest.java | 620 +++++++++--------- 9 files changed, 456 insertions(+), 479 deletions(-) diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java index 836c825300db..e7244bf982d0 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java @@ -430,24 +430,16 @@ private SingleOutputStreamOperator>>> add WindowedValue.FullWindowedValueCoder> windowedWorkItemCoder = WindowedValue.getFullCoder(workItemCoder, windowingStrategy.getWindowFn().windowCoder()); - CoderTypeInformation>> workItemTypeInfo = - new CoderTypeInformation<>(windowedWorkItemCoder, context.getPipelineOptions()); - - DataStream>> workItemStream = - inputDataStream - .flatMap( - new FlinkStreamingTransformTranslators.ToKeyedWorkItem<>( - context.getPipelineOptions())) - .returns(workItemTypeInfo) - .name("ToKeyedWorkItem"); - WorkItemKeySelector keySelector = new WorkItemKeySelector<>( inputElementCoder.getKeyCoder(), new SerializablePipelineOptions(context.getPipelineOptions())); - KeyedStream>, ByteBuffer> keyedWorkItemStream = - workItemStream.keyBy(keySelector); + KeyedStream>, ByteBuffer> keyedWorkItemStream = + inputDataStream.keyBy( + new KvToByteBufferKeySelector( + inputElementCoder.getKeyCoder(), + new SerializablePipelineOptions(context.getPipelineOptions()))); SystemReduceFn, Iterable, BoundedWindow> reduceFn = SystemReduceFn.buffering(inputElementCoder.getValueCoder()); @@ -872,7 +864,7 @@ private void translateExecutableStage( tagsToIds, new SerializablePipelineOptions(context.getPipelineOptions())); - DoFnOperator doFnOperator = + DoFnOperator doFnOperator = new ExecutableStageDoFnOperator<>( transform.getUniqueName(), windowedInputCoder, diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java index d243a88c5af4..4ef57424d526 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java @@ -30,6 +30,7 @@ import java.util.List; import java.util.Map; import org.apache.beam.runners.core.KeyedWorkItem; +import org.apache.beam.runners.core.KeyedWorkItemCoder; import org.apache.beam.runners.core.SplittableParDoViaKeyedWorkItems; import org.apache.beam.runners.core.SystemReduceFn; import org.apache.beam.runners.core.construction.SerializablePipelineOptions; @@ -496,7 +497,7 @@ public RawUnionValue map(T o) throws Exception { static class ParDoTranslationHelper { interface DoFnOperatorFactory { - DoFnOperator createDoFnOperator( + DoFnOperator createDoFnOperator( DoFn doFn, String stepName, List> sideInputs, @@ -604,7 +605,7 @@ static void translateParDo( context.getPipelineOptions()); if (sideInputs.isEmpty()) { - DoFnOperator doFnOperator = + DoFnOperator doFnOperator = doFnOperatorFactory.createDoFnOperator( doFn, getCurrentTransformName(context), @@ -631,7 +632,7 @@ static void translateParDo( Tuple2>, DataStream> transformedSideInputs = transformSideInputs(sideInputs, context); - DoFnOperator doFnOperator = + DoFnOperator doFnOperator = doFnOperatorFactory.createDoFnOperator( doFn, getCurrentTransformName(context), @@ -943,36 +944,37 @@ public void translateNode( KvCoder inputKvCoder = (KvCoder) input.getCoder(); - SingletonKeyedWorkItemCoder workItemCoder = - SingletonKeyedWorkItemCoder.of( - inputKvCoder.getKeyCoder(), - ByteArrayCoder.of(), - input.getWindowingStrategy().getWindowFn().windowCoder()); - DataStream>> inputDataStream = context.getInputDataStream(input); - WindowedValue.FullWindowedValueCoder> windowedWorkItemCoder = + WindowedValue.FullWindowedValueCoder> windowedBinaryKVCoder = WindowedValue.getFullCoder( - workItemCoder, input.getWindowingStrategy().getWindowFn().windowCoder()); + KvCoder.of(inputKvCoder.getKeyCoder(), ByteArrayCoder.of()), + input.getWindowingStrategy().getWindowFn().windowCoder()); - CoderTypeInformation>> workItemTypeInfo = - new CoderTypeInformation<>(windowedWorkItemCoder, context.getPipelineOptions()); + WindowedValue.FullWindowedValueCoder> windowedKeyedWorkItemCoder = + WindowedValue.getFullCoder( + KeyedWorkItemCoder.of( + inputKvCoder.getKeyCoder(), + ByteArrayCoder.of(), + input.getWindowingStrategy().getWindowFn().windowCoder()), + input.getWindowingStrategy().getWindowFn().windowCoder()); - DataStream>> workItemStream = + CoderTypeInformation>> binaryKVTypeInfo = + new CoderTypeInformation<>(windowedBinaryKVCoder, context.getPipelineOptions()); + + DataStream>> inputBinaryDataStream = inputDataStream - .flatMap( - new ToBinaryKeyedWorkItem<>( - context.getPipelineOptions(), inputKvCoder.getValueCoder())) - .returns(workItemTypeInfo) - .name("ToBinaryKeyedWorkItem"); + .flatMap(new ToBinaryKV<>(context.getPipelineOptions(), inputKvCoder.getValueCoder())) + .returns(binaryKVTypeInfo) + .name("ToBinaryKV"); - WorkItemKeySelector keySelector = - new WorkItemKeySelector<>( + KvToByteBufferKeySelector keySelector = + new KvToByteBufferKeySelector<>( inputKvCoder.getKeyCoder(), new SerializablePipelineOptions(context.getPipelineOptions())); - KeyedStream>, ByteBuffer> keyedWorkItemStream = - workItemStream.keyBy(keySelector); + KeyedStream>, ByteBuffer> keyedWorkItemStream = + inputBinaryDataStream.keyBy(keySelector); SystemReduceFn, Iterable, BoundedWindow> reduceFn = SystemReduceFn.buffering(ByteArrayCoder.of()); @@ -987,12 +989,17 @@ public void translateNode( TupleTag>> mainTag = new TupleTag<>("main output"); + WorkItemKeySelector workItemKeySelector = + new WorkItemKeySelector( + inputKvCoder.getKeyCoder(), + new SerializablePipelineOptions(context.getPipelineOptions())); + String fullName = getCurrentTransformName(context); WindowDoFnOperator> doFnOperator = new WindowDoFnOperator<>( reduceFn, fullName, - windowedWorkItemCoder, + windowedKeyedWorkItemCoder, mainTag, Collections.emptyList(), new DoFnOperator.MultiOutputOutputManagerFactory<>( @@ -1004,7 +1011,7 @@ public void translateNode( Collections.emptyList(), /* side inputs */ context.getPipelineOptions(), inputKvCoder.getKeyCoder(), - keySelector); + workItemKeySelector); final SingleOutputStreamOperator>>> outDataStream = keyedWorkItemStream @@ -1066,21 +1073,16 @@ public void translateNode( WindowedValue.getFullCoder( workItemCoder, input.getWindowingStrategy().getWindowFn().windowCoder()); - CoderTypeInformation>> workItemTypeInfo = - new CoderTypeInformation<>(windowedWorkItemCoder, context.getPipelineOptions()); - - DataStream>> workItemStream = - inputDataStream - .flatMap(new ToKeyedWorkItem<>(context.getPipelineOptions())) - .returns(workItemTypeInfo) - .name("ToKeyedWorkItem"); - WorkItemKeySelector keySelector = new WorkItemKeySelector<>( inputKvCoder.getKeyCoder(), new SerializablePipelineOptions(context.getPipelineOptions())); - KeyedStream>, ByteBuffer> keyedWorkItemStream = - workItemStream.keyBy(keySelector); + + KeyedStream>, ByteBuffer> keyedStream = + inputDataStream.keyBy( + new KvToByteBufferKeySelector<>( + inputKvCoder.getKeyCoder(), + new SerializablePipelineOptions(context.getPipelineOptions()))); GlobalCombineFn combineFn = ((Combine.PerKey) transform).getFn(); SystemReduceFn reduceFn = @@ -1117,7 +1119,8 @@ public void translateNode( keySelector); SingleOutputStreamOperator>> outDataStream = - keyedWorkItemStream.transform(fullName, outputTypeInfo, doFnOperator).uid(fullName); + keyedStream.transform(fullName, outputTypeInfo, doFnOperator).uid(fullName); + context.setOutputDataStream(context.getOutput(transform), outDataStream); } else { Tuple2>, DataStream> transformSideInputs = @@ -1146,28 +1149,26 @@ public void translateNode( // allowed to have only one input keyed, normally. TwoInputTransformation< - WindowedValue>, - RawUnionValue, - WindowedValue>> + WindowedValue>, RawUnionValue, WindowedValue>> rawFlinkTransform = new TwoInputTransformation<>( - keyedWorkItemStream.getTransformation(), + keyedStream.getTransformation(), transformSideInputs.f1.broadcast().getTransformation(), transform.getName(), doFnOperator, outputTypeInfo, - keyedWorkItemStream.getParallelism()); + keyedStream.getParallelism()); - rawFlinkTransform.setStateKeyType(keyedWorkItemStream.getKeyType()); - rawFlinkTransform.setStateKeySelectors(keyedWorkItemStream.getKeySelector(), null); + rawFlinkTransform.setStateKeyType(keyedStream.getKeyType()); + rawFlinkTransform.setStateKeySelectors(keyedStream.getKeySelector(), null); @SuppressWarnings({"unchecked", "rawtypes"}) SingleOutputStreamOperator>> outDataStream = new SingleOutputStreamOperator( - keyedWorkItemStream.getExecutionEnvironment(), + keyedStream.getExecutionEnvironment(), rawFlinkTransform) {}; // we have to cheat around the ctor being protected - keyedWorkItemStream.getExecutionEnvironment().addOperator(rawFlinkTransform); + keyedStream.getExecutionEnvironment().addOperator(rawFlinkTransform); context.setOutputDataStream(context.getOutput(transform), outDataStream); } @@ -1332,51 +1333,13 @@ public void flatMap(T t, Collector collector) throws Exception { } } - static class ToKeyedWorkItem - extends RichFlatMapFunction< - WindowedValue>, WindowedValue>> { - - private final SerializablePipelineOptions options; - - ToKeyedWorkItem(PipelineOptions options) { - this.options = new SerializablePipelineOptions(options); - } - - @Override - public void open(Configuration parameters) { - // Initialize FileSystems for any coders which may want to use the FileSystem, - // see https://issues.apache.org/jira/browse/BEAM-8303 - FileSystems.setDefaultPipelineOptions(options.get()); - } - - @Override - public void flatMap( - WindowedValue> inWithMultipleWindows, - Collector>> out) { - - // we need to wrap each one work item per window for now - // since otherwise the PushbackSideInputRunner will not correctly - // determine whether side inputs are ready - // - // this is tracked as https://github.com/apache/beam/issues/18358 - for (WindowedValue> in : inWithMultipleWindows.explodeWindows()) { - SingletonKeyedWorkItem workItem = - new SingletonKeyedWorkItem<>( - in.getValue().getKey(), in.withValue(in.getValue().getValue())); - - out.collect(in.withValue(workItem)); - } - } - } - - static class ToBinaryKeyedWorkItem - extends RichFlatMapFunction< - WindowedValue>, WindowedValue>> { + static class ToBinaryKV + extends RichFlatMapFunction>, WindowedValue>> { private final SerializablePipelineOptions options; private final Coder valueCoder; - ToBinaryKeyedWorkItem(PipelineOptions options, Coder valueCoder) { + ToBinaryKV(PipelineOptions options, Coder valueCoder) { this.options = new SerializablePipelineOptions(options); this.valueCoder = valueCoder; } @@ -1390,22 +1353,10 @@ public void open(Configuration parameters) { @Override public void flatMap( - WindowedValue> inWithMultipleWindows, - Collector>> out) + WindowedValue> in, Collector>> out) throws CoderException { - - // we need to wrap each one work item per window for now - // since otherwise the PushbackSideInputRunner will not correctly - // determine whether side inputs are ready - // - // this is tracked as https://github.com/apache/beam/issues/18358 - for (WindowedValue> in : inWithMultipleWindows.explodeWindows()) { - final byte[] binaryValue = - CoderUtils.encodeToByteArray(valueCoder, in.getValue().getValue()); - final SingletonKeyedWorkItem workItem = - new SingletonKeyedWorkItem<>(in.getValue().getKey(), in.withValue(binaryValue)); - out.collect(in.withValue(workItem)); - } + final byte[] binaryValue = CoderUtils.encodeToByteArray(valueCoder, in.getValue().getValue()); + out.collect(in.withValue(KV.of(in.getValue().getKey(), binaryValue))); } } diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java index 055831a8103b..059826f8b897 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java @@ -27,6 +27,7 @@ import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; import java.util.Iterator; import java.util.LinkedHashMap; @@ -140,10 +141,10 @@ "keyfor", "nullness" }) // TODO(https://github.com/apache/beam/issues/20497) -public class DoFnOperator +public class DoFnOperator extends AbstractStreamOperatorCompat> - implements OneInputStreamOperator, WindowedValue>, - TwoInputStreamOperator, RawUnionValue, WindowedValue>, + implements OneInputStreamOperator, WindowedValue>, + TwoInputStreamOperator, RawUnionValue, WindowedValue>, Triggerable { private static final Logger LOG = LoggerFactory.getLogger(DoFnOperator.class); @@ -355,6 +356,11 @@ protected DoFn getDoFn() { return doFn; } + protected Iterable> preProcess(WindowedValue input) { + // Assume Input is PreInputT + return Collections.singletonList((WindowedValue) input); + } + // allow overriding this, for example SplittableDoFnOperator will not create a // stateful DoFn runner because ProcessFn, which is used for executing a Splittable DoFn // doesn't play by the normal DoFn rules and WindowDoFnOperator uses LateDataDroppingDoFnRunner @@ -686,30 +692,34 @@ protected final void setBundleFinishedCallback(Runnable callback) { } @Override - public final void processElement(StreamRecord> streamRecord) { - checkInvokeStartBundle(); - LOG.trace("Processing element {} in {}", streamRecord.getValue().getValue(), doFn.getClass()); - long oldHold = keyCoder != null ? keyedStateInternals.minWatermarkHoldMs() : -1L; - doFnRunner.processElement(streamRecord.getValue()); - checkInvokeFinishBundleByCount(); - emitWatermarkIfHoldChanged(oldHold); + public final void processElement(StreamRecord> streamRecord) { + for (WindowedValue e : preProcess(streamRecord.getValue())) { + checkInvokeStartBundle(); + LOG.trace("Processing element {} in {}", streamRecord.getValue().getValue(), doFn.getClass()); + long oldHold = keyCoder != null ? keyedStateInternals.minWatermarkHoldMs() : -1L; + doFnRunner.processElement(e); + checkInvokeFinishBundleByCount(); + emitWatermarkIfHoldChanged(oldHold); + } } @Override - public final void processElement1(StreamRecord> streamRecord) + public final void processElement1(StreamRecord> streamRecord) throws Exception { - checkInvokeStartBundle(); - Iterable> justPushedBack = - pushbackDoFnRunner.processElementInReadyWindows(streamRecord.getValue()); + for (WindowedValue e : preProcess(streamRecord.getValue())) { + checkInvokeStartBundle(); + Iterable> justPushedBack = + pushbackDoFnRunner.processElementInReadyWindows(e); - long min = pushedBackWatermark; - for (WindowedValue pushedBackValue : justPushedBack) { - min = Math.min(min, pushedBackValue.getTimestamp().getMillis()); - pushedBackElementsHandler.pushBack(pushedBackValue); - } - pushedBackWatermark = min; + long min = pushedBackWatermark; + for (WindowedValue pushedBackValue : justPushedBack) { + min = Math.min(min, pushedBackValue.getTimestamp().getMillis()); + pushedBackElementsHandler.pushBack(pushedBackValue); + } + pushedBackWatermark = min; - checkInvokeFinishBundleByCount(); + checkInvokeFinishBundleByCount(); + } } /** diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java index 7ec37cbe6dd3..446a4541dd1a 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java @@ -138,7 +138,8 @@ "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) "nullness" // TODO(https://github.com/apache/beam/issues/20497) }) -public class ExecutableStageDoFnOperator extends DoFnOperator { +public class ExecutableStageDoFnOperator + extends DoFnOperator { private static final Logger LOG = LoggerFactory.getLogger(ExecutableStageDoFnOperator.class); diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java index 59d09ae99966..c8b41587590f 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java @@ -64,7 +64,10 @@ "nullness" // TODO(https://github.com/apache/beam/issues/20497) }) public class SplittableDoFnOperator - extends DoFnOperator>, OutputT> { + extends DoFnOperator< + KeyedWorkItem>, + KeyedWorkItem>, + OutputT> { private static final Logger LOG = LoggerFactory.getLogger(SplittableDoFnOperator.class); diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperator.java index d8f4885ea057..60b20f375f22 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperator.java @@ -19,6 +19,7 @@ import static org.apache.beam.runners.core.TimerInternals.TimerData; +import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.List; @@ -50,7 +51,7 @@ "nullness" // TODO(https://github.com/apache/beam/issues/20497) }) public class WindowDoFnOperator - extends DoFnOperator, KV> { + extends DoFnOperator, KeyedWorkItem, KV> { private final SystemReduceFn systemReduceFn; @@ -87,6 +88,25 @@ public WindowDoFnOperator( this.systemReduceFn = systemReduceFn; } + @Override + protected Iterable>> preProcess( + WindowedValue> inWithMultipleWindows) { + // we need to wrap each one work item per window for now + // since otherwise the PushbackSideInputRunner will not correctly + // determine whether side inputs are ready + // + // this is tracked as https://github.com/apache/beam/issues/18358 + ArrayList>> inputs = new ArrayList<>(); + for (WindowedValue> in : inWithMultipleWindows.explodeWindows()) { + SingletonKeyedWorkItem workItem = + new SingletonKeyedWorkItem<>( + in.getValue().getKey(), in.withValue(in.getValue().getValue())); + + inputs.add(in.withValue(workItem)); + } + return inputs; + } + @Override protected DoFnRunner, KV> createWrappingDoFnRunner( DoFnRunner, KV> wrappedRunner, StepContext stepContext) { diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkPipelineOptionsTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkPipelineOptionsTest.java index c20bd077c3f2..9fa7aaca1b69 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkPipelineOptionsTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkPipelineOptionsTest.java @@ -139,7 +139,7 @@ public void parDoBaseClassPipelineOptionsSerializationTest() throws Exception { TupleTag mainTag = new TupleTag<>("main-output"); Coder> coder = WindowedValue.getValueOnlyCoder(StringUtf8Coder.of()); - DoFnOperator doFnOperator = + DoFnOperator doFnOperator = new DoFnOperator<>( new TestDoFn(), "stepName", @@ -161,7 +161,7 @@ mainTag, coder, new SerializablePipelineOptions(FlinkPipelineOptions.defaults()) final byte[] serialized = SerializationUtils.serialize(doFnOperator); @SuppressWarnings("unchecked") - DoFnOperator deserialized = SerializationUtils.deserialize(serialized); + DoFnOperator deserialized = SerializationUtils.deserialize(serialized); TypeInformation> typeInformation = TypeInformation.of(new TypeHint>() {}); diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java index 17cc16cc76e0..124fae05b03e 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java @@ -149,7 +149,7 @@ public void testSingleOutput() throws Exception { TupleTag outputTag = new TupleTag<>("main-output"); - DoFnOperator doFnOperator = + DoFnOperator doFnOperator = new DoFnOperator<>( new IdentityDoFn<>(), "stepName", @@ -211,7 +211,7 @@ public void testMultiOutputOutput() throws Exception { .put(additionalOutput2, 2) .build(); - DoFnOperator doFnOperator = + DoFnOperator doFnOperator = new DoFnOperator<>( new MultiOutputDoFn(additionalOutput1, additionalOutput2), "stepName", @@ -353,7 +353,7 @@ public void onProcessingTime(OnTimerContext context) { TupleTag outputTag = new TupleTag<>("main-output"); - DoFnOperator doFnOperator = + DoFnOperator doFnOperator = new DoFnOperator<>( fn, "stepName", @@ -441,8 +441,8 @@ public void testWatermarkUpdateAfterWatermarkHoldRelease() throws Exception { KeySelector>, ByteBuffer> keySelector = e -> FlinkKeyUtils.encodeKey(e.getValue().getKey(), StringUtf8Coder.of()); - DoFnOperator, KV> doFnOperator = - new DoFnOperator, KV>( + DoFnOperator, KV, KV> doFnOperator = + new DoFnOperator, KV, KV>( new IdentityDoFn<>(), "stepName", coder, @@ -616,7 +616,7 @@ public void processElement(ProcessContext context) { TupleTag outputTag = new TupleTag<>("main-output"); - DoFnOperator doFnOperator = + DoFnOperator doFnOperator = new DoFnOperator<>( fn, "stepName", @@ -866,7 +866,7 @@ public void onTimer(OnTimerContext context, @StateId(stateId) ValueState KeySelector>, ByteBuffer> keySelector = e -> FlinkKeyUtils.encodeKey(e.getValue().getKey(), StringUtf8Coder.of()); - DoFnOperator, KV> doFnOperator = + DoFnOperator, KV, KV> doFnOperator = new DoFnOperator<>( fn, "stepName", @@ -917,7 +917,7 @@ void testSideInputs(boolean keyed) throws Exception { keySelector = value -> FlinkKeyUtils.encodeKey(value.getValue(), keyCoder); } - DoFnOperator doFnOperator = + DoFnOperator doFnOperator = new DoFnOperator<>( new IdentityDoFn<>(), "stepName", @@ -1115,7 +1115,7 @@ public void nonKeyedParDoSideInputCheckpointing() throws Exception { .put(2, view2) .build(); - DoFnOperator doFnOperator = + DoFnOperator doFnOperator = new DoFnOperator<>( new IdentityDoFn<>(), "stepName", @@ -1158,7 +1158,7 @@ public void keyedParDoSideInputCheckpointing() throws Exception { .put(2, view2) .build(); - DoFnOperator doFnOperator = + DoFnOperator doFnOperator = new DoFnOperator<>( new IdentityDoFn<>(), "stepName", @@ -1261,7 +1261,7 @@ public void nonKeyedParDoPushbackDataCheckpointing() throws Exception { .put(2, view2) .build(); - DoFnOperator doFnOperator = + DoFnOperator doFnOperator = new DoFnOperator<>( new IdentityDoFn<>(), "stepName", @@ -1305,7 +1305,7 @@ public void keyedParDoPushbackDataCheckpointing() throws Exception { .put(2, view2) .build(); - DoFnOperator doFnOperator = + DoFnOperator doFnOperator = new DoFnOperator<>( new IdentityDoFn<>(), "stepName", @@ -1504,7 +1504,7 @@ OneInputStreamOperatorTestHarness, WindowedValue> creat TypeInformation keyCoderInfo, KeySelector, K> keySelector) throws Exception { - DoFnOperator doFnOperator = + DoFnOperator doFnOperator = new DoFnOperator<>( fn, "stepName", @@ -1554,7 +1554,7 @@ public void finishBundle(FinishBundleContext context) { WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE), new SerializablePipelineOptions(options)); - DoFnOperator doFnOperator = + DoFnOperator doFnOperator = new DoFnOperator<>( doFn, "stepName", @@ -1603,7 +1603,7 @@ public void finishBundle(FinishBundleContext context) { testHarness.close(); - DoFnOperator newDoFnOperator = + DoFnOperator newDoFnOperator = new DoFnOperator<>( doFn, "stepName", @@ -1702,7 +1702,7 @@ public void finishBundle(FinishBundleContext context) { WindowedValue.getFullCoder(kvCoder.getValueCoder(), GlobalWindow.Coder.INSTANCE), new SerializablePipelineOptions(options)); - DoFnOperator, String> doFnOperator = + DoFnOperator, KV, String> doFnOperator = new DoFnOperator<>( doFn, "stepName", @@ -1819,7 +1819,7 @@ public void testCheckpointBufferingWithMultipleBundles() throws Exception { WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE), new SerializablePipelineOptions(options)); - Supplier> doFnOperatorSupplier = + Supplier> doFnOperatorSupplier = () -> new DoFnOperator<>( new IdentityDoFn<>(), @@ -1838,7 +1838,7 @@ public void testCheckpointBufferingWithMultipleBundles() throws Exception { DoFnSchemaInformation.create(), Collections.emptyMap()); - DoFnOperator doFnOperator = doFnOperatorSupplier.get(); + DoFnOperator doFnOperator = doFnOperatorSupplier.get(); OneInputStreamOperatorTestHarness, WindowedValue> testHarness = new OneInputStreamOperatorTestHarness<>(doFnOperator); @@ -1943,7 +1943,7 @@ public void finishBundle(FinishBundleContext context) { WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE), new SerializablePipelineOptions(options)); - Supplier> doFnOperatorSupplier = + Supplier> doFnOperatorSupplier = () -> new DoFnOperator<>( doFn, @@ -1962,7 +1962,7 @@ public void finishBundle(FinishBundleContext context) { DoFnSchemaInformation.create(), Collections.emptyMap()); - DoFnOperator doFnOperator = doFnOperatorSupplier.get(); + DoFnOperator doFnOperator = doFnOperatorSupplier.get(); OneInputStreamOperatorTestHarness, WindowedValue> testHarness = new OneInputStreamOperatorTestHarness<>(doFnOperator); @@ -2054,7 +2054,7 @@ public void finishBundle(FinishBundleContext context) { WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE), new SerializablePipelineOptions(options)); - Supplier> doFnOperatorSupplier = + Supplier> doFnOperatorSupplier = () -> new DoFnOperator<>( doFn, @@ -2073,7 +2073,7 @@ public void finishBundle(FinishBundleContext context) { DoFnSchemaInformation.create(), Collections.emptyMap()); - DoFnOperator doFnOperator = doFnOperatorSupplier.get(); + DoFnOperator doFnOperator = doFnOperatorSupplier.get(); OneInputStreamOperatorTestHarness, WindowedValue> testHarness = new OneInputStreamOperatorTestHarness<>(doFnOperator); @@ -2151,7 +2151,7 @@ public void finishBundle(FinishBundleContext context) { WindowedValue.getFullCoder(kvCoder, GlobalWindow.Coder.INSTANCE), new SerializablePipelineOptions(options)); - Supplier, KV>> doFnOperatorSupplier = + Supplier, KV, KV>> doFnOperatorSupplier = () -> new DoFnOperator<>( doFn, @@ -2170,7 +2170,7 @@ public void finishBundle(FinishBundleContext context) { DoFnSchemaInformation.create(), Collections.emptyMap()); - DoFnOperator, KV> doFnOperator = doFnOperatorSupplier.get(); + DoFnOperator, KV, KV> doFnOperator = doFnOperatorSupplier.get(); OneInputStreamOperatorTestHarness< WindowedValue>, WindowedValue>> testHarness = @@ -2307,7 +2307,7 @@ public void testBundleProcessingExceptionIsFatalDuringCheckpointing() throws Exc WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE), new SerializablePipelineOptions(options)); - DoFnOperator doFnOperator = + DoFnOperator doFnOperator = new DoFnOperator<>( new IdentityDoFn() { @FinishBundle @@ -2346,7 +2346,7 @@ public void finishBundle() { @Test public void testAccumulatorRegistrationOnOperatorClose() throws Exception { - DoFnOperator doFnOperator = getOperatorForCleanupInspection(); + DoFnOperator doFnOperator = getOperatorForCleanupInspection(); OneInputStreamOperatorTestHarness, WindowedValue> testHarness = new OneInputStreamOperatorTestHarness<>(doFnOperator); @@ -2382,7 +2382,7 @@ public void testRemoveCachedClassReferences() throws Exception { assertThat(typeCache.size(), is(0)); } - private static DoFnOperator getOperatorForCleanupInspection() { + private static DoFnOperator getOperatorForCleanupInspection() { FlinkPipelineOptions options = FlinkPipelineOptions.defaults(); options.setParallelism(4); diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperatorTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperatorTest.java index 8fab1bc6c167..fa00b942bad2 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperatorTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperatorTest.java @@ -1,310 +1,310 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.flink.translation.wrappers.streaming; - -import static java.util.Collections.emptyList; -import static java.util.Collections.emptyMap; -import static org.apache.beam.runners.flink.translation.wrappers.streaming.StreamRecordStripper.stripStreamRecordFromWindowedValue; -import static org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING; -import static org.apache.beam.sdk.transforms.windowing.PaneInfo.Timing.ON_TIME; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.containsInAnyOrder; -import static org.hamcrest.core.Is.is; -import static org.joda.time.Duration.standardMinutes; -import static org.junit.Assert.assertEquals; - -import java.io.ByteArrayOutputStream; -import java.nio.ByteBuffer; -import org.apache.beam.runners.core.KeyedWorkItem; -import org.apache.beam.runners.core.SystemReduceFn; -import org.apache.beam.runners.core.construction.SerializablePipelineOptions; -import org.apache.beam.runners.flink.FlinkPipelineOptions; -import org.apache.beam.runners.flink.translation.wrappers.streaming.DoFnOperator.MultiOutputOutputManagerFactory; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.CoderRegistry; -import org.apache.beam.sdk.coders.KvCoder; -import org.apache.beam.sdk.coders.VarLongCoder; -import org.apache.beam.sdk.transforms.Sum; -import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.transforms.windowing.FixedWindows; -import org.apache.beam.sdk.transforms.windowing.IntervalWindow; -import org.apache.beam.sdk.transforms.windowing.PaneInfo; -import org.apache.beam.sdk.util.AppliedCombineFn; -import org.apache.beam.sdk.util.WindowedValue; -import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder; -import org.apache.beam.sdk.values.KV; -import org.apache.beam.sdk.values.TupleTag; -import org.apache.beam.sdk.values.WindowingStrategy; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; -import org.apache.flink.api.java.functions.KeySelector; -import org.apache.flink.api.java.typeutils.GenericTypeInfo; -import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; -import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; -import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; -import org.joda.time.Duration; -import org.joda.time.Instant; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** Tests for {@link WindowDoFnOperator}. */ -@RunWith(JUnit4.class) -@SuppressWarnings({ - "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) -}) -public class WindowDoFnOperatorTest { - - @Test - public void testRestore() throws Exception { - // test harness - KeyedOneInputStreamOperatorTestHarness< - ByteBuffer, WindowedValue>, WindowedValue>> - testHarness = createTestHarness(getWindowDoFnOperator()); - testHarness.open(); - - // process elements - IntervalWindow window = new IntervalWindow(new Instant(0), Duration.millis(10_000)); - testHarness.processWatermark(0L); - testHarness.processElement( - Item.builder().key(1L).timestamp(1L).value(100L).window(window).build().toStreamRecord()); - testHarness.processElement( - Item.builder().key(1L).timestamp(2L).value(20L).window(window).build().toStreamRecord()); - testHarness.processElement( - Item.builder().key(2L).timestamp(3L).value(77L).window(window).build().toStreamRecord()); - - // create snapshot - OperatorSubtaskState snapshot = testHarness.snapshot(0, 0); - testHarness.close(); - - // restore from the snapshot - testHarness = createTestHarness(getWindowDoFnOperator()); - testHarness.initializeState(snapshot); - testHarness.open(); - - // close window - testHarness.processWatermark(10_000L); - - Iterable>> output = - stripStreamRecordFromWindowedValue(testHarness.getOutput()); - - assertEquals(2, Iterables.size(output)); - assertThat( - output, - containsInAnyOrder( - WindowedValue.of( - KV.of(1L, 120L), - new Instant(9_999), - window, - PaneInfo.createPane(true, true, ON_TIME)), - WindowedValue.of( - KV.of(2L, 77L), - new Instant(9_999), - window, - PaneInfo.createPane(true, true, ON_TIME)))); - // cleanup - testHarness.close(); - } - - @Test - public void testTimerCleanupOfPendingTimerList() throws Exception { - // test harness - WindowDoFnOperator windowDoFnOperator = getWindowDoFnOperator(); - KeyedOneInputStreamOperatorTestHarness< - ByteBuffer, WindowedValue>, WindowedValue>> - testHarness = createTestHarness(windowDoFnOperator); - testHarness.open(); - - DoFnOperator, KV>.FlinkTimerInternals timerInternals = - windowDoFnOperator.timerInternals; - - // process elements - IntervalWindow window = new IntervalWindow(new Instant(0), Duration.millis(100)); - IntervalWindow window2 = new IntervalWindow(new Instant(100), Duration.millis(100)); - testHarness.processWatermark(0L); - - // Use two different keys to check for correct watermark hold calculation - testHarness.processElement( - Item.builder().key(1L).timestamp(1L).value(100L).window(window).build().toStreamRecord()); - testHarness.processElement( - Item.builder() - .key(2L) - .timestamp(150L) - .value(150L) - .window(window2) - .build() - .toStreamRecord()); - - testHarness.processWatermark(1); - - // Note that the following is 1 because the state is key-partitioned - assertThat(Iterables.size(timerInternals.pendingTimersById.keys()), is(1)); - - assertThat(testHarness.numKeyedStateEntries(), is(6)); - // close bundle - testHarness.setProcessingTime( - testHarness.getProcessingTime() - + 2 * FlinkPipelineOptions.defaults().getMaxBundleTimeMills()); - assertThat(windowDoFnOperator.getCurrentOutputWatermark(), is(1L)); - - // close window - testHarness.processWatermark(100L); - - // Note that the following is zero because we only the first key is active - assertThat(Iterables.size(timerInternals.pendingTimersById.keys()), is(0)); - - assertThat(testHarness.numKeyedStateEntries(), is(3)); - - // close bundle - testHarness.setProcessingTime( - testHarness.getProcessingTime() - + 2 * FlinkPipelineOptions.defaults().getMaxBundleTimeMills()); - assertThat(windowDoFnOperator.getCurrentOutputWatermark(), is(100L)); - - testHarness.processWatermark(200L); - - // All the state has been cleaned up - assertThat(testHarness.numKeyedStateEntries(), is(0)); - - assertThat( - stripStreamRecordFromWindowedValue(testHarness.getOutput()), - containsInAnyOrder( - WindowedValue.of( - KV.of(1L, 100L), new Instant(99), window, PaneInfo.createPane(true, true, ON_TIME)), - WindowedValue.of( - KV.of(2L, 150L), - new Instant(199), - window2, - PaneInfo.createPane(true, true, ON_TIME)))); - - // cleanup - testHarness.close(); - } - - private WindowDoFnOperator getWindowDoFnOperator() { - WindowingStrategy windowingStrategy = - WindowingStrategy.of(FixedWindows.of(standardMinutes(1))); - - TupleTag> outputTag = new TupleTag<>("main-output"); - - SystemReduceFn reduceFn = - SystemReduceFn.combining( - VarLongCoder.of(), - AppliedCombineFn.withInputCoder( - Sum.ofLongs(), - CoderRegistry.createDefault(), - KvCoder.of(VarLongCoder.of(), VarLongCoder.of()))); - - Coder windowCoder = windowingStrategy.getWindowFn().windowCoder(); - SingletonKeyedWorkItemCoder workItemCoder = - SingletonKeyedWorkItemCoder.of(VarLongCoder.of(), VarLongCoder.of(), windowCoder); - FullWindowedValueCoder> inputCoder = - WindowedValue.getFullCoder(workItemCoder, windowCoder); - FullWindowedValueCoder> outputCoder = - WindowedValue.getFullCoder(KvCoder.of(VarLongCoder.of(), VarLongCoder.of()), windowCoder); - - return new WindowDoFnOperator( - reduceFn, - "stepName", - (Coder) inputCoder, - outputTag, - emptyList(), - new MultiOutputOutputManagerFactory<>( - outputTag, - outputCoder, - new SerializablePipelineOptions(FlinkPipelineOptions.defaults())), - windowingStrategy, - emptyMap(), - emptyList(), - FlinkPipelineOptions.defaults(), - VarLongCoder.of(), - new WorkItemKeySelector( - VarLongCoder.of(), new SerializablePipelineOptions(FlinkPipelineOptions.defaults()))); - } - - private KeyedOneInputStreamOperatorTestHarness< - ByteBuffer, WindowedValue>, WindowedValue>> - createTestHarness(WindowDoFnOperator windowDoFnOperator) throws Exception { - return new KeyedOneInputStreamOperatorTestHarness<>( - windowDoFnOperator, - (KeySelector>, ByteBuffer>) - o -> { - try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) { - VarLongCoder.of().encode(o.getValue().key(), baos); - return ByteBuffer.wrap(baos.toByteArray()); - } - }, - new GenericTypeInfo<>(ByteBuffer.class)); - } - - private static class Item { - - static ItemBuilder builder() { - return new ItemBuilder(); - } - - private long key; - private long value; - private long timestamp; - private IntervalWindow window; - - StreamRecord>> toStreamRecord() { - WindowedValue item = WindowedValue.of(value, new Instant(timestamp), window, NO_FIRING); - WindowedValue> keyedItem = - WindowedValue.of( - new SingletonKeyedWorkItem<>(key, item), new Instant(timestamp), window, NO_FIRING); - return new StreamRecord<>(keyedItem); - } - - private static final class ItemBuilder { - - private long key; - private long value; - private long timestamp; - private IntervalWindow window; - - ItemBuilder key(long key) { - this.key = key; - return this; - } - - ItemBuilder value(long value) { - this.value = value; - return this; - } - - ItemBuilder timestamp(long timestamp) { - this.timestamp = timestamp; - return this; - } - - ItemBuilder window(IntervalWindow window) { - this.window = window; - return this; - } - - Item build() { - Item item = new Item(); - item.key = this.key; - item.value = this.value; - item.window = this.window; - item.timestamp = this.timestamp; - return item; - } - } - } -} +///* +// * Licensed to the Apache Software Foundation (ASF) under one +// * or more contributor license agreements. See the NOTICE file +// * distributed with this work for additional information +// * regarding copyright ownership. The ASF licenses this file +// * to you under the Apache License, Version 2.0 (the +// * "License"); you may not use this file except in compliance +// * with the License. You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +//package org.apache.beam.runners.flink.translation.wrappers.streaming; +// +//import static java.util.Collections.emptyList; +//import static java.util.Collections.emptyMap; +//import static org.apache.beam.runners.flink.translation.wrappers.streaming.StreamRecordStripper.stripStreamRecordFromWindowedValue; +//import static org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING; +//import static org.apache.beam.sdk.transforms.windowing.PaneInfo.Timing.ON_TIME; +//import static org.hamcrest.MatcherAssert.assertThat; +//import static org.hamcrest.Matchers.containsInAnyOrder; +//import static org.hamcrest.core.Is.is; +//import static org.joda.time.Duration.standardMinutes; +//import static org.junit.Assert.assertEquals; +// +//import java.io.ByteArrayOutputStream; +//import java.nio.ByteBuffer; +//import org.apache.beam.runners.core.KeyedWorkItem; +//import org.apache.beam.runners.core.SystemReduceFn; +//import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +//import org.apache.beam.runners.flink.FlinkPipelineOptions; +//import org.apache.beam.runners.flink.translation.wrappers.streaming.DoFnOperator.MultiOutputOutputManagerFactory; +//import org.apache.beam.sdk.coders.Coder; +//import org.apache.beam.sdk.coders.CoderRegistry; +//import org.apache.beam.sdk.coders.KvCoder; +//import org.apache.beam.sdk.coders.VarLongCoder; +//import org.apache.beam.sdk.transforms.Sum; +//import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +//import org.apache.beam.sdk.transforms.windowing.FixedWindows; +//import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +//import org.apache.beam.sdk.transforms.windowing.PaneInfo; +//import org.apache.beam.sdk.util.AppliedCombineFn; +//import org.apache.beam.sdk.util.WindowedValue; +//import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder; +//import org.apache.beam.sdk.values.KV; +//import org.apache.beam.sdk.values.TupleTag; +//import org.apache.beam.sdk.values.WindowingStrategy; +//import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +//import org.apache.flink.api.java.functions.KeySelector; +//import org.apache.flink.api.java.typeutils.GenericTypeInfo; +//import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; +//import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +//import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; +//import org.joda.time.Duration; +//import org.joda.time.Instant; +//import org.junit.Test; +//import org.junit.runner.RunWith; +//import org.junit.runners.JUnit4; +// +///** Tests for {@link WindowDoFnOperator}. */ +//@RunWith(JUnit4.class) +//@SuppressWarnings({ +// "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) +//}) +//public class WindowDoFnOperatorTest { +// +// @Test +// public void testRestore() throws Exception { +// // test harness +// KeyedOneInputStreamOperatorTestHarness< +// ByteBuffer, WindowedValue>, WindowedValue>> +// testHarness = createTestHarness(getWindowDoFnOperator()); +// testHarness.open(); +// +// // process elements +// IntervalWindow window = new IntervalWindow(new Instant(0), Duration.millis(10_000)); +// testHarness.processWatermark(0L); +// testHarness.processElement( +// Item.builder().key(1L).timestamp(1L).value(100L).window(window).build().toStreamRecord()); +// testHarness.processElement( +// Item.builder().key(1L).timestamp(2L).value(20L).window(window).build().toStreamRecord()); +// testHarness.processElement( +// Item.builder().key(2L).timestamp(3L).value(77L).window(window).build().toStreamRecord()); +// +// // create snapshot +// OperatorSubtaskState snapshot = testHarness.snapshot(0, 0); +// testHarness.close(); +// +// // restore from the snapshot +// testHarness = createTestHarness(getWindowDoFnOperator()); +// testHarness.initializeState(snapshot); +// testHarness.open(); +// +// // close window +// testHarness.processWatermark(10_000L); +// +// Iterable>> output = +// stripStreamRecordFromWindowedValue(testHarness.getOutput()); +// +// assertEquals(2, Iterables.size(output)); +// assertThat( +// output, +// containsInAnyOrder( +// WindowedValue.of( +// KV.of(1L, 120L), +// new Instant(9_999), +// window, +// PaneInfo.createPane(true, true, ON_TIME)), +// WindowedValue.of( +// KV.of(2L, 77L), +// new Instant(9_999), +// window, +// PaneInfo.createPane(true, true, ON_TIME)))); +// // cleanup +// testHarness.close(); +// } +// +// @Test +// public void testTimerCleanupOfPendingTimerList() throws Exception { +// // test harness +// WindowDoFnOperator windowDoFnOperator = getWindowDoFnOperator(); +// KeyedOneInputStreamOperatorTestHarness< +// ByteBuffer, WindowedValue>, WindowedValue>> +// testHarness = createTestHarness(windowDoFnOperator); +// testHarness.open(); +// +// DoFnOperator, KeyedWorkItem, KV>.FlinkTimerInternals timerInternals = +// windowDoFnOperator.timerInternals; +// +// // process elements +// IntervalWindow window = new IntervalWindow(new Instant(0), Duration.millis(100)); +// IntervalWindow window2 = new IntervalWindow(new Instant(100), Duration.millis(100)); +// testHarness.processWatermark(0L); +// +// // Use two different keys to check for correct watermark hold calculation +// testHarness.processElement( +// Item.builder().key(1L).timestamp(1L).value(100L).window(window).build().toStreamRecord()); +// testHarness.processElement( +// Item.builder() +// .key(2L) +// .timestamp(150L) +// .value(150L) +// .window(window2) +// .build() +// .toStreamRecord()); +// +// testHarness.processWatermark(1); +// +// // Note that the following is 1 because the state is key-partitioned +// assertThat(Iterables.size(timerInternals.pendingTimersById.keys()), is(1)); +// +// assertThat(testHarness.numKeyedStateEntries(), is(6)); +// // close bundle +// testHarness.setProcessingTime( +// testHarness.getProcessingTime() +// + 2 * FlinkPipelineOptions.defaults().getMaxBundleTimeMills()); +// assertThat(windowDoFnOperator.getCurrentOutputWatermark(), is(1L)); +// +// // close window +// testHarness.processWatermark(100L); +// +// // Note that the following is zero because we only the first key is active +// assertThat(Iterables.size(timerInternals.pendingTimersById.keys()), is(0)); +// +// assertThat(testHarness.numKeyedStateEntries(), is(3)); +// +// // close bundle +// testHarness.setProcessingTime( +// testHarness.getProcessingTime() +// + 2 * FlinkPipelineOptions.defaults().getMaxBundleTimeMills()); +// assertThat(windowDoFnOperator.getCurrentOutputWatermark(), is(100L)); +// +// testHarness.processWatermark(200L); +// +// // All the state has been cleaned up +// assertThat(testHarness.numKeyedStateEntries(), is(0)); +// +// assertThat( +// stripStreamRecordFromWindowedValue(testHarness.getOutput()), +// containsInAnyOrder( +// WindowedValue.of( +// KV.of(1L, 100L), new Instant(99), window, PaneInfo.createPane(true, true, ON_TIME)), +// WindowedValue.of( +// KV.of(2L, 150L), +// new Instant(199), +// window2, +// PaneInfo.createPane(true, true, ON_TIME)))); +// +// // cleanup +// testHarness.close(); +// } +// +// private WindowDoFnOperator getWindowDoFnOperator() { +// WindowingStrategy windowingStrategy = +// WindowingStrategy.of(FixedWindows.of(standardMinutes(1))); +// +// TupleTag> outputTag = new TupleTag<>("main-output"); +// +// SystemReduceFn reduceFn = +// SystemReduceFn.combining( +// VarLongCoder.of(), +// AppliedCombineFn.withInputCoder( +// Sum.ofLongs(), +// CoderRegistry.createDefault(), +// KvCoder.of(VarLongCoder.of(), VarLongCoder.of()))); +// +// Coder windowCoder = windowingStrategy.getWindowFn().windowCoder(); +// SingletonKeyedWorkItemCoder workItemCoder = +// SingletonKeyedWorkItemCoder.of(VarLongCoder.of(), VarLongCoder.of(), windowCoder); +// FullWindowedValueCoder> inputCoder = +// WindowedValue.getFullCoder(workItemCoder, windowCoder); +// FullWindowedValueCoder> outputCoder = +// WindowedValue.getFullCoder(KvCoder.of(VarLongCoder.of(), VarLongCoder.of()), windowCoder); +// +// return new WindowDoFnOperator( +// reduceFn, +// "stepName", +// (Coder) inputCoder, +// outputTag, +// emptyList(), +// new MultiOutputOutputManagerFactory<>( +// outputTag, +// outputCoder, +// new SerializablePipelineOptions(FlinkPipelineOptions.defaults())), +// windowingStrategy, +// emptyMap(), +// emptyList(), +// FlinkPipelineOptions.defaults(), +// VarLongCoder.of(), +// new WorkItemKeySelector( +// VarLongCoder.of(), new SerializablePipelineOptions(FlinkPipelineOptions.defaults()))); +// } +// +// private KeyedOneInputStreamOperatorTestHarness< +// ByteBuffer, WindowedValue>, WindowedValue>> +// createTestHarness(WindowDoFnOperator windowDoFnOperator) throws Exception { +// return new KeyedOneInputStreamOperatorTestHarness<>( +// windowDoFnOperator, +// (KeySelector>, ByteBuffer>) +// o -> { +// try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) { +// VarLongCoder.of().encode(o.getValue().key(), baos); +// return ByteBuffer.wrap(baos.toByteArray()); +// } +// }, +// new GenericTypeInfo<>(ByteBuffer.class)); +// } +// +// private static class Item { +// +// static ItemBuilder builder() { +// return new ItemBuilder(); +// } +// +// private long key; +// private long value; +// private long timestamp; +// private IntervalWindow window; +// +// StreamRecord>> toStreamRecord() { +// WindowedValue item = WindowedValue.of(value, new Instant(timestamp), window, NO_FIRING); +// WindowedValue> keyedItem = +// WindowedValue.of( +// new SingletonKeyedWorkItem<>(key, item), new Instant(timestamp), window, NO_FIRING); +// return new StreamRecord<>(keyedItem); +// } +// +// private static final class ItemBuilder { +// +// private long key; +// private long value; +// private long timestamp; +// private IntervalWindow window; +// +// ItemBuilder key(long key) { +// this.key = key; +// return this; +// } +// +// ItemBuilder value(long value) { +// this.value = value; +// return this; +// } +// +// ItemBuilder timestamp(long timestamp) { +// this.timestamp = timestamp; +// return this; +// } +// +// ItemBuilder window(IntervalWindow window) { +// this.window = window; +// return this; +// } +// +// Item build() { +// Item item = new Item(); +// item.key = this.key; +// item.value = this.value; +// item.window = this.window; +// item.timestamp = this.timestamp; +// return item; +// } +// } +// } +//} From 16e37a28df56a28d384c8abde5c2f055fd223fcd Mon Sep 17 00:00:00 2001 From: jto Date: Mon, 19 Aug 2024 21:33:37 +0200 Subject: [PATCH 08/19] [Flink] Remove ToBinaryKV --- .../FlinkStreamingTransformTranslators.java | 62 +++++++++---------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java index 4ef57424d526..50600d22f297 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java @@ -946,56 +946,56 @@ public void translateNode( DataStream>> inputDataStream = context.getInputDataStream(input); - WindowedValue.FullWindowedValueCoder> windowedBinaryKVCoder = - WindowedValue.getFullCoder( - KvCoder.of(inputKvCoder.getKeyCoder(), ByteArrayCoder.of()), - input.getWindowingStrategy().getWindowFn().windowCoder()); +// WindowedValue.FullWindowedValueCoder> windowedBinaryKVCoder = +// WindowedValue.getFullCoder( +// KvCoder.of(inputKvCoder.getKeyCoder(), ByteArrayCoder.of()), +// input.getWindowingStrategy().getWindowFn().windowCoder()); - WindowedValue.FullWindowedValueCoder> windowedKeyedWorkItemCoder = + WindowedValue.FullWindowedValueCoder> windowedKeyedWorkItemCoder = WindowedValue.getFullCoder( KeyedWorkItemCoder.of( inputKvCoder.getKeyCoder(), - ByteArrayCoder.of(), + inputKvCoder.getValueCoder(), input.getWindowingStrategy().getWindowFn().windowCoder()), input.getWindowingStrategy().getWindowFn().windowCoder()); - CoderTypeInformation>> binaryKVTypeInfo = - new CoderTypeInformation<>(windowedBinaryKVCoder, context.getPipelineOptions()); +// CoderTypeInformation>> binaryKVTypeInfo = +// new CoderTypeInformation<>(windowedBinaryKVCoder, context.getPipelineOptions()); - DataStream>> inputBinaryDataStream = - inputDataStream - .flatMap(new ToBinaryKV<>(context.getPipelineOptions(), inputKvCoder.getValueCoder())) - .returns(binaryKVTypeInfo) - .name("ToBinaryKV"); +// DataStream>> inputBinaryDataStream = +// inputDataStream +// .flatMap(new ToBinaryKV<>(context.getPipelineOptions(), inputKvCoder.getValueCoder())) +// .returns(binaryKVTypeInfo) +// .name("ToBinaryKV"); - KvToByteBufferKeySelector keySelector = + KvToByteBufferKeySelector keySelector = new KvToByteBufferKeySelector<>( inputKvCoder.getKeyCoder(), new SerializablePipelineOptions(context.getPipelineOptions())); - KeyedStream>, ByteBuffer> keyedWorkItemStream = - inputBinaryDataStream.keyBy(keySelector); + KeyedStream>, ByteBuffer> keyedWorkItemStream = + inputDataStream.keyBy(keySelector); - SystemReduceFn, Iterable, BoundedWindow> reduceFn = - SystemReduceFn.buffering(ByteArrayCoder.of()); + SystemReduceFn, Iterable, BoundedWindow> reduceFn = + SystemReduceFn.buffering(inputKvCoder.getValueCoder()); - Coder>>> outputCoder = + Coder>>> outputCoder = WindowedValue.getFullCoder( - KvCoder.of(inputKvCoder.getKeyCoder(), IterableCoder.of(ByteArrayCoder.of())), + KvCoder.of(inputKvCoder.getKeyCoder(), IterableCoder.of(inputKvCoder.getValueCoder())), windowingStrategy.getWindowFn().windowCoder()); - TypeInformation>>> outputTypeInfo = + TypeInformation>>> outputTypeInfo = new CoderTypeInformation<>(outputCoder, context.getPipelineOptions()); - TupleTag>> mainTag = new TupleTag<>("main output"); + TupleTag>> mainTag = new TupleTag<>("main output"); - WorkItemKeySelector workItemKeySelector = - new WorkItemKeySelector( + WorkItemKeySelector workItemKeySelector = + new WorkItemKeySelector<>( inputKvCoder.getKeyCoder(), new SerializablePipelineOptions(context.getPipelineOptions())); String fullName = getCurrentTransformName(context); - WindowDoFnOperator> doFnOperator = + WindowDoFnOperator> doFnOperator = new WindowDoFnOperator<>( reduceFn, fullName, @@ -1016,12 +1016,12 @@ public void translateNode( final SingleOutputStreamOperator>>> outDataStream = keyedWorkItemStream .transform(fullName, outputTypeInfo, doFnOperator) - .uid(fullName) - .flatMap( - new ToGroupByKeyResult<>( - context.getPipelineOptions(), inputKvCoder.getValueCoder())) - .returns(context.getTypeInfo(context.getOutput(transform))) - .name("ToGBKResult"); + .uid(fullName); +// .flatMap( +// new ToGroupByKeyResult<>( +// context.getPipelineOptions(), inputKvCoder.getValueCoder())) +// .returns(context.getTypeInfo(context.getOutput(transform))) +// .name("ToGBKResult"); context.setOutputDataStream(context.getOutput(transform), outDataStream); } From 7147d4bd023f0794bc47c971e8eda5ae54790034 Mon Sep 17 00:00:00 2001 From: jto Date: Thu, 8 Aug 2024 17:10:11 +0200 Subject: [PATCH 09/19] [Flink] Refactor CombinePerKeyTranslator --- .../FlinkStreamingTransformTranslators.java | 232 +++++++++++++----- .../wrappers/streaming/DoFnOperatorTest.java | 40 +-- .../streaming/WindowDoFnOperatorTest.java | 154 ++++++------ 3 files changed, 278 insertions(+), 148 deletions(-) diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java index 50600d22f297..de9c2e114575 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java @@ -1048,75 +1048,213 @@ boolean canTranslate( || ((Combine.PerKey) transform).getSideInputs().isEmpty(); } - @Override - public void translateNode( + /* + private GlobalCombineFn toPartialFlinkCombineFn(GlobalCombineFn combineFn) { + + if(combineFn instanceof Combine.CombineFn) { + return new Combine.CombineFn() { + + Combine.CombineFn fn = + (Combine.CombineFn) combineFn; + + @Override + public Object createAccumulator() { + return fn.createAccumulator(); + } + + @Override + public Object addInput(Object mutableAccumulator, InputT input) { + return fn.addInput(mutableAccumulator, input); + } + + @Override + public Object mergeAccumulators(Iterable accumulators) { + return fn.mergeAccumulators(accumulators); + } + + @Override + public Object extractOutput(Object accumulator) { + return accumulator; + } + }; + } else if (combineFn instanceof CombineWithContext.CombineFnWithContext){ + return new CombineWithContext.CombineFnWithContext() { + CombineWithContext.CombineFnWithContext fn = + (CombineWithContext.CombineFnWithContext) combineFn; + @Override + public Object createAccumulator(CombineWithContext.Context c) { + return fn.createAccumulator(c); + } + + @Override + public Object addInput(Object accumulator, InputT input, CombineWithContext.Context c) { + return fn.addInput(accumulator, input, c); + } + + @Override + public Object mergeAccumulators(Iterable accumulators, CombineWithContext.Context c) { + return fn.mergeAccumulators(accumulators, c); + } + + @Override + public Object extractOutput(Object accumulator, CombineWithContext.Context c) { + return accumulator; + } + }; + } + + throw new IllegalArgumentException("Unsupported CombineFn implementation: " + combineFn.getClass()); + } + + private GlobalCombineFn toFinalFlinkCombineFn(GlobalCombineFn combineFn) { + + if(combineFn instanceof Combine.CombineFn) { + return new Combine.CombineFn() { + Combine.CombineFn fn = + (Combine.CombineFn) combineFn; + @Override + public Object createAccumulator() { + return fn.createAccumulator(); + } + + @Override + public Object addInput(Object mutableAccumulator, Object input) { + return fn.mergeAccumulators(ImmutableList.of(mutableAccumulator, input)); + } + + @Override + public Object mergeAccumulators(Iterable accumulators) { + return fn.mergeAccumulators(accumulators); + } + + @Override + public OutputT extractOutput(Object accumulator) { + return fn.extractOutput(accumulator); + } + }; + } else if (combineFn instanceof CombineWithContext.CombineFnWithContext){ + return new CombineWithContext.CombineFnWithContext() { + CombineWithContext.CombineFnWithContext fn = + (CombineWithContext.CombineFnWithContext) combineFn; + @Override + public Object createAccumulator(CombineWithContext.Context c) { + return fn.createAccumulator(c); + } + + @Override + public Object addInput(Object accumulator, Object input, CombineWithContext.Context c) { + return fn.mergeAccumulators(ImmutableList.of(accumulator, input), c); + } + + @Override + public Object mergeAccumulators(Iterable accumulators, CombineWithContext.Context c) { + return fn.mergeAccumulators(accumulators, c); + } + + @Override + public OutputT extractOutput(Object accumulator, CombineWithContext.Context c) { + return fn.extractOutput(accumulator, c); + } + }; + } + throw new IllegalArgumentException("Unsupported CombineFn implementation: " + combineFn.getClass()); + } + */ + + private WindowDoFnOperator getDoFnOperator( + FlinkStreamingTranslationContext context, PTransform>, PCollection>> transform, - FlinkStreamingTranslationContext context) { + GlobalCombineFn combineFn, + Map> sideInputTagMapping, + List> sideInputs) { + + // Naming String fullName = getCurrentTransformName(context); - PCollection> input = context.getInput(transform); + TupleTag> mainTag = new TupleTag<>("main output"); + // input infos + PCollection> input = context.getInput(transform); @SuppressWarnings("unchecked") WindowingStrategy windowingStrategy = (WindowingStrategy) input.getWindowingStrategy(); + SerializablePipelineOptions serializablePipelineOptions = + new SerializablePipelineOptions(context.getPipelineOptions()); + // Coders KvCoder inputKvCoder = (KvCoder) input.getCoder(); + Coder keyCoder = inputKvCoder.getKeyCoder(); SingletonKeyedWorkItemCoder workItemCoder = SingletonKeyedWorkItemCoder.of( - inputKvCoder.getKeyCoder(), + keyCoder, inputKvCoder.getValueCoder(), input.getWindowingStrategy().getWindowFn().windowCoder()); - DataStream>> inputDataStream = context.getInputDataStream(input); - WindowedValue.FullWindowedValueCoder> windowedWorkItemCoder = WindowedValue.getFullCoder( workItemCoder, input.getWindowingStrategy().getWindowFn().windowCoder()); - WorkItemKeySelector keySelector = - new WorkItemKeySelector<>( - inputKvCoder.getKeyCoder(), - new SerializablePipelineOptions(context.getPipelineOptions())); - - KeyedStream>, ByteBuffer> keyedStream = - inputDataStream.keyBy( - new KvToByteBufferKeySelector<>( - inputKvCoder.getKeyCoder(), - new SerializablePipelineOptions(context.getPipelineOptions()))); + Coder>> outputCoder = + context.getWindowedInputCoder(context.getOutput(transform)); - GlobalCombineFn combineFn = ((Combine.PerKey) transform).getFn(); + // Combining fn SystemReduceFn reduceFn = SystemReduceFn.combining( - inputKvCoder.getKeyCoder(), + keyCoder, AppliedCombineFn.withInputCoder( combineFn, input.getPipeline().getCoderRegistry(), inputKvCoder)); - Coder>> outputCoder = - context.getWindowedInputCoder(context.getOutput(transform)); + // Key selector + WorkItemKeySelector workItemKeySelector = + new WorkItemKeySelector<>(keyCoder, serializablePipelineOptions); + + return new WindowDoFnOperator<>( + reduceFn, + fullName, + (Coder) windowedWorkItemCoder, + mainTag, + Collections.emptyList(), + new DoFnOperator.MultiOutputOutputManagerFactory<>( + mainTag, outputCoder, serializablePipelineOptions), + windowingStrategy, + sideInputTagMapping, + sideInputs, + context.getPipelineOptions(), + keyCoder, + workItemKeySelector); + } + + @Override + public void translateNode( + PTransform>, PCollection>> transform, + FlinkStreamingTranslationContext context) { + String fullName = getCurrentTransformName(context); + + PCollection> input = context.getInput(transform); + + KvCoder inputKvCoder = (KvCoder) input.getCoder(); + Coder keyCoder = inputKvCoder.getKeyCoder(); + + DataStream>> inputDataStream = context.getInputDataStream(input); + + SerializablePipelineOptions serializablePipelineOptions = + new SerializablePipelineOptions(context.getPipelineOptions()); + + GlobalCombineFn combineFn = ((Combine.PerKey) transform).getFn(); + TypeInformation>> outputTypeInfo = context.getTypeInfo(context.getOutput(transform)); List> sideInputs = ((Combine.PerKey) transform).getSideInputs(); + KeyedStream>, ByteBuffer> keyedStream = + inputDataStream.keyBy( + new KvToByteBufferKeySelector<>(keyCoder, serializablePipelineOptions)); + if (sideInputs.isEmpty()) { - TupleTag> mainTag = new TupleTag<>("main output"); WindowDoFnOperator doFnOperator = - new WindowDoFnOperator<>( - reduceFn, - fullName, - (Coder) windowedWorkItemCoder, - mainTag, - Collections.emptyList(), - new DoFnOperator.MultiOutputOutputManagerFactory<>( - mainTag, - outputCoder, - new SerializablePipelineOptions(context.getPipelineOptions())), - windowingStrategy, - new HashMap<>(), /* side-input mapping */ - Collections.emptyList(), /* side inputs */ - context.getPipelineOptions(), - inputKvCoder.getKeyCoder(), - keySelector); + getDoFnOperator( + context, transform, combineFn, new HashMap<>(), Collections.emptyList()); SingleOutputStreamOperator>> outDataStream = keyedStream.transform(fullName, outputTypeInfo, doFnOperator).uid(fullName); @@ -1126,24 +1264,8 @@ public void translateNode( Tuple2>, DataStream> transformSideInputs = transformSideInputs(sideInputs, context); - TupleTag> mainTag = new TupleTag<>("main output"); WindowDoFnOperator doFnOperator = - new WindowDoFnOperator<>( - reduceFn, - fullName, - (Coder) windowedWorkItemCoder, - mainTag, - Collections.emptyList(), - new DoFnOperator.MultiOutputOutputManagerFactory<>( - mainTag, - outputCoder, - new SerializablePipelineOptions(context.getPipelineOptions())), - windowingStrategy, - transformSideInputs.f0, - sideInputs, - context.getPipelineOptions(), - inputKvCoder.getKeyCoder(), - keySelector); + getDoFnOperator(context, transform, combineFn, transformSideInputs.f0, sideInputs); // we have to manually contruct the two-input transform because we're not // allowed to have only one input keyed, normally. diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java index 124fae05b03e..73873d94f1b7 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java @@ -2151,26 +2151,28 @@ public void finishBundle(FinishBundleContext context) { WindowedValue.getFullCoder(kvCoder, GlobalWindow.Coder.INSTANCE), new SerializablePipelineOptions(options)); - Supplier, KV, KV>> doFnOperatorSupplier = - () -> - new DoFnOperator<>( - doFn, - "stepName", - windowedValueCoder, - Collections.emptyMap(), - outputTag, - Collections.emptyList(), - outputManagerFactory, - WindowingStrategy.globalDefault(), - new HashMap<>(), /* side-input mapping */ - Collections.emptyList(), /* side inputs */ - options, - keyCoder, - keySelector, - DoFnSchemaInformation.create(), - Collections.emptyMap()); + Supplier, KV, KV>> + doFnOperatorSupplier = + () -> + new DoFnOperator<>( + doFn, + "stepName", + windowedValueCoder, + Collections.emptyMap(), + outputTag, + Collections.emptyList(), + outputManagerFactory, + WindowingStrategy.globalDefault(), + new HashMap<>(), /* side-input mapping */ + Collections.emptyList(), /* side inputs */ + options, + keyCoder, + keySelector, + DoFnSchemaInformation.create(), + Collections.emptyMap()); - DoFnOperator, KV, KV> doFnOperator = doFnOperatorSupplier.get(); + DoFnOperator, KV, KV> doFnOperator = + doFnOperatorSupplier.get(); OneInputStreamOperatorTestHarness< WindowedValue>, WindowedValue>> testHarness = diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperatorTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperatorTest.java index fa00b942bad2..22713f6b33c6 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperatorTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperatorTest.java @@ -1,73 +1,75 @@ -///* -// * Licensed to the Apache Software Foundation (ASF) under one -// * or more contributor license agreements. See the NOTICE file -// * distributed with this work for additional information -// * regarding copyright ownership. The ASF licenses this file -// * to you under the Apache License, Version 2.0 (the -// * "License"); you may not use this file except in compliance -// * with the License. You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -//package org.apache.beam.runners.flink.translation.wrappers.streaming; -// -//import static java.util.Collections.emptyList; -//import static java.util.Collections.emptyMap; -//import static org.apache.beam.runners.flink.translation.wrappers.streaming.StreamRecordStripper.stripStreamRecordFromWindowedValue; -//import static org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING; -//import static org.apache.beam.sdk.transforms.windowing.PaneInfo.Timing.ON_TIME; -//import static org.hamcrest.MatcherAssert.assertThat; -//import static org.hamcrest.Matchers.containsInAnyOrder; -//import static org.hamcrest.core.Is.is; -//import static org.joda.time.Duration.standardMinutes; -//import static org.junit.Assert.assertEquals; -// -//import java.io.ByteArrayOutputStream; -//import java.nio.ByteBuffer; -//import org.apache.beam.runners.core.KeyedWorkItem; -//import org.apache.beam.runners.core.SystemReduceFn; -//import org.apache.beam.runners.core.construction.SerializablePipelineOptions; -//import org.apache.beam.runners.flink.FlinkPipelineOptions; -//import org.apache.beam.runners.flink.translation.wrappers.streaming.DoFnOperator.MultiOutputOutputManagerFactory; -//import org.apache.beam.sdk.coders.Coder; -//import org.apache.beam.sdk.coders.CoderRegistry; -//import org.apache.beam.sdk.coders.KvCoder; -//import org.apache.beam.sdk.coders.VarLongCoder; -//import org.apache.beam.sdk.transforms.Sum; -//import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -//import org.apache.beam.sdk.transforms.windowing.FixedWindows; -//import org.apache.beam.sdk.transforms.windowing.IntervalWindow; -//import org.apache.beam.sdk.transforms.windowing.PaneInfo; -//import org.apache.beam.sdk.util.AppliedCombineFn; -//import org.apache.beam.sdk.util.WindowedValue; -//import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder; -//import org.apache.beam.sdk.values.KV; -//import org.apache.beam.sdk.values.TupleTag; -//import org.apache.beam.sdk.values.WindowingStrategy; -//import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; -//import org.apache.flink.api.java.functions.KeySelector; -//import org.apache.flink.api.java.typeutils.GenericTypeInfo; -//import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; -//import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; -//import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; -//import org.joda.time.Duration; -//import org.joda.time.Instant; -//import org.junit.Test; -//import org.junit.runner.RunWith; -//import org.junit.runners.JUnit4; -// -///** Tests for {@link WindowDoFnOperator}. */ -//@RunWith(JUnit4.class) -//@SuppressWarnings({ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming; +// +// import static java.util.Collections.emptyList; +// import static java.util.Collections.emptyMap; +// import static +// org.apache.beam.runners.flink.translation.wrappers.streaming.StreamRecordStripper.stripStreamRecordFromWindowedValue; +// import static org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING; +// import static org.apache.beam.sdk.transforms.windowing.PaneInfo.Timing.ON_TIME; +// import static org.hamcrest.MatcherAssert.assertThat; +// import static org.hamcrest.Matchers.containsInAnyOrder; +// import static org.hamcrest.core.Is.is; +// import static org.joda.time.Duration.standardMinutes; +// import static org.junit.Assert.assertEquals; +// +// import java.io.ByteArrayOutputStream; +// import java.nio.ByteBuffer; +// import org.apache.beam.runners.core.KeyedWorkItem; +// import org.apache.beam.runners.core.SystemReduceFn; +// import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +// import org.apache.beam.runners.flink.FlinkPipelineOptions; +// import +// org.apache.beam.runners.flink.translation.wrappers.streaming.DoFnOperator.MultiOutputOutputManagerFactory; +// import org.apache.beam.sdk.coders.Coder; +// import org.apache.beam.sdk.coders.CoderRegistry; +// import org.apache.beam.sdk.coders.KvCoder; +// import org.apache.beam.sdk.coders.VarLongCoder; +// import org.apache.beam.sdk.transforms.Sum; +// import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +// import org.apache.beam.sdk.transforms.windowing.FixedWindows; +// import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +// import org.apache.beam.sdk.transforms.windowing.PaneInfo; +// import org.apache.beam.sdk.util.AppliedCombineFn; +// import org.apache.beam.sdk.util.WindowedValue; +// import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder; +// import org.apache.beam.sdk.values.KV; +// import org.apache.beam.sdk.values.TupleTag; +// import org.apache.beam.sdk.values.WindowingStrategy; +// import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +// import org.apache.flink.api.java.functions.KeySelector; +// import org.apache.flink.api.java.typeutils.GenericTypeInfo; +// import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; +// import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +// import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; +// import org.joda.time.Duration; +// import org.joda.time.Instant; +// import org.junit.Test; +// import org.junit.runner.RunWith; +// import org.junit.runners.JUnit4; +// +/// ** Tests for {@link WindowDoFnOperator}. */ +// @RunWith(JUnit4.class) +// @SuppressWarnings({ // "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) -//}) -//public class WindowDoFnOperatorTest { +// }) +// public class WindowDoFnOperatorTest { // // @Test // public void testRestore() throws Exception { @@ -129,7 +131,8 @@ // testHarness = createTestHarness(windowDoFnOperator); // testHarness.open(); // -// DoFnOperator, KeyedWorkItem, KV>.FlinkTimerInternals timerInternals = +// DoFnOperator, KeyedWorkItem, KV>.FlinkTimerInternals +// timerInternals = // windowDoFnOperator.timerInternals; // // // process elements @@ -184,7 +187,8 @@ // stripStreamRecordFromWindowedValue(testHarness.getOutput()), // containsInAnyOrder( // WindowedValue.of( -// KV.of(1L, 100L), new Instant(99), window, PaneInfo.createPane(true, true, ON_TIME)), +// KV.of(1L, 100L), new Instant(99), window, PaneInfo.createPane(true, true, +// ON_TIME)), // WindowedValue.of( // KV.of(2L, 150L), // new Instant(199), @@ -238,7 +242,8 @@ // // private KeyedOneInputStreamOperatorTestHarness< // ByteBuffer, WindowedValue>, WindowedValue>> -// createTestHarness(WindowDoFnOperator windowDoFnOperator) throws Exception { +// createTestHarness(WindowDoFnOperator windowDoFnOperator) throws Exception +// { // return new KeyedOneInputStreamOperatorTestHarness<>( // windowDoFnOperator, // (KeySelector>, ByteBuffer>) @@ -263,7 +268,8 @@ // private IntervalWindow window; // // StreamRecord>> toStreamRecord() { -// WindowedValue item = WindowedValue.of(value, new Instant(timestamp), window, NO_FIRING); +// WindowedValue item = WindowedValue.of(value, new Instant(timestamp), window, +// NO_FIRING); // WindowedValue> keyedItem = // WindowedValue.of( // new SingletonKeyedWorkItem<>(key, item), new Instant(timestamp), window, NO_FIRING); @@ -307,4 +313,4 @@ // } // } // } -//} +// } From 7146a3720bbc1214b14b600deae039507905c200 Mon Sep 17 00:00:00 2001 From: jto Date: Fri, 9 Aug 2024 10:11:44 +0200 Subject: [PATCH 10/19] [Flink] Combine before Reduce (no side-input only) [Flink] Implement partial reduce [Flink] dead code cleanup [Flink] spotless [Flink] persistent PartialReduceBundleOperator operator state --- .../flink/FlinkExecutionEnvironments.java | 9 +- .../FlinkStreamingTransformTranslators.java | 396 ++++++++---------- .../wrappers/streaming/DoFnOperator.java | 8 +- .../PartialReduceBundleOperator.java | 175 ++++++++ .../streaming/state/FlinkStateInternals.java | 17 +- 5 files changed, 370 insertions(+), 235 deletions(-) create mode 100644 runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/PartialReduceBundleOperator.java diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkExecutionEnvironments.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkExecutionEnvironments.java index 67a091b46ff4..8c5d2cecfc0d 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkExecutionEnvironments.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkExecutionEnvironments.java @@ -237,12 +237,15 @@ public static StreamExecutionEnvironment createStreamExecutionEnvironment( flinkStreamEnv.setParallelism(parallelism); if (options.getMaxParallelism() > 0) { flinkStreamEnv.setMaxParallelism(options.getMaxParallelism()); - } else if(!options.isStreaming()) { + } else if (!options.isStreaming()) { // In Flink maxParallelism defines the number of keyGroups. - // (see https://github.com/apache/flink/blob/e9dd4683f758b463d0b5ee18e49cecef6a70c5cf/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupRangeAssignment.java#L76) + // (see + // https://github.com/apache/flink/blob/e9dd4683f758b463d0b5ee18e49cecef6a70c5cf/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupRangeAssignment.java#L76) // The default value (parallelism * 1.5) - // (see https://github.com/apache/flink/blob/e9dd4683f758b463d0b5ee18e49cecef6a70c5cf/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupRangeAssignment.java#L137-L147) + // (see + // https://github.com/apache/flink/blob/e9dd4683f758b463d0b5ee18e49cecef6a70c5cf/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupRangeAssignment.java#L137-L147) // create a lot of skew so we force maxParallelism = parallelism in Batch mode. + LOG.info("Setting maxParallelism to {}", parallelism); flinkStreamEnv.setMaxParallelism(parallelism); } // set parallelism in the options (required by some execution code) diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java index de9c2e114575..244b5f83b78a 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java @@ -39,6 +39,7 @@ import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; import org.apache.beam.runners.flink.translation.wrappers.streaming.DoFnOperator; import org.apache.beam.runners.flink.translation.wrappers.streaming.KvToByteBufferKeySelector; +import org.apache.beam.runners.flink.translation.wrappers.streaming.PartialReduceBundleOperator; import org.apache.beam.runners.flink.translation.wrappers.streaming.ProcessingTimeCallbackCompat; import org.apache.beam.runners.flink.translation.wrappers.streaming.SingletonKeyedWorkItem; import org.apache.beam.runners.flink.translation.wrappers.streaming.SingletonKeyedWorkItemCoder; @@ -53,8 +54,10 @@ import org.apache.beam.runners.flink.translation.wrappers.streaming.io.source.bounded.FlinkBoundedSource; import org.apache.beam.runners.flink.translation.wrappers.streaming.io.source.unbounded.FlinkUnboundedSource; import org.apache.beam.sdk.coders.ByteArrayCoder; +import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderException; +import org.apache.beam.sdk.coders.CoderRegistry; import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.VoidCoder; @@ -66,6 +69,7 @@ import org.apache.beam.sdk.testing.TestStream; import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.CombineFnBase.GlobalCombineFn; +import org.apache.beam.sdk.transforms.CombineWithContext; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.DoFnSchemaInformation; import org.apache.beam.sdk.transforms.Impulse; @@ -97,6 +101,7 @@ import org.apache.beam.sdk.values.ValueWithRecordId; import org.apache.beam.sdk.values.WindowingStrategy; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; import org.apache.flink.api.common.eventtime.WatermarkStrategy; @@ -946,11 +951,6 @@ public void translateNode( DataStream>> inputDataStream = context.getInputDataStream(input); -// WindowedValue.FullWindowedValueCoder> windowedBinaryKVCoder = -// WindowedValue.getFullCoder( -// KvCoder.of(inputKvCoder.getKeyCoder(), ByteArrayCoder.of()), -// input.getWindowingStrategy().getWindowFn().windowCoder()); - WindowedValue.FullWindowedValueCoder> windowedKeyedWorkItemCoder = WindowedValue.getFullCoder( KeyedWorkItemCoder.of( @@ -959,29 +959,21 @@ public void translateNode( input.getWindowingStrategy().getWindowFn().windowCoder()), input.getWindowingStrategy().getWindowFn().windowCoder()); -// CoderTypeInformation>> binaryKVTypeInfo = -// new CoderTypeInformation<>(windowedBinaryKVCoder, context.getPipelineOptions()); - -// DataStream>> inputBinaryDataStream = -// inputDataStream -// .flatMap(new ToBinaryKV<>(context.getPipelineOptions(), inputKvCoder.getValueCoder())) -// .returns(binaryKVTypeInfo) -// .name("ToBinaryKV"); - KvToByteBufferKeySelector keySelector = new KvToByteBufferKeySelector<>( inputKvCoder.getKeyCoder(), new SerializablePipelineOptions(context.getPipelineOptions())); KeyedStream>, ByteBuffer> keyedWorkItemStream = - inputDataStream.keyBy(keySelector); + inputDataStream.keyBy(keySelector); SystemReduceFn, Iterable, BoundedWindow> reduceFn = SystemReduceFn.buffering(inputKvCoder.getValueCoder()); Coder>>> outputCoder = WindowedValue.getFullCoder( - KvCoder.of(inputKvCoder.getKeyCoder(), IterableCoder.of(inputKvCoder.getValueCoder())), + KvCoder.of( + inputKvCoder.getKeyCoder(), IterableCoder.of(inputKvCoder.getValueCoder())), windowingStrategy.getWindowFn().windowCoder()); TypeInformation>>> outputTypeInfo = @@ -1014,14 +1006,7 @@ public void translateNode( workItemKeySelector); final SingleOutputStreamOperator>>> outDataStream = - keyedWorkItemStream - .transform(fullName, outputTypeInfo, doFnOperator) - .uid(fullName); -// .flatMap( -// new ToGroupByKeyResult<>( -// context.getPipelineOptions(), inputKvCoder.getValueCoder())) -// .returns(context.getTypeInfo(context.getOutput(transform))) -// .name("ToGBKResult"); + keyedWorkItemStream.transform(fullName, outputTypeInfo, doFnOperator).uid(fullName); context.setOutputDataStream(context.getOutput(transform), outDataStream); } @@ -1048,129 +1033,94 @@ boolean canTranslate( || ((Combine.PerKey) transform).getSideInputs().isEmpty(); } - /* - private GlobalCombineFn toPartialFlinkCombineFn(GlobalCombineFn combineFn) { - - if(combineFn instanceof Combine.CombineFn) { - return new Combine.CombineFn() { - - Combine.CombineFn fn = - (Combine.CombineFn) combineFn; - - @Override - public Object createAccumulator() { - return fn.createAccumulator(); - } - - @Override - public Object addInput(Object mutableAccumulator, InputT input) { - return fn.addInput(mutableAccumulator, input); - } - - @Override - public Object mergeAccumulators(Iterable accumulators) { - return fn.mergeAccumulators(accumulators); - } - - @Override - public Object extractOutput(Object accumulator) { - return accumulator; - } - }; - } else if (combineFn instanceof CombineWithContext.CombineFnWithContext){ - return new CombineWithContext.CombineFnWithContext() { - CombineWithContext.CombineFnWithContext fn = - (CombineWithContext.CombineFnWithContext) combineFn; - @Override - public Object createAccumulator(CombineWithContext.Context c) { - return fn.createAccumulator(c); - } - - @Override - public Object addInput(Object accumulator, InputT input, CombineWithContext.Context c) { - return fn.addInput(accumulator, input, c); - } - - @Override - public Object mergeAccumulators(Iterable accumulators, CombineWithContext.Context c) { - return fn.mergeAccumulators(accumulators, c); - } - - @Override - public Object extractOutput(Object accumulator, CombineWithContext.Context c) { - return accumulator; - } - }; + private static GlobalCombineFn toFinalFlinkCombineFn( + GlobalCombineFn combineFn, Coder inputTCoder) { + + if (combineFn instanceof Combine.CombineFn) { + return new Combine.CombineFn() { + + @SuppressWarnings("unchecked") + final Combine.CombineFn fn = + (Combine.CombineFn) combineFn; + + @Override + public Object createAccumulator() { + return fn.createAccumulator(); } - throw new IllegalArgumentException("Unsupported CombineFn implementation: " + combineFn.getClass()); - } + @Override + public Coder getAccumulatorCoder(CoderRegistry registry, Coder inputCoder) + throws CannotProvideCoderException { + return fn.getAccumulatorCoder(registry, inputTCoder); + } - private GlobalCombineFn toFinalFlinkCombineFn(GlobalCombineFn combineFn) { - - if(combineFn instanceof Combine.CombineFn) { - return new Combine.CombineFn() { - Combine.CombineFn fn = - (Combine.CombineFn) combineFn; - @Override - public Object createAccumulator() { - return fn.createAccumulator(); - } - - @Override - public Object addInput(Object mutableAccumulator, Object input) { - return fn.mergeAccumulators(ImmutableList.of(mutableAccumulator, input)); - } - - @Override - public Object mergeAccumulators(Iterable accumulators) { - return fn.mergeAccumulators(accumulators); - } - - @Override - public OutputT extractOutput(Object accumulator) { - return fn.extractOutput(accumulator); - } - }; - } else if (combineFn instanceof CombineWithContext.CombineFnWithContext){ - return new CombineWithContext.CombineFnWithContext() { - CombineWithContext.CombineFnWithContext fn = - (CombineWithContext.CombineFnWithContext) combineFn; - @Override - public Object createAccumulator(CombineWithContext.Context c) { - return fn.createAccumulator(c); - } - - @Override - public Object addInput(Object accumulator, Object input, CombineWithContext.Context c) { - return fn.mergeAccumulators(ImmutableList.of(accumulator, input), c); - } - - @Override - public Object mergeAccumulators(Iterable accumulators, CombineWithContext.Context c) { - return fn.mergeAccumulators(accumulators, c); - } - - @Override - public OutputT extractOutput(Object accumulator, CombineWithContext.Context c) { - return fn.extractOutput(accumulator, c); - } - }; + @Override + public Object addInput(Object mutableAccumulator, Object input) { + return fn.mergeAccumulators(ImmutableList.of(mutableAccumulator, input)); } - throw new IllegalArgumentException("Unsupported CombineFn implementation: " + combineFn.getClass()); - } - */ - private WindowDoFnOperator getDoFnOperator( - FlinkStreamingTranslationContext context, - PTransform>, PCollection>> transform, - GlobalCombineFn combineFn, - Map> sideInputTagMapping, - List> sideInputs) { + @Override + public Object mergeAccumulators(Iterable accumulators) { + return fn.mergeAccumulators(accumulators); + } + + @Override + public OutputT extractOutput(Object accumulator) { + return fn.extractOutput(accumulator); + } + }; + } else if (combineFn instanceof CombineWithContext.CombineFnWithContext) { + return new CombineWithContext.CombineFnWithContext() { + + @SuppressWarnings("unchecked") + final CombineWithContext.CombineFnWithContext fn = + (CombineWithContext.CombineFnWithContext) combineFn; + + @Override + public Object createAccumulator(CombineWithContext.Context c) { + return fn.createAccumulator(c); + } + + @Override + public Coder getAccumulatorCoder(CoderRegistry registry, Coder inputCoder) + throws CannotProvideCoderException { + return fn.getAccumulatorCoder(registry, inputTCoder); + } + + @Override + public Object addInput(Object accumulator, Object input, CombineWithContext.Context c) { + return fn.mergeAccumulators(ImmutableList.of(accumulator, input), c); + } + + @Override + public Object mergeAccumulators( + Iterable accumulators, CombineWithContext.Context c) { + return fn.mergeAccumulators(accumulators, c); + } + + @Override + public OutputT extractOutput(Object accumulator, CombineWithContext.Context c) { + return fn.extractOutput(accumulator, c); + } + }; + } + throw new IllegalArgumentException( + "Unsupported CombineFn implementation: " + combineFn.getClass()); + } + + private static + WindowDoFnOperator getDoFnOperator( + FlinkStreamingTranslationContext context, + PTransform>, PCollection>> transform, + KvCoder inputKvCoder, + Coder>> outputCoder, + GlobalCombineFn combineFn, + Map> sideInputTagMapping, + List> sideInputs) { // Naming String fullName = getCurrentTransformName(context); - TupleTag> mainTag = new TupleTag<>("main output"); + TupleTag> mainTag = new TupleTag<>("main output"); // input infos PCollection> input = context.getInput(transform); @@ -1181,31 +1131,26 @@ private WindowDoFnOperator getDoFnOperator( new SerializablePipelineOptions(context.getPipelineOptions()); // Coders - KvCoder inputKvCoder = (KvCoder) input.getCoder(); Coder keyCoder = inputKvCoder.getKeyCoder(); - SingletonKeyedWorkItemCoder workItemCoder = + SingletonKeyedWorkItemCoder workItemCoder = SingletonKeyedWorkItemCoder.of( keyCoder, inputKvCoder.getValueCoder(), - input.getWindowingStrategy().getWindowFn().windowCoder()); - - WindowedValue.FullWindowedValueCoder> windowedWorkItemCoder = - WindowedValue.getFullCoder( - workItemCoder, input.getWindowingStrategy().getWindowFn().windowCoder()); + windowingStrategy.getWindowFn().windowCoder()); - Coder>> outputCoder = - context.getWindowedInputCoder(context.getOutput(transform)); + WindowedValue.FullWindowedValueCoder> windowedWorkItemCoder = + WindowedValue.getFullCoder(workItemCoder, windowingStrategy.getWindowFn().windowCoder()); // Combining fn - SystemReduceFn reduceFn = + SystemReduceFn reduceFn = SystemReduceFn.combining( keyCoder, AppliedCombineFn.withInputCoder( combineFn, input.getPipeline().getCoderRegistry(), inputKvCoder)); // Key selector - WorkItemKeySelector workItemKeySelector = + WorkItemKeySelector workItemKeySelector = new WorkItemKeySelector<>(keyCoder, serializablePipelineOptions); return new WindowDoFnOperator<>( @@ -1234,17 +1179,21 @@ public void translateNode( KvCoder inputKvCoder = (KvCoder) input.getCoder(); Coder keyCoder = inputKvCoder.getKeyCoder(); + Coder>> outputCoder = + context.getWindowedInputCoder(context.getOutput(transform)); DataStream>> inputDataStream = context.getInputDataStream(input); SerializablePipelineOptions serializablePipelineOptions = new SerializablePipelineOptions(context.getPipelineOptions()); - GlobalCombineFn combineFn = ((Combine.PerKey) transform).getFn(); + @SuppressWarnings("unchecked") + GlobalCombineFn combineFn = ((Combine.PerKey) transform).getFn(); TypeInformation>> outputTypeInfo = context.getTypeInfo(context.getOutput(transform)); + @SuppressWarnings("unchecked") List> sideInputs = ((Combine.PerKey) transform).getSideInputs(); KeyedStream>, ByteBuffer> keyedStream = @@ -1252,12 +1201,79 @@ public void translateNode( new KvToByteBufferKeySelector<>(keyCoder, serializablePipelineOptions)); if (sideInputs.isEmpty()) { - WindowDoFnOperator doFnOperator = - getDoFnOperator( - context, transform, combineFn, new HashMap<>(), Collections.emptyList()); + SingleOutputStreamOperator>> outDataStream; + + if (!context.isStreaming()) { + Coder>> windowedAccumCoder; + KvCoder accumKvCoder; + try { + @SuppressWarnings("unchecked") + Coder accumulatorCoder = + (Coder) + combineFn.getAccumulatorCoder( + input.getPipeline().getCoderRegistry(), inputKvCoder.getValueCoder()); + + accumKvCoder = KvCoder.of(inputKvCoder.getKeyCoder(), accumulatorCoder); + + windowedAccumCoder = + WindowedValue.getFullCoder( + accumKvCoder, input.getWindowingStrategy().getWindowFn().windowCoder()); + } catch (CannotProvideCoderException e) { + throw new RuntimeException(e); + } - SingleOutputStreamOperator>> outDataStream = - keyedStream.transform(fullName, outputTypeInfo, doFnOperator).uid(fullName); + TupleTag> mainTag = new TupleTag<>("main output"); + + PartialReduceBundleOperator partialDoFnOperator = + new PartialReduceBundleOperator<>( + (GlobalCombineFn) combineFn, + getCurrentTransformName(context), + context.getWindowedInputCoder(input), + mainTag, + Collections.emptyList(), + new DoFnOperator.MultiOutputOutputManagerFactory<>( + mainTag, windowedAccumCoder, serializablePipelineOptions), + input.getWindowingStrategy(), + new HashMap<>(), + Collections.emptyList(), + context.getPipelineOptions()); + + // final aggregation from AccumT to OutputT + WindowDoFnOperator finalDoFnOperator = + getDoFnOperator( + context, + transform, + accumKvCoder, + outputCoder, + toFinalFlinkCombineFn(combineFn, inputKvCoder.getValueCoder()), + new HashMap<>(), + Collections.emptyList()); + + String partialName = "Combine: " + fullName; + CoderTypeInformation>> partialTypeInfo = + new CoderTypeInformation<>(windowedAccumCoder, context.getPipelineOptions()); + + outDataStream = + inputDataStream + .transform(partialName, partialTypeInfo, partialDoFnOperator) + .uid(partialName) + .keyBy(new KvToByteBufferKeySelector<>(keyCoder, serializablePipelineOptions)) + .transform(fullName, outputTypeInfo, finalDoFnOperator) + .uid(fullName); + } else { + WindowDoFnOperator doFnOperator = + getDoFnOperator( + context, + transform, + inputKvCoder, + outputCoder, + combineFn, + new HashMap<>(), + Collections.emptyList()); + + outDataStream = + keyedStream.transform(fullName, outputTypeInfo, doFnOperator).uid(fullName); + } context.setOutputDataStream(context.getOutput(transform), outDataStream); } else { @@ -1265,7 +1281,14 @@ public void translateNode( transformSideInputs(sideInputs, context); WindowDoFnOperator doFnOperator = - getDoFnOperator(context, transform, combineFn, transformSideInputs.f0, sideInputs); + getDoFnOperator( + context, + transform, + inputKvCoder, + outputCoder, + combineFn, + transformSideInputs.f0, + sideInputs); // we have to manually contruct the two-input transform because we're not // allowed to have only one input keyed, normally. @@ -1455,65 +1478,6 @@ public void flatMap(T t, Collector collector) throws Exception { } } - static class ToBinaryKV - extends RichFlatMapFunction>, WindowedValue>> { - - private final SerializablePipelineOptions options; - private final Coder valueCoder; - - ToBinaryKV(PipelineOptions options, Coder valueCoder) { - this.options = new SerializablePipelineOptions(options); - this.valueCoder = valueCoder; - } - - @Override - public void open(Configuration parameters) { - // Initialize FileSystems for any coders which may want to use the FileSystem, - // see https://issues.apache.org/jira/browse/BEAM-8303 - FileSystems.setDefaultPipelineOptions(options.get()); - } - - @Override - public void flatMap( - WindowedValue> in, Collector>> out) - throws CoderException { - final byte[] binaryValue = CoderUtils.encodeToByteArray(valueCoder, in.getValue().getValue()); - out.collect(in.withValue(KV.of(in.getValue().getKey(), binaryValue))); - } - } - - static class ToGroupByKeyResult - extends RichFlatMapFunction< - WindowedValue>>, WindowedValue>>> { - - private final SerializablePipelineOptions options; - private final Coder valueCoder; - - ToGroupByKeyResult(PipelineOptions options, Coder valueCoder) { - this.options = new SerializablePipelineOptions(options); - this.valueCoder = valueCoder; - } - - @Override - public void open(Configuration parameters) { - // Initialize FileSystems for any coders which may want to use the FileSystem, - // see https://issues.apache.org/jira/browse/BEAM-8303 - FileSystems.setDefaultPipelineOptions(options.get()); - } - - @Override - public void flatMap( - WindowedValue>> element, - Collector>>> collector) - throws CoderException { - final List result = new ArrayList<>(); - for (byte[] binaryValue : element.getValue().getValue()) { - result.add(CoderUtils.decodeFromByteArray(valueCoder, binaryValue)); - } - collector.collect(element.withValue(KV.of(element.getValue().getKey(), result))); - } - } - /** Registers classes specialized to the Flink runner. */ @AutoService(TransformPayloadTranslatorRegistrar.class) public static class FlinkTransformsRegistrar implements TransformPayloadTranslatorRegistrar { diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java index 059826f8b897..3f076efbf298 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java @@ -268,7 +268,7 @@ public class DoFnOperator /** Constructor for DoFnOperator. */ public DoFnOperator( - DoFn doFn, + @Nullable DoFn doFn, String stepName, Coder> inputWindowedCoder, Map, Coder> outputCoders, @@ -279,8 +279,8 @@ public DoFnOperator( Map> sideInputTagMapping, Collection> sideInputs, PipelineOptions options, - Coder keyCoder, - KeySelector, ?> keySelector, + @Nullable Coder keyCoder, + @Nullable KeySelector, ?> keySelector, DoFnSchemaInformation doFnSchemaInformation, Map> sideInputMapping) { this.doFn = doFn; @@ -1014,7 +1014,7 @@ public void prepareSnapshotPreBarrier(long checkpointId) { } @Override - public final void snapshotState(StateSnapshotContext context) throws Exception { + public void snapshotState(StateSnapshotContext context) throws Exception { if (checkpointStats != null) { checkpointStats.snapshotStart(context.getCheckpointId()); } diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/PartialReduceBundleOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/PartialReduceBundleOperator.java new file mode 100644 index 000000000000..b81d19889622 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/PartialReduceBundleOperator.java @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming; + +import java.util.*; +import java.util.stream.Collectors; + +import org.apache.beam.runners.flink.translation.functions.AbstractFlinkCombineRunner; +import org.apache.beam.runners.flink.translation.functions.HashingFlinkCombineRunner; +import org.apache.beam.runners.flink.translation.functions.SortingFlinkCombineRunner; +import org.apache.beam.runners.flink.translation.types.CoderTypeSerializer; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.CombineFnBase; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFnSchemaInformation; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.Sessions; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ArrayListMultimap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.util.Collector; +import org.checkerframework.checker.nullness.qual.Nullable; + +public class PartialReduceBundleOperator + extends DoFnOperator, KV, KV> { + + private final CombineFnBase.GlobalCombineFn combineFn; + + private Multimap>> state; + private transient @Nullable ListState>> checkpointedState; + + public PartialReduceBundleOperator( + CombineFnBase.GlobalCombineFn combineFn, + String stepName, + Coder>> windowedInputCoder, + TupleTag> mainOutputTag, + List> additionalOutputTags, + OutputManagerFactory> outputManagerFactory, + WindowingStrategy windowingStrategy, + Map> sideInputTagMapping, + Collection> sideInputs, + PipelineOptions options) { + super( + null, + stepName, + windowedInputCoder, + Collections.emptyMap(), + mainOutputTag, + additionalOutputTags, + outputManagerFactory, + windowingStrategy, + sideInputTagMapping, + sideInputs, + options, + null, + null, + DoFnSchemaInformation.create(), + Collections.emptyMap()); + + this.combineFn = combineFn; + this.state = ArrayListMultimap.create(); + this.checkpointedState = null; + } + + @Override + public void open() throws Exception { + clearState(); + setBundleFinishedCallback(this::finishBundle); + super.open(); + } + + private void finishBundle() { + AbstractFlinkCombineRunner reduceRunner; + try { + if (windowingStrategy.needsMerge() && windowingStrategy.getWindowFn() instanceof Sessions) { + reduceRunner = new SortingFlinkCombineRunner<>(); + } else { + reduceRunner = new HashingFlinkCombineRunner<>(); + } + + for (Map.Entry>>> e : state.asMap().entrySet()) { + //noinspection unchecked + reduceRunner.combine( + new AbstractFlinkCombineRunner.PartialFlinkCombiner<>(combineFn), + (WindowingStrategy) windowingStrategy, + sideInputReader, + serializedOptions.get(), + e.getValue(), + new Collector>>() { + @Override + public void collect(WindowedValue> record) { + outputManager.output(mainOutputTag, record); + } + + @Override + public void close() {} + }); + } + + } catch (Exception e) { + throw new RuntimeException(e); + } + clearState(); + } + + private void clearState() { + this.state = ArrayListMultimap.create(); + if (this.checkpointedState != null) { + this.checkpointedState.clear(); + } + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + + ListStateDescriptor>> descriptor = + new ListStateDescriptor<>( + "buffered-elements", + new CoderTypeSerializer<>(windowedInputCoder, serializedOptions)); + + checkpointedState = context.getOperatorStateStore().getListState(descriptor); + + if(context.isRestored() && this.checkpointedState != null) { + for(WindowedValue> wkv : this.checkpointedState.get()) { + this.state.put(wkv.getValue().getKey(), wkv); + } + } + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + if (this.checkpointedState != null) { + this.checkpointedState.update(new ArrayList<>(this.state.values())); + } + } + + @Override + protected DoFn, KV> getDoFn() { + return new DoFn, KV>() { + @ProcessElement + public void processElement(ProcessContext c, BoundedWindow window) throws Exception { + WindowedValue> windowedValue = + WindowedValue.of(c.element(), c.timestamp(), window, c.pane()); + state.put(Objects.requireNonNull(c.element()).getKey(), windowedValue); + } + }; + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java index 8102582c4817..388271cdd68a 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java @@ -878,8 +878,7 @@ private static class FlinkCombiningState this.namespaceSerializer = namespaceSerializer; flinkStateDescriptor = - new ValueStateDescriptor<>( - stateId, new CoderTypeSerializer<>(accumCoder, fasterCopy)); + new ValueStateDescriptor<>(stateId, new CoderTypeSerializer<>(accumCoder, fasterCopy)); } @Override @@ -1047,8 +1046,7 @@ private static class FlinkCombiningStateWithContext this.namespaceSerializer = namespaceSerializer; flinkStateDescriptor = - new ValueStateDescriptor<>( - stateId, new CoderTypeSerializer<>(accumCoder, fasterCopy)); + new ValueStateDescriptor<>(stateId, new CoderTypeSerializer<>(accumCoder, fasterCopy)); } @Override @@ -1560,9 +1558,7 @@ private static class FlinkSetState implements SetState { this.flinkStateBackend = flinkStateBackend; this.flinkStateDescriptor = new MapStateDescriptor<>( - stateId, - new CoderTypeSerializer<>(coder, fasterCopy), - BooleanSerializer.INSTANCE); + stateId, new CoderTypeSerializer<>(coder, fasterCopy), BooleanSerializer.INSTANCE); } @Override @@ -1765,9 +1761,7 @@ public SetState bindSet(String id, StateSpec> spec, Coder keyedStateBackend.getOrCreateKeyedState( namespaceSerializer, new MapStateDescriptor<>( - id, - new CoderTypeSerializer<>(elemCoder, fasterCopy), - BooleanSerializer.INSTANCE)); + id, new CoderTypeSerializer<>(elemCoder, fasterCopy), BooleanSerializer.INSTANCE)); } catch (Exception e) { throw new RuntimeException(e); } @@ -1800,8 +1794,7 @@ public OrderedListState bindOrderedList( keyedStateBackend.getOrCreateKeyedState( namespaceSerializer, new ListStateDescriptor<>( - id, - new CoderTypeSerializer<>(TimestampedValueCoder.of(elemCoder), fasterCopy))); + id, new CoderTypeSerializer<>(TimestampedValueCoder.of(elemCoder), fasterCopy))); } catch (Exception e) { throw new RuntimeException(e); } From ad872abc0818992ed4f22cdba236607ec9c81dd6 Mon Sep 17 00:00:00 2001 From: jto Date: Fri, 23 Aug 2024 16:24:28 +0200 Subject: [PATCH 11/19] [Flink] Combine before GBK --- .../types/CoderTypeSerializer.java | 7 +- ...FlinkStreamingAggregationsTranslators.java | 286 ++++++++++ .../FlinkStreamingTransformTranslators.java | 504 ++++++------------ 3 files changed, 441 insertions(+), 356 deletions(-) create mode 100644 runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingAggregationsTranslators.java diff --git a/runners/flink/1.17/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java b/runners/flink/1.17/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java index 911dd3185adf..30dde7ace394 100644 --- a/runners/flink/1.17/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java +++ b/runners/flink/1.17/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java @@ -50,7 +50,12 @@ public class CoderTypeSerializer extends TypeSerializer { private final boolean fasterCopy; public CoderTypeSerializer(Coder coder, SerializablePipelineOptions pipelineOptions) { - this(coder, Preconditions.checkNotNull(pipelineOptions).get().as(FlinkPipelineOptions.class).getFasterCopy()); + this( + coder, + Preconditions.checkNotNull(pipelineOptions) + .get() + .as(FlinkPipelineOptions.class) + .getFasterCopy()); } public CoderTypeSerializer(Coder coder, boolean fasterCopy) { diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingAggregationsTranslators.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingAggregationsTranslators.java new file mode 100644 index 000000000000..60e0a1a8a058 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingAggregationsTranslators.java @@ -0,0 +1,286 @@ +package org.apache.beam.runners.flink; + +import org.apache.beam.runners.core.KeyedWorkItem; +import org.apache.beam.runners.core.SystemReduceFn; +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; +import org.apache.beam.runners.flink.translation.wrappers.streaming.*; +import org.apache.beam.sdk.coders.*; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.CombineFnBase; +import org.apache.beam.sdk.transforms.CombineWithContext; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.util.AppliedCombineFn; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.*; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; + +import java.util.*; + +public class FlinkStreamingAggregationsTranslators { + public static class ConcatenateAsIterable extends Combine.CombineFn, Iterable> { + @Override + public List createAccumulator() { + return new ArrayList<>(); + } + + @Override + public List addInput(List accumulator, T input) { + accumulator.add(input); + return accumulator; + } + + @Override + public List mergeAccumulators(Iterable> accumulators) { + List result = createAccumulator(); + for (List accumulator : accumulators) { + result.addAll(accumulator); + } + return result; + } + + @Override + public List extractOutput(List accumulator) { + return accumulator; + } + + @Override + public Coder> getAccumulatorCoder(CoderRegistry registry, Coder inputCoder) { + return ListCoder.of(inputCoder); + } + + @Override + public Coder> getDefaultOutputCoder(CoderRegistry registry, Coder inputCoder) { + return IterableCoder.of(inputCoder); + } + } + + private static CombineFnBase.GlobalCombineFn toFinalFlinkCombineFn( + CombineFnBase.GlobalCombineFn combineFn, Coder inputTCoder) { + + if (combineFn instanceof Combine.CombineFn) { + return new Combine.CombineFn() { + + @SuppressWarnings("unchecked") + final Combine.CombineFn fn = + (Combine.CombineFn) combineFn; + + @Override + public Object createAccumulator() { + return fn.createAccumulator(); + } + + @Override + public Coder getAccumulatorCoder(CoderRegistry registry, Coder inputCoder) + throws CannotProvideCoderException { + return fn.getAccumulatorCoder(registry, inputTCoder); + } + + @Override + public Object addInput(Object mutableAccumulator, Object input) { + return fn.mergeAccumulators(ImmutableList.of(mutableAccumulator, input)); + } + + @Override + public Object mergeAccumulators(Iterable accumulators) { + return fn.mergeAccumulators(accumulators); + } + + @Override + public OutputT extractOutput(Object accumulator) { + return fn.extractOutput(accumulator); + } + }; + } else if (combineFn instanceof CombineWithContext.CombineFnWithContext) { + return new CombineWithContext.CombineFnWithContext() { + + @SuppressWarnings("unchecked") + final CombineWithContext.CombineFnWithContext fn = + (CombineWithContext.CombineFnWithContext) combineFn; + + @Override + public Object createAccumulator(CombineWithContext.Context c) { + return fn.createAccumulator(c); + } + + @Override + public Coder getAccumulatorCoder(CoderRegistry registry, Coder inputCoder) + throws CannotProvideCoderException { + return fn.getAccumulatorCoder(registry, inputTCoder); + } + + @Override + public Object addInput(Object accumulator, Object input, CombineWithContext.Context c) { + return fn.mergeAccumulators(ImmutableList.of(accumulator, input), c); + } + + @Override + public Object mergeAccumulators( + Iterable accumulators, CombineWithContext.Context c) { + return fn.mergeAccumulators(accumulators, c); + } + + @Override + public OutputT extractOutput(Object accumulator, CombineWithContext.Context c) { + return fn.extractOutput(accumulator, c); + } + }; + } + throw new IllegalArgumentException( + "Unsupported CombineFn implementation: " + combineFn.getClass()); + } + + /** + * Create a DoFnOperator instance that group elements per window and apply a combine function on them. + */ + public static WindowDoFnOperator getWindowedAggregateDoFnOperator( + FlinkStreamingTranslationContext context, + PTransform>, PCollection>> transform, + KvCoder inputKvCoder, + Coder>> outputCoder, + SystemReduceFn reduceFn, + Map> sideInputTagMapping, + List> sideInputs) { + + // Naming + String fullName = FlinkStreamingTransformTranslators.getCurrentTransformName(context); + TupleTag> mainTag = new TupleTag<>("main output"); + + // input infos + PCollection> input = context.getInput(transform); + + @SuppressWarnings("unchecked") + WindowingStrategy windowingStrategy = + (WindowingStrategy) input.getWindowingStrategy(); + SerializablePipelineOptions serializablePipelineOptions = + new SerializablePipelineOptions(context.getPipelineOptions()); + + // Coders + Coder keyCoder = inputKvCoder.getKeyCoder(); + + SingletonKeyedWorkItemCoder workItemCoder = + SingletonKeyedWorkItemCoder.of( + keyCoder, + inputKvCoder.getValueCoder(), + windowingStrategy.getWindowFn().windowCoder()); + + WindowedValue.FullWindowedValueCoder> windowedWorkItemCoder = + WindowedValue.getFullCoder(workItemCoder, windowingStrategy.getWindowFn().windowCoder()); + + // Key selector + WorkItemKeySelector workItemKeySelector = + new WorkItemKeySelector<>(keyCoder, serializablePipelineOptions); + + return new WindowDoFnOperator<>( + reduceFn, + fullName, + (Coder) windowedWorkItemCoder, + mainTag, + Collections.emptyList(), + new DoFnOperator.MultiOutputOutputManagerFactory<>( + mainTag, outputCoder, serializablePipelineOptions), + windowingStrategy, + sideInputTagMapping, + sideInputs, + context.getPipelineOptions(), + keyCoder, + workItemKeySelector); + } + + public static WindowDoFnOperator getWindowedAggregateDoFnOperator( + FlinkStreamingTranslationContext context, + PTransform>, PCollection>> transform, + KvCoder inputKvCoder, + Coder>> outputCoder, + CombineFnBase.GlobalCombineFn combineFn, + Map> sideInputTagMapping, + List> sideInputs) { + + // Combining fn + SystemReduceFn reduceFn = + SystemReduceFn.combining( + inputKvCoder.getKeyCoder(), + AppliedCombineFn.withInputCoder( + combineFn, context.getInput(transform).getPipeline().getCoderRegistry(), inputKvCoder)); + + return getWindowedAggregateDoFnOperator(context, transform, inputKvCoder, outputCoder, reduceFn, sideInputTagMapping, sideInputs); + } + + public static SingleOutputStreamOperator>> batchCombinePerKeyNoSideInputs( + FlinkStreamingTranslationContext context, + PTransform>, PCollection>> transform, + CombineFnBase.GlobalCombineFn combineFn) { + + Coder>> windowedAccumCoder; + KvCoder accumKvCoder; + + PCollection> input = context.getInput(transform); + String fullName = FlinkStreamingTransformTranslators.getCurrentTransformName(context); + DataStream>> inputDataStream = context.getInputDataStream(input); + KvCoder inputKvCoder = + (KvCoder) input.getCoder(); + Coder>> outputCoder = + context.getWindowedInputCoder(context.getOutput(transform)); + SerializablePipelineOptions serializablePipelineOptions = + new SerializablePipelineOptions(context.getPipelineOptions()); + TypeInformation>> outputTypeInfo = + context.getTypeInfo(context.getOutput(transform)); + + try { + Coder accumulatorCoder = + combineFn.getAccumulatorCoder( + input.getPipeline().getCoderRegistry(), inputKvCoder.getValueCoder()); + + accumKvCoder = KvCoder.of(inputKvCoder.getKeyCoder(), accumulatorCoder); + + windowedAccumCoder = + WindowedValue.getFullCoder( + accumKvCoder, input.getWindowingStrategy().getWindowFn().windowCoder()); + } catch (CannotProvideCoderException e) { + throw new RuntimeException(e); + } + + TupleTag> mainTag = new TupleTag<>("main output"); + + PartialReduceBundleOperator partialDoFnOperator = + new PartialReduceBundleOperator<>( + combineFn, + FlinkStreamingTransformTranslators.getCurrentTransformName(context), + context.getWindowedInputCoder(input), + mainTag, + Collections.emptyList(), + new DoFnOperator.MultiOutputOutputManagerFactory<>( + mainTag, windowedAccumCoder, serializablePipelineOptions), + input.getWindowingStrategy(), + new HashMap<>(), + Collections.emptyList(), + context.getPipelineOptions()); + + // final aggregation from AccumT to OutputT + WindowDoFnOperator finalDoFnOperator = + getWindowedAggregateDoFnOperator( + context, + transform, + accumKvCoder, + outputCoder, + toFinalFlinkCombineFn(combineFn, inputKvCoder.getValueCoder()), + new HashMap<>(), + Collections.emptyList()); + + String partialName = "Combine: " + fullName; + CoderTypeInformation>> partialTypeInfo = + new CoderTypeInformation<>(windowedAccumCoder, context.getPipelineOptions()); + + return + inputDataStream + .transform(partialName, partialTypeInfo, partialDoFnOperator) + .uid(partialName) + .keyBy(new KvToByteBufferKeySelector<>(inputKvCoder.getKeyCoder(), serializablePipelineOptions)) + .transform(fullName, outputTypeInfo, finalDoFnOperator) + .uid(fullName); + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java index 244b5f83b78a..0a3d063f94c0 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java @@ -21,6 +21,7 @@ import static org.apache.beam.sdk.util.construction.SplittableParDo.SPLITTABLE_PROCESS_URN; import com.google.auto.service.AutoService; + import java.io.IOException; import java.nio.ByteBuffer; import java.util.ArrayList; @@ -29,17 +30,14 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import org.apache.beam.runners.core.KeyedWorkItem; -import org.apache.beam.runners.core.KeyedWorkItemCoder; -import org.apache.beam.runners.core.SplittableParDoViaKeyedWorkItems; -import org.apache.beam.runners.core.SystemReduceFn; + +import org.apache.beam.runners.core.*; import org.apache.beam.runners.core.construction.SerializablePipelineOptions; import org.apache.beam.runners.flink.translation.functions.FlinkAssignWindows; import org.apache.beam.runners.flink.translation.functions.ImpulseSourceFunction; import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; import org.apache.beam.runners.flink.translation.wrappers.streaming.DoFnOperator; import org.apache.beam.runners.flink.translation.wrappers.streaming.KvToByteBufferKeySelector; -import org.apache.beam.runners.flink.translation.wrappers.streaming.PartialReduceBundleOperator; import org.apache.beam.runners.flink.translation.wrappers.streaming.ProcessingTimeCallbackCompat; import org.apache.beam.runners.flink.translation.wrappers.streaming.SingletonKeyedWorkItem; import org.apache.beam.runners.flink.translation.wrappers.streaming.SingletonKeyedWorkItemCoder; @@ -53,14 +51,7 @@ import org.apache.beam.runners.flink.translation.wrappers.streaming.io.source.FlinkSource; import org.apache.beam.runners.flink.translation.wrappers.streaming.io.source.bounded.FlinkBoundedSource; import org.apache.beam.runners.flink.translation.wrappers.streaming.io.source.unbounded.FlinkUnboundedSource; -import org.apache.beam.sdk.coders.ByteArrayCoder; -import org.apache.beam.sdk.coders.CannotProvideCoderException; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.CoderException; -import org.apache.beam.sdk.coders.CoderRegistry; -import org.apache.beam.sdk.coders.IterableCoder; -import org.apache.beam.sdk.coders.KvCoder; -import org.apache.beam.sdk.coders.VoidCoder; +import org.apache.beam.sdk.coders.*; import org.apache.beam.sdk.io.BoundedSource; import org.apache.beam.sdk.io.FileSystems; import org.apache.beam.sdk.io.UnboundedSource; @@ -69,7 +60,6 @@ import org.apache.beam.sdk.testing.TestStream; import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.CombineFnBase.GlobalCombineFn; -import org.apache.beam.sdk.transforms.CombineWithContext; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.DoFnSchemaInformation; import org.apache.beam.sdk.transforms.Impulse; @@ -82,7 +72,6 @@ import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.transforms.windowing.WindowFn; -import org.apache.beam.sdk.util.AppliedCombineFn; import org.apache.beam.sdk.util.CoderUtils; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.util.construction.PTransformTranslation; @@ -101,7 +90,6 @@ import org.apache.beam.sdk.values.ValueWithRecordId; import org.apache.beam.sdk.values.WindowingStrategy; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; import org.apache.flink.api.common.eventtime.WatermarkStrategy; @@ -136,8 +124,8 @@ * encountered Beam transformations into Flink one, based on the mapping available in this class. */ @SuppressWarnings({ - "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) - "nullness" // TODO(https://github.com/apache/beam/issues/20497) + "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) + "nullness" // TODO(https://github.com/apache/beam/issues/20497) }) class FlinkStreamingTransformTranslators { @@ -145,7 +133,9 @@ class FlinkStreamingTransformTranslators { // Transform Translator Registry // -------------------------------------------------------------------------------------------- - /** A map from a Transform URN to the translator. */ + /** + * A map from a Transform URN to the translator. + */ @SuppressWarnings("rawtypes") private static final Map TRANSLATORS = new HashMap<>(); @@ -182,7 +172,7 @@ public static FlinkStreamingPipelineTranslator.StreamTransformTranslator getT } @SuppressWarnings("unchecked") - private static String getCurrentTransformName(FlinkStreamingTranslationContext context) { + public static String getCurrentTransformName(FlinkStreamingTranslationContext context) { return context.getCurrentTransform().getFullName(); } @@ -192,7 +182,7 @@ private static String getCurrentTransformName(FlinkStreamingTranslationContext c private static class UnboundedReadSourceTranslator extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< - PTransform>> { + PTransform>> { @Override public void translateNode( @@ -268,7 +258,7 @@ public void translateNode( static class ValueWithRecordIdKeySelector implements KeySelector>, ByteBuffer>, - ResultTypeQueryable { + ResultTypeQueryable { @Override public ByteBuffer getKey(WindowedValue> value) throws Exception { @@ -341,7 +331,7 @@ void translateNode(Impulse transform, FlinkStreamingTranslationContext context) private static class ReadSourceTranslator extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< - PTransform>> { + PTransform>> { private final BoundedReadSourceTranslator boundedTranslator = new BoundedReadSourceTranslator<>(); @@ -362,7 +352,7 @@ void translateNode( private static class BoundedReadSourceTranslator extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< - PTransform>> { + PTransform>> { @Override public void translateNode( @@ -414,7 +404,9 @@ public void translateNode( } } - /** Wraps each element in a {@link RawUnionValue} with the given tag id. */ + /** + * Wraps each element in a {@link RawUnionValue} with the given tag id. + */ public static class ToRawUnion extends RichMapFunction { private final int intTag; private final SerializablePipelineOptions options; @@ -438,8 +430,8 @@ public RawUnionValue map(T o) throws Exception { } private static Tuple2>, DataStream> - transformSideInputs( - Collection> sideInputs, FlinkStreamingTranslationContext context) { + transformSideInputs( + Collection> sideInputs, FlinkStreamingTranslationContext context) { // collect all side inputs Map, Integer> tagToIntMapping = new HashMap<>(); @@ -662,15 +654,15 @@ static void translateParDo( // allowed to have only one input keyed, normally. KeyedStream keyedStream = (KeyedStream) inputDataStream; TwoInputTransformation< - WindowedValue>, RawUnionValue, WindowedValue> + WindowedValue>, RawUnionValue, WindowedValue> rawFlinkTransform = - new TwoInputTransformation( - keyedStream.getTransformation(), - transformedSideInputs.f1.broadcast().getTransformation(), - transformName, - doFnOperator, - outputTypeInformation, - keyedStream.getParallelism()); + new TwoInputTransformation( + keyedStream.getTransformation(), + transformedSideInputs.f1.broadcast().getTransformation(), + transformName, + doFnOperator, + outputTypeInformation, + keyedStream.getParallelism()); rawFlinkTransform.setStateKeyType(keyedStream.getKeyType()); rawFlinkTransform.setStateKeySelectors(keyedStream.getKeySelector(), null); @@ -678,7 +670,8 @@ static void translateParDo( outputStream = new SingleOutputStreamOperator( keyedStream.getExecutionEnvironment(), - rawFlinkTransform) {}; // we have to cheat around the ctor being protected + rawFlinkTransform) { + }; // we have to cheat around the ctor being protected keyedStream.getExecutionEnvironment().addOperator(rawFlinkTransform); @@ -704,7 +697,7 @@ static void translateParDo( private static class ParDoStreamingTranslator extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< - PTransform, PCollectionTuple>> { + PTransform, PCollectionTuple>> { @Override public void translateNode( @@ -759,22 +752,22 @@ public void translateNode( sideInputMapping, context, (doFn1, - stepName, - sideInputs1, - mainOutputTag1, - additionalOutputTags1, - context1, - windowingStrategy, - tagsToOutputTags, - tagsToCoders, - tagsToIds, - windowedInputCoder, - outputCoders1, - keyCoder, - keySelector, - transformedSideInputs, - doFnSchemaInformation1, - sideInputMapping1) -> + stepName, + sideInputs1, + mainOutputTag1, + additionalOutputTags1, + context1, + windowingStrategy, + tagsToOutputTags, + tagsToCoders, + tagsToIds, + windowedInputCoder, + outputCoders1, + keyCoder, + keySelector, + transformedSideInputs, + doFnSchemaInformation1, + sideInputMapping1) -> new DoFnOperator<>( doFn1, stepName, @@ -800,15 +793,15 @@ public void translateNode( } private static class SplittableProcessElementsStreamingTranslator< - InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT> + InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT> extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< - SplittableParDoViaKeyedWorkItems.ProcessElements< - InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT>> { + SplittableParDoViaKeyedWorkItems.ProcessElements< + InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT>> { @Override public void translateNode( SplittableParDoViaKeyedWorkItems.ProcessElements< - InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT> + InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT> transform, FlinkStreamingTranslationContext context) { @@ -824,22 +817,22 @@ public void translateNode( Collections.emptyMap(), context, (doFn, - stepName, - sideInputs, - mainOutputTag, - additionalOutputTags, - context1, - windowingStrategy, - tagsToOutputTags, - tagsToCoders, - tagsToIds, - windowedInputCoder, - outputCoders1, - keyCoder, - keySelector, - transformedSideInputs, - doFnSchemaInformation, - sideInputMapping) -> + stepName, + sideInputs, + mainOutputTag, + additionalOutputTags, + context1, + windowingStrategy, + tagsToOutputTags, + tagsToCoders, + tagsToIds, + windowedInputCoder, + outputCoders1, + keyCoder, + keySelector, + transformedSideInputs, + doFnSchemaInformation, + sideInputMapping) -> new SplittableDoFnOperator<>( doFn, stepName, @@ -864,7 +857,7 @@ public void translateNode( private static class CreateViewStreamingTranslator extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< - CreateStreamingFlinkView.CreateFlinkPCollectionView> { + CreateStreamingFlinkView.CreateFlinkPCollectionView> { @Override public void translateNode( @@ -882,7 +875,7 @@ public void translateNode( private static class WindowAssignTranslator extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< - PTransform, PCollection>> { + PTransform, PCollection>> { @Override public void translateNode( @@ -918,7 +911,7 @@ public void translateNode( private static class ReshuffleTranslatorStreaming extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< - PTransform>, PCollection>>> { + PTransform>, PCollection>>> { @Override public void translateNode( @@ -934,7 +927,7 @@ public void translateNode( private static class GroupByKeyTranslator extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< - PTransform>, PCollection>>>> { + PTransform>, PCollection>>>> { @Override public void translateNode( @@ -942,79 +935,62 @@ public void translateNode( FlinkStreamingTranslationContext context) { PCollection> input = context.getInput(transform); - @SuppressWarnings("unchecked") WindowingStrategy windowingStrategy = (WindowingStrategy) input.getWindowingStrategy(); - KvCoder inputKvCoder = (KvCoder) input.getCoder(); - DataStream>> inputDataStream = context.getInputDataStream(input); - - WindowedValue.FullWindowedValueCoder> windowedKeyedWorkItemCoder = - WindowedValue.getFullCoder( - KeyedWorkItemCoder.of( - inputKvCoder.getKeyCoder(), - inputKvCoder.getValueCoder(), - input.getWindowingStrategy().getWindowFn().windowCoder()), - input.getWindowingStrategy().getWindowFn().windowCoder()); - - KvToByteBufferKeySelector keySelector = - new KvToByteBufferKeySelector<>( - inputKvCoder.getKeyCoder(), - new SerializablePipelineOptions(context.getPipelineOptions())); - - KeyedStream>, ByteBuffer> keyedWorkItemStream = - inputDataStream.keyBy(keySelector); - - SystemReduceFn, Iterable, BoundedWindow> reduceFn = - SystemReduceFn.buffering(inputKvCoder.getValueCoder()); - - Coder>>> outputCoder = - WindowedValue.getFullCoder( - KvCoder.of( - inputKvCoder.getKeyCoder(), IterableCoder.of(inputKvCoder.getValueCoder())), - windowingStrategy.getWindowFn().windowCoder()); - - TypeInformation>>> outputTypeInfo = - new CoderTypeInformation<>(outputCoder, context.getPipelineOptions()); - - TupleTag>> mainTag = new TupleTag<>("main output"); - - WorkItemKeySelector workItemKeySelector = - new WorkItemKeySelector<>( - inputKvCoder.getKeyCoder(), - new SerializablePipelineOptions(context.getPipelineOptions())); - String fullName = getCurrentTransformName(context); - WindowDoFnOperator> doFnOperator = - new WindowDoFnOperator<>( - reduceFn, - fullName, - windowedKeyedWorkItemCoder, - mainTag, - Collections.emptyList(), - new DoFnOperator.MultiOutputOutputManagerFactory<>( - mainTag, - outputCoder, - new SerializablePipelineOptions(context.getPipelineOptions())), - windowingStrategy, - new HashMap<>(), /* side-input mapping */ - Collections.emptyList(), /* side inputs */ - context.getPipelineOptions(), - inputKvCoder.getKeyCoder(), - workItemKeySelector); - final SingleOutputStreamOperator>>> outDataStream = - keyedWorkItemStream.transform(fullName, outputTypeInfo, doFnOperator).uid(fullName); + SingleOutputStreamOperator>>> outDataStream; + // Pre-aggregate before shuffle similar to group combine + if (!context.isStreaming()) { + outDataStream = + FlinkStreamingAggregationsTranslators.batchCombinePerKeyNoSideInputs( + context, + transform, + new FlinkStreamingAggregationsTranslators.ConcatenateAsIterable<>()); + } else { + // No pre-aggregation in Streaming mode. + KvToByteBufferKeySelector keySelector = + new KvToByteBufferKeySelector<>( + inputKvCoder.getKeyCoder(), + new SerializablePipelineOptions(context.getPipelineOptions())); + + Coder>>> outputCoder = + WindowedValue.getFullCoder( + KvCoder.of( + inputKvCoder.getKeyCoder(), IterableCoder.of(inputKvCoder.getValueCoder())), + windowingStrategy.getWindowFn().windowCoder()); + + TypeInformation>>> outputTypeInfo = + new CoderTypeInformation<>(outputCoder, context.getPipelineOptions()); + + WindowDoFnOperator> doFnOperator = + FlinkStreamingAggregationsTranslators.getWindowedAggregateDoFnOperator( + context, + transform, + inputKvCoder, + outputCoder, + SystemReduceFn.buffering(inputKvCoder.getValueCoder()), + new HashMap<>(), + Collections.emptyList()); + + outDataStream = + inputDataStream + .keyBy(keySelector) + .transform(fullName, outputTypeInfo, doFnOperator) + .uid(fullName); + } context.setOutputDataStream(context.getOutput(transform), outDataStream); + } } private static class CombinePerKeyTranslator extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< - PTransform>, PCollection>>> { + PTransform>, PCollection>>> { @Override boolean canTranslate( @@ -1033,142 +1009,6 @@ boolean canTranslate( || ((Combine.PerKey) transform).getSideInputs().isEmpty(); } - private static GlobalCombineFn toFinalFlinkCombineFn( - GlobalCombineFn combineFn, Coder inputTCoder) { - - if (combineFn instanceof Combine.CombineFn) { - return new Combine.CombineFn() { - - @SuppressWarnings("unchecked") - final Combine.CombineFn fn = - (Combine.CombineFn) combineFn; - - @Override - public Object createAccumulator() { - return fn.createAccumulator(); - } - - @Override - public Coder getAccumulatorCoder(CoderRegistry registry, Coder inputCoder) - throws CannotProvideCoderException { - return fn.getAccumulatorCoder(registry, inputTCoder); - } - - @Override - public Object addInput(Object mutableAccumulator, Object input) { - return fn.mergeAccumulators(ImmutableList.of(mutableAccumulator, input)); - } - - @Override - public Object mergeAccumulators(Iterable accumulators) { - return fn.mergeAccumulators(accumulators); - } - - @Override - public OutputT extractOutput(Object accumulator) { - return fn.extractOutput(accumulator); - } - }; - } else if (combineFn instanceof CombineWithContext.CombineFnWithContext) { - return new CombineWithContext.CombineFnWithContext() { - - @SuppressWarnings("unchecked") - final CombineWithContext.CombineFnWithContext fn = - (CombineWithContext.CombineFnWithContext) combineFn; - - @Override - public Object createAccumulator(CombineWithContext.Context c) { - return fn.createAccumulator(c); - } - - @Override - public Coder getAccumulatorCoder(CoderRegistry registry, Coder inputCoder) - throws CannotProvideCoderException { - return fn.getAccumulatorCoder(registry, inputTCoder); - } - - @Override - public Object addInput(Object accumulator, Object input, CombineWithContext.Context c) { - return fn.mergeAccumulators(ImmutableList.of(accumulator, input), c); - } - - @Override - public Object mergeAccumulators( - Iterable accumulators, CombineWithContext.Context c) { - return fn.mergeAccumulators(accumulators, c); - } - - @Override - public OutputT extractOutput(Object accumulator, CombineWithContext.Context c) { - return fn.extractOutput(accumulator, c); - } - }; - } - throw new IllegalArgumentException( - "Unsupported CombineFn implementation: " + combineFn.getClass()); - } - - private static - WindowDoFnOperator getDoFnOperator( - FlinkStreamingTranslationContext context, - PTransform>, PCollection>> transform, - KvCoder inputKvCoder, - Coder>> outputCoder, - GlobalCombineFn combineFn, - Map> sideInputTagMapping, - List> sideInputs) { - - // Naming - String fullName = getCurrentTransformName(context); - TupleTag> mainTag = new TupleTag<>("main output"); - - // input infos - PCollection> input = context.getInput(transform); - @SuppressWarnings("unchecked") - WindowingStrategy windowingStrategy = - (WindowingStrategy) input.getWindowingStrategy(); - SerializablePipelineOptions serializablePipelineOptions = - new SerializablePipelineOptions(context.getPipelineOptions()); - - // Coders - Coder keyCoder = inputKvCoder.getKeyCoder(); - - SingletonKeyedWorkItemCoder workItemCoder = - SingletonKeyedWorkItemCoder.of( - keyCoder, - inputKvCoder.getValueCoder(), - windowingStrategy.getWindowFn().windowCoder()); - - WindowedValue.FullWindowedValueCoder> windowedWorkItemCoder = - WindowedValue.getFullCoder(workItemCoder, windowingStrategy.getWindowFn().windowCoder()); - - // Combining fn - SystemReduceFn reduceFn = - SystemReduceFn.combining( - keyCoder, - AppliedCombineFn.withInputCoder( - combineFn, input.getPipeline().getCoderRegistry(), inputKvCoder)); - - // Key selector - WorkItemKeySelector workItemKeySelector = - new WorkItemKeySelector<>(keyCoder, serializablePipelineOptions); - - return new WindowDoFnOperator<>( - reduceFn, - fullName, - (Coder) windowedWorkItemCoder, - mainTag, - Collections.emptyList(), - new DoFnOperator.MultiOutputOutputManagerFactory<>( - mainTag, outputCoder, serializablePipelineOptions), - windowingStrategy, - sideInputTagMapping, - sideInputs, - context.getPipelineOptions(), - keyCoder, - workItemKeySelector); - } - @Override public void translateNode( PTransform>, PCollection>> transform, @@ -1182,7 +1022,8 @@ public void translateNode( Coder>> outputCoder = context.getWindowedInputCoder(context.getOutput(transform)); - DataStream>> inputDataStream = context.getInputDataStream(input); + DataStream>> inputDataStream = + context.getInputDataStream(input); SerializablePipelineOptions serializablePipelineOptions = new SerializablePipelineOptions(context.getPipelineOptions()); @@ -1204,65 +1045,10 @@ public void translateNode( SingleOutputStreamOperator>> outDataStream; if (!context.isStreaming()) { - Coder>> windowedAccumCoder; - KvCoder accumKvCoder; - try { - @SuppressWarnings("unchecked") - Coder accumulatorCoder = - (Coder) - combineFn.getAccumulatorCoder( - input.getPipeline().getCoderRegistry(), inputKvCoder.getValueCoder()); - - accumKvCoder = KvCoder.of(inputKvCoder.getKeyCoder(), accumulatorCoder); - - windowedAccumCoder = - WindowedValue.getFullCoder( - accumKvCoder, input.getWindowingStrategy().getWindowFn().windowCoder()); - } catch (CannotProvideCoderException e) { - throw new RuntimeException(e); - } - - TupleTag> mainTag = new TupleTag<>("main output"); - - PartialReduceBundleOperator partialDoFnOperator = - new PartialReduceBundleOperator<>( - (GlobalCombineFn) combineFn, - getCurrentTransformName(context), - context.getWindowedInputCoder(input), - mainTag, - Collections.emptyList(), - new DoFnOperator.MultiOutputOutputManagerFactory<>( - mainTag, windowedAccumCoder, serializablePipelineOptions), - input.getWindowingStrategy(), - new HashMap<>(), - Collections.emptyList(), - context.getPipelineOptions()); - - // final aggregation from AccumT to OutputT - WindowDoFnOperator finalDoFnOperator = - getDoFnOperator( - context, - transform, - accumKvCoder, - outputCoder, - toFinalFlinkCombineFn(combineFn, inputKvCoder.getValueCoder()), - new HashMap<>(), - Collections.emptyList()); - - String partialName = "Combine: " + fullName; - CoderTypeInformation>> partialTypeInfo = - new CoderTypeInformation<>(windowedAccumCoder, context.getPipelineOptions()); - - outDataStream = - inputDataStream - .transform(partialName, partialTypeInfo, partialDoFnOperator) - .uid(partialName) - .keyBy(new KvToByteBufferKeySelector<>(keyCoder, serializablePipelineOptions)) - .transform(fullName, outputTypeInfo, finalDoFnOperator) - .uid(fullName); + outDataStream = FlinkStreamingAggregationsTranslators.batchCombinePerKeyNoSideInputs(context, transform, combineFn); } else { WindowDoFnOperator doFnOperator = - getDoFnOperator( + FlinkStreamingAggregationsTranslators.getWindowedAggregateDoFnOperator( context, transform, inputKvCoder, @@ -1281,7 +1067,7 @@ public void translateNode( transformSideInputs(sideInputs, context); WindowDoFnOperator doFnOperator = - getDoFnOperator( + FlinkStreamingAggregationsTranslators.getWindowedAggregateDoFnOperator( context, transform, inputKvCoder, @@ -1294,15 +1080,15 @@ public void translateNode( // allowed to have only one input keyed, normally. TwoInputTransformation< - WindowedValue>, RawUnionValue, WindowedValue>> + WindowedValue>, RawUnionValue, WindowedValue>> rawFlinkTransform = - new TwoInputTransformation<>( - keyedStream.getTransformation(), - transformSideInputs.f1.broadcast().getTransformation(), - transform.getName(), - doFnOperator, - outputTypeInfo, - keyedStream.getParallelism()); + new TwoInputTransformation<>( + keyedStream.getTransformation(), + transformSideInputs.f1.broadcast().getTransformation(), + transform.getName(), + doFnOperator, + outputTypeInfo, + keyedStream.getParallelism()); rawFlinkTransform.setStateKeyType(keyedStream.getKeyType()); rawFlinkTransform.setStateKeySelectors(keyedStream.getKeySelector(), null); @@ -1311,7 +1097,8 @@ public void translateNode( SingleOutputStreamOperator>> outDataStream = new SingleOutputStreamOperator( keyedStream.getExecutionEnvironment(), - rawFlinkTransform) {}; // we have to cheat around the ctor being protected + rawFlinkTransform) { + }; // we have to cheat around the ctor being protected keyedStream.getExecutionEnvironment().addOperator(rawFlinkTransform); @@ -1322,7 +1109,7 @@ public void translateNode( private static class GBKIntoKeyedWorkItemsTranslator extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< - PTransform>, PCollection>>> { + PTransform>, PCollection>>> { @Override boolean canTranslate( @@ -1372,7 +1159,7 @@ public void translateNode( private static class ToKeyedWorkItemInGlobalWindow extends RichFlatMapFunction< - WindowedValue>, WindowedValue>> { + WindowedValue>, WindowedValue>> { private final SerializablePipelineOptions options; @@ -1410,7 +1197,7 @@ public void flatMap( private static class FlattenPCollectionTranslator extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< - PTransform, PCollection>> { + PTransform, PCollection>> { @Override public void translateNode( @@ -1478,14 +1265,16 @@ public void flatMap(T t, Collector collector) throws Exception { } } - /** Registers classes specialized to the Flink runner. */ + /** + * Registers classes specialized to the Flink runner. + */ @AutoService(TransformPayloadTranslatorRegistrar.class) public static class FlinkTransformsRegistrar implements TransformPayloadTranslatorRegistrar { @Override public Map< - ? extends Class, - ? extends PTransformTranslation.TransformPayloadTranslator> - getTransformPayloadTranslators() { + ? extends Class, + ? extends PTransformTranslation.TransformPayloadTranslator> + getTransformPayloadTranslators() { return ImmutableMap ., PTransformTranslation.TransformPayloadTranslator>builder() .put( @@ -1495,12 +1284,15 @@ public static class FlinkTransformsRegistrar implements TransformPayloadTranslat } } - /** A translator just to vend the URN. */ + /** + * A translator just to vend the URN. + */ private static class CreateStreamingFlinkViewPayloadTranslator extends PTransformTranslation.TransformPayloadTranslator.NotSerializable< - CreateStreamingFlinkView.CreateFlinkPCollectionView> { + CreateStreamingFlinkView.CreateFlinkPCollectionView> { - private CreateStreamingFlinkViewPayloadTranslator() {} + private CreateStreamingFlinkViewPayloadTranslator() { + } @Override public String getUrn() { @@ -1508,7 +1300,9 @@ public String getUrn() { } } - /** A translator to support {@link TestStream} with Flink. */ + /** + * A translator to support {@link TestStream} with Flink. + */ private static class TestStreamTranslator extends FlinkStreamingPipelineTranslator.StreamTransformTranslator> { @@ -1554,12 +1348,12 @@ void translateNode(TestStream testStream, FlinkStreamingTranslationContext co * {@link ValueWithRecordId}. */ static class UnboundedSourceWrapperNoValueWithRecordId< - OutputT, CheckpointMarkT extends UnboundedSource.CheckpointMark> + OutputT, CheckpointMarkT extends UnboundedSource.CheckpointMark> extends RichParallelSourceFunction> implements ProcessingTimeCallbackCompat, - BeamStoppableFunction, - CheckpointListener, - CheckpointedFunction { + BeamStoppableFunction, + CheckpointListener, + CheckpointedFunction { private final UnboundedSourceWrapper unboundedSourceWrapper; From ec3c54e9a00f996f25c0534e6f68e8475e83f1e8 Mon Sep 17 00:00:00 2001 From: jto Date: Wed, 28 Aug 2024 16:18:55 +0200 Subject: [PATCH 12/19] [Flink] Combine before reduce (with side input) --- ...FlinkStreamingAggregationsTranslators.java | 102 +++++++++++++++--- .../FlinkStreamingTransformTranslators.java | 61 +++++------ 2 files changed, 110 insertions(+), 53 deletions(-) diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingAggregationsTranslators.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingAggregationsTranslators.java index 60e0a1a8a058..882a6fd18cd1 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingAggregationsTranslators.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingAggregationsTranslators.java @@ -10,15 +10,20 @@ import org.apache.beam.sdk.transforms.CombineFnBase; import org.apache.beam.sdk.transforms.CombineWithContext; import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.join.RawUnionValue; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.AppliedCombineFn; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.*; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.KeyedStream; import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.transformations.TwoInputTransformation; +import java.nio.ByteBuffer; import java.util.*; public class FlinkStreamingAggregationsTranslators { @@ -210,10 +215,12 @@ public static WindowDoFnOperator SingleOutputStreamOperator>> batchCombinePerKeyNoSideInputs( + public static SingleOutputStreamOperator>> batchCombinePerKey( FlinkStreamingTranslationContext context, PTransform>, PCollection>> transform, - CombineFnBase.GlobalCombineFn combineFn) { + CombineFnBase.GlobalCombineFn combineFn, + Map> sideInputTagMapping, + List> sideInputs) { Coder>> windowedAccumCoder; KvCoder accumKvCoder; @@ -249,17 +256,24 @@ public static SingleOutputStreamOperator partialDoFnOperator = new PartialReduceBundleOperator<>( combineFn, - FlinkStreamingTransformTranslators.getCurrentTransformName(context), + fullName, context.getWindowedInputCoder(input), mainTag, Collections.emptyList(), new DoFnOperator.MultiOutputOutputManagerFactory<>( mainTag, windowedAccumCoder, serializablePipelineOptions), input.getWindowingStrategy(), - new HashMap<>(), - Collections.emptyList(), + sideInputTagMapping, + sideInputs, context.getPipelineOptions()); + String partialName = "Combine: " + fullName; + CoderTypeInformation>> partialTypeInfo = + new CoderTypeInformation<>(windowedAccumCoder, context.getPipelineOptions()); + + KvToByteBufferKeySelector accumKeySelector = + new KvToByteBufferKeySelector<>(inputKvCoder.getKeyCoder(), serializablePipelineOptions); + // final aggregation from AccumT to OutputT WindowDoFnOperator finalDoFnOperator = getWindowedAggregateDoFnOperator( @@ -268,19 +282,73 @@ public static SingleOutputStreamOperator(), - Collections.emptyList()); - - String partialName = "Combine: " + fullName; - CoderTypeInformation>> partialTypeInfo = - new CoderTypeInformation<>(windowedAccumCoder, context.getPipelineOptions()); - - return - inputDataStream + sideInputTagMapping, + sideInputs); + + if(sideInputs.isEmpty()) { + return + inputDataStream + .transform(partialName, partialTypeInfo, partialDoFnOperator) + .uid(partialName) + .keyBy(accumKeySelector) + .transform(fullName, outputTypeInfo, finalDoFnOperator) + .uid(fullName); + } else { + Tuple2>, DataStream> transformSideInputs = + FlinkStreamingTransformTranslators.transformSideInputs(sideInputs, context); + + KeyedStream>, ByteBuffer> keyedStream = + inputDataStream .transform(partialName, partialTypeInfo, partialDoFnOperator) .uid(partialName) - .keyBy(new KvToByteBufferKeySelector<>(inputKvCoder.getKeyCoder(), serializablePipelineOptions)) - .transform(fullName, outputTypeInfo, finalDoFnOperator) - .uid(fullName); + .keyBy(accumKeySelector); + + return buildTwoInputStream(keyedStream, transformSideInputs.f1, transform.getName(), finalDoFnOperator, outputTypeInfo); + } + } + + @SuppressWarnings({ + "nullness" // TODO(https://github.com/apache/beam/issues/20497) + }) + public static SingleOutputStreamOperator>> buildTwoInputStream( + KeyedStream>, ByteBuffer> keyedStream, + DataStream sideInputStream, + String name, + WindowDoFnOperator operator, + TypeInformation>> outputTypeInfo + ) { + // we have to manually construct the two-input transform because we're not + // allowed to have only one input keyed, normally. + TwoInputTransformation< + WindowedValue>, RawUnionValue, WindowedValue>> + rawFlinkTransform = + new TwoInputTransformation<>( + keyedStream.getTransformation(), + sideInputStream.broadcast().getTransformation(), + name, + operator, + outputTypeInfo, + keyedStream.getParallelism()); + + rawFlinkTransform.setStateKeyType(keyedStream.getKeyType()); + rawFlinkTransform.setStateKeySelectors(keyedStream.getKeySelector(), null); + + @SuppressWarnings({"unchecked", "rawtypes"}) + SingleOutputStreamOperator>> outDataStream = + new SingleOutputStreamOperator( + keyedStream.getExecutionEnvironment(), + rawFlinkTransform) { + }; // we have to cheat around the ctor being protected + + keyedStream.getExecutionEnvironment().addOperator(rawFlinkTransform); + + return outDataStream; + } + + public static SingleOutputStreamOperator>> batchCombinePerKeyNoSideInputs( + FlinkStreamingTranslationContext context, + PTransform>, PCollection>> transform, + CombineFnBase.GlobalCombineFn combineFn) { + return batchCombinePerKey(context, transform, combineFn, new HashMap<>(), Collections.emptyList()); } } diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java index 0a3d063f94c0..09be4dae0e77 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java @@ -429,7 +429,7 @@ public RawUnionValue map(T o) throws Exception { } } - private static Tuple2>, DataStream> + public static Tuple2>, DataStream> transformSideInputs( Collection> sideInputs, FlinkStreamingTranslationContext context) { @@ -1045,7 +1045,8 @@ public void translateNode( SingleOutputStreamOperator>> outDataStream; if (!context.isStreaming()) { - outDataStream = FlinkStreamingAggregationsTranslators.batchCombinePerKeyNoSideInputs(context, transform, combineFn); + outDataStream = + FlinkStreamingAggregationsTranslators.batchCombinePerKeyNoSideInputs(context, transform, combineFn); } else { WindowDoFnOperator doFnOperator = FlinkStreamingAggregationsTranslators.getWindowedAggregateDoFnOperator( @@ -1065,42 +1066,30 @@ public void translateNode( } else { Tuple2>, DataStream> transformSideInputs = transformSideInputs(sideInputs, context); + SingleOutputStreamOperator>> outDataStream; - WindowDoFnOperator doFnOperator = - FlinkStreamingAggregationsTranslators.getWindowedAggregateDoFnOperator( - context, - transform, - inputKvCoder, - outputCoder, - combineFn, - transformSideInputs.f0, - sideInputs); - - // we have to manually contruct the two-input transform because we're not - // allowed to have only one input keyed, normally. - - TwoInputTransformation< - WindowedValue>, RawUnionValue, WindowedValue>> - rawFlinkTransform = - new TwoInputTransformation<>( - keyedStream.getTransformation(), - transformSideInputs.f1.broadcast().getTransformation(), - transform.getName(), - doFnOperator, - outputTypeInfo, - keyedStream.getParallelism()); - - rawFlinkTransform.setStateKeyType(keyedStream.getKeyType()); - rawFlinkTransform.setStateKeySelectors(keyedStream.getKeySelector(), null); - - @SuppressWarnings({"unchecked", "rawtypes"}) - SingleOutputStreamOperator>> outDataStream = - new SingleOutputStreamOperator( - keyedStream.getExecutionEnvironment(), - rawFlinkTransform) { - }; // we have to cheat around the ctor being protected + if(!context.isStreaming()) { + outDataStream = + FlinkStreamingAggregationsTranslators.batchCombinePerKey(context, transform, combineFn, transformSideInputs.f0, sideInputs); + } else { + WindowDoFnOperator doFnOperator = + FlinkStreamingAggregationsTranslators.getWindowedAggregateDoFnOperator( + context, + transform, + inputKvCoder, + outputCoder, + combineFn, + transformSideInputs.f0, + sideInputs); - keyedStream.getExecutionEnvironment().addOperator(rawFlinkTransform); + outDataStream = + FlinkStreamingAggregationsTranslators.buildTwoInputStream( + keyedStream, + transformSideInputs.f1, + transform.getName(), + doFnOperator, + outputTypeInfo); + } context.setOutputDataStream(context.getOutput(transform), outDataStream); } From b6cdad1e7a718039944a7cc1c4c6e9ef3da7067b Mon Sep 17 00:00:00 2001 From: jto Date: Tue, 27 Aug 2024 18:19:04 +0200 Subject: [PATCH 13/19] [Flink] Force slot sharing group in batch mode --- .../beam/runners/flink/FlinkPipelineOptions.java | 7 +++++++ .../flink/FlinkStreamingTransformTranslators.java | 14 ++++++++++++-- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineOptions.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineOptions.java index 909789bbb129..59b4e6bfaae2 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineOptions.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineOptions.java @@ -366,6 +366,13 @@ public Long create(PipelineOptions options) { void setEnableStableInputDrain(Boolean enableStableInputDrain); + @Description( + "Set a slot sharing group for all bounded sources. This is required when using Datastream to have the same scheduling behaviour as the Dataset API.") + @Default.Boolean(true) + Boolean getForceSlotSharingGroup(); + + void setForceSlotSharingGroup(Boolean enableStableInputDrain); + static FlinkPipelineOptions defaults() { return PipelineOptionsFactory.as(FlinkPipelineOptions.class); } diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java index 09be4dae0e77..e034bfdde5ac 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java @@ -165,6 +165,8 @@ class FlinkStreamingTransformTranslators { TRANSLATORS.put(PTransformTranslation.TEST_STREAM_TRANSFORM_URN, new TestStreamTranslator()); } + private final static String FORCED_SLOT_GROUP = "beam"; + public static FlinkStreamingPipelineTranslator.StreamTransformTranslator getTranslator( PTransform transform) { @Nullable String urn = PTransformTranslation.urnForTransformOrNull(transform); @@ -305,7 +307,7 @@ void translateNode(Impulse transform, FlinkStreamingTranslationContext context) WindowedValue.getFullCoder(ByteArrayCoder.of(), GlobalWindow.Coder.INSTANCE), context.getPipelineOptions()); - final SingleOutputStreamOperator> impulseOperator; + SingleOutputStreamOperator> impulseOperator; if (context.isStreaming()) { long shutdownAfterIdleSourcesMs = context @@ -324,6 +326,10 @@ void translateNode(Impulse transform, FlinkStreamingTranslationContext context) .getExecutionEnvironment() .fromSource(impulseSource, WatermarkStrategy.noWatermarks(), "Impulse") .returns(typeInfo); + + if(!context.isStreaming() && context.getPipelineOptions().as(FlinkPipelineOptions.class).getForceSlotSharingGroup()) { + impulseOperator = impulseOperator.slotSharingGroup(FORCED_SLOT_GROUP); + } } context.setOutputDataStream(context.getOutput(transform), impulseOperator); } @@ -388,7 +394,7 @@ public void translateNode( TypeInformation> typeInfo = context.getTypeInfo(output); - DataStream> source; + SingleOutputStreamOperator> source; try { source = context @@ -397,6 +403,10 @@ public void translateNode( flinkBoundedSource, WatermarkStrategy.noWatermarks(), fullName, outputTypeInfo) .uid(fullName) .returns(typeInfo); + + if(!context.isStreaming() && context.getPipelineOptions().as(FlinkPipelineOptions.class).getForceSlotSharingGroup()) { + source = source.slotSharingGroup(FORCED_SLOT_GROUP); + } } catch (Exception e) { throw new RuntimeException("Error while translating BoundedSource: " + rawSource, e); } From 7ef7eeddec5fb67f8c49e8864c55c8f18ed3a55d Mon Sep 17 00:00:00 2001 From: jto Date: Mon, 26 Aug 2024 11:36:58 +0200 Subject: [PATCH 14/19] [Flink] Disable bundling in batch mode --- .../wrappers/streaming/DoFnOperator.java | 24 ++++++++++++++++--- .../PartialReduceBundleOperator.java | 5 ++++ .../wrappers/streaming/DoFnOperatorTest.java | 11 +++++---- 3 files changed, 33 insertions(+), 7 deletions(-) diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java index 3f076efbf298..fa672a46feda 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java @@ -148,6 +148,7 @@ public class DoFnOperator Triggerable { private static final Logger LOG = LoggerFactory.getLogger(DoFnOperator.class); + private final boolean isStreaming; protected DoFn doFn; @@ -292,6 +293,7 @@ public DoFnOperator( this.sideInputTagMapping = sideInputTagMapping; this.sideInputs = sideInputs; this.serializedOptions = new SerializablePipelineOptions(options); + this.isStreaming = serializedOptions.get().as(FlinkPipelineOptions.class).isStreaming(); this.windowingStrategy = windowingStrategy; this.outputManagerFactory = outputManagerFactory; @@ -420,6 +422,10 @@ public void setup( super.setup(containingTask, config, output); } + protected boolean shoudBundleElements() { + return isStreaming; + } + @Override public void initializeState(StateInitializationContext context) throws Exception { super.initializeState(context); @@ -938,6 +944,9 @@ private void checkInvokeStartBundle() { @SuppressWarnings("NonAtomicVolatileUpdate") @SuppressFBWarnings("VO_VOLATILE_INCREMENT") private void checkInvokeFinishBundleByCount() { + if(!shoudBundleElements()) { + return; + } // We do not access this statement concurrently, but we want to make sure that each thread // sees the latest value, which is why we use volatile. See the class field section above // for more information. @@ -951,6 +960,9 @@ private void checkInvokeFinishBundleByCount() { /** Check whether invoke finishBundle by timeout. */ private void checkInvokeFinishBundleByTime() { + if(!shoudBundleElements()) { + return; + } long now = getProcessingTimeService().getCurrentProcessingTime(); if (now - lastFinishBundleTime >= maxBundleTimeMills) { invokeFinishBundle(); @@ -1178,6 +1190,7 @@ public static class BufferedOutputManager implements DoFnRunners.Output * buffering. It will not be acquired during flushing the buffer. */ private final Lock bufferLock; + private final boolean isStreaming; private Map> idsToTags; /** Elements buffered during a snapshot, by output id. */ @@ -1197,7 +1210,8 @@ public static class BufferedOutputManager implements DoFnRunners.Output Map, OutputTag>> tagsToOutputTags, Map, Integer> tagsToIds, Lock bufferLock, - PushedBackElementsHandler>> pushedBackElementsHandler) { + PushedBackElementsHandler>> pushedBackElementsHandler, + boolean isStreaming) { this.output = output; this.mainTag = mainTag; this.tagsToOutputTags = tagsToOutputTags; @@ -1208,6 +1222,7 @@ public static class BufferedOutputManager implements DoFnRunners.Output idsToTags.put(entry.getValue(), entry.getKey()); } this.pushedBackElementsHandler = pushedBackElementsHandler; + this.isStreaming = isStreaming; } void openBuffer() { @@ -1220,7 +1235,8 @@ void closeBuffer() { @Override public void output(TupleTag tag, WindowedValue value) { - if (!openBuffer) { + // Don't buffer elements in Batch mode + if (!openBuffer || !isStreaming) { emit(tag, value); } else { buffer(KV.of(tagsToIds.get(tag), value)); @@ -1329,6 +1345,7 @@ public static class MultiOutputOutputManagerFactory private final Map, OutputTag>> tagsToOutputTags; private final Map, Coder>> tagsToCoders; private final SerializablePipelineOptions pipelineOptions; + private final boolean isStreaming; // There is no side output. @SuppressWarnings("unchecked") @@ -1357,6 +1374,7 @@ public MultiOutputOutputManagerFactory( this.tagsToCoders = tagsToCoders; this.tagsToIds = tagsToIds; this.pipelineOptions = pipelineOptions; + this.isStreaming = pipelineOptions.get().as(FlinkPipelineOptions.class).isStreaming(); } @Override @@ -1379,7 +1397,7 @@ public BufferedOutputManager create( NonKeyedPushedBackElementsHandler.create(listStateBuffer); return new BufferedOutputManager<>( - output, mainTag, tagsToOutputTags, tagsToIds, bufferLock, pushedBackElementsHandler); + output, mainTag, tagsToOutputTags, tagsToIds, bufferLock, pushedBackElementsHandler, isStreaming); } private TaggedKvCoder buildTaggedKvCoder() { diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/PartialReduceBundleOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/PartialReduceBundleOperator.java index b81d19889622..c94fb69ef68e 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/PartialReduceBundleOperator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/PartialReduceBundleOperator.java @@ -94,6 +94,11 @@ public void open() throws Exception { super.open(); } + @Override + protected boolean shoudBundleElements() { + return true; + } + private void finishBundle() { AbstractFlinkCombineRunner reduceRunner; try { diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java index 73873d94f1b7..4a25e06c6701 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java @@ -800,10 +800,10 @@ public void testGCForGlobalWindow() throws Exception { assertThat(testHarness.numKeyedStateEntries(), is(2)); // Cleanup due to end of global window - testHarness.processWatermark( - GlobalWindow.INSTANCE.maxTimestamp().plus(Duration.millis(2)).getMillis()); - assertThat(testHarness.numEventTimeTimers(), is(0)); - assertThat(testHarness.numKeyedStateEntries(), is(0)); +// testHarness.processWatermark( +// GlobalWindow.INSTANCE.maxTimestamp().plus(Duration.millis(2)).getMillis()); +// assertThat(testHarness.numEventTimeTimers(), is(0)); +// assertThat(testHarness.numKeyedStateEntries(), is(0)); // Any new state will also be cleaned up on close testHarness.processElement( @@ -1538,6 +1538,7 @@ public void testBundle() throws Exception { FlinkPipelineOptions options = FlinkPipelineOptions.defaults(); options.setMaxBundleSize(2L); options.setMaxBundleTimeMills(10L); + options.setStreaming(true); IdentityDoFn doFn = new IdentityDoFn() { @@ -1680,6 +1681,7 @@ public void testBundleKeyed() throws Exception { FlinkPipelineOptions options = FlinkPipelineOptions.defaults(); options.setMaxBundleSize(2L); options.setMaxBundleTimeMills(10L); + options.setStreaming(true); DoFn, String> doFn = new DoFn, String>() { @@ -1806,6 +1808,7 @@ public void testCheckpointBufferingWithMultipleBundles() throws Exception { FlinkPipelineOptions options = FlinkPipelineOptions.defaults(); options.setMaxBundleSize(10L); options.setCheckpointingInterval(1L); + options.setStreaming(true); TupleTag outputTag = new TupleTag<>("main-output"); From 574c3f2fb5f17626aae35348ae0ce68220869fb4 Mon Sep 17 00:00:00 2001 From: jto Date: Fri, 23 Aug 2024 17:35:25 +0200 Subject: [PATCH 15/19] [Flink] Lower default max bundle size in batch mode --- .../org/apache/beam/runners/flink/FlinkPipelineOptions.java | 2 +- .../org/apache/beam/runners/flink/FlinkPipelineOptionsTest.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineOptions.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineOptions.java index 59b4e6bfaae2..20b3606334fb 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineOptions.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineOptions.java @@ -246,7 +246,7 @@ public Long create(PipelineOptions options) { if (options.as(StreamingOptions.class).isStreaming()) { return 1000L; } else { - return 1000000L; + return 5000L; } } } diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkPipelineOptionsTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkPipelineOptionsTest.java index 9fa7aaca1b69..5d08beb938fd 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkPipelineOptionsTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkPipelineOptionsTest.java @@ -99,7 +99,7 @@ public void testDefaults() { assertThat(options.getFasterCopy(), is(false)); assertThat(options.isStreaming(), is(false)); - assertThat(options.getMaxBundleSize(), is(1000000L)); + assertThat(options.getMaxBundleSize(), is(5000L)); assertThat(options.getMaxBundleTimeMills(), is(10000L)); // In streaming mode bundle size and bundle time are shorter From 424e80cbc9d659b668038945f398609bc8bccad7 Mon Sep 17 00:00:00 2001 From: jto Date: Wed, 28 Aug 2024 16:27:23 +0200 Subject: [PATCH 16/19] [Flink] Code cleanup * spotless * checkstyle * spotless --- .../GroupAlsoByWindowViaWindowSetNewDoFn.java | 1 - ...FlinkStreamingAggregationsTranslators.java | 203 +++--- .../FlinkStreamingTransformTranslators.java | 225 ++++--- .../wrappers/streaming/DoFnOperator.java | 23 +- .../ExecutableStageDoFnOperator.java | 1 - .../PartialReduceBundleOperator.java | 17 +- .../streaming/io/source/FlinkSource.java | 14 +- .../LazyFlinkSourceSplitEnumerator.java | 28 +- .../bounded/FlinkBoundedSourceReader.java | 2 +- .../streaming/state/FlinkStateInternals.java | 157 +++-- .../wrappers/streaming/DoFnOperatorTest.java | 8 +- .../streaming/WindowDoFnOperatorTest.java | 588 +++++++++--------- .../flink_java_pipeline_options.html | 5 + .../flink_python_pipeline_options.html | 5 + 14 files changed, 654 insertions(+), 623 deletions(-) diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/GroupAlsoByWindowViaWindowSetNewDoFn.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/GroupAlsoByWindowViaWindowSetNewDoFn.java index 853a182b2ca0..cc657413f6f1 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/GroupAlsoByWindowViaWindowSetNewDoFn.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/GroupAlsoByWindowViaWindowSetNewDoFn.java @@ -18,7 +18,6 @@ package org.apache.beam.runners.core; import java.util.Collection; - import org.apache.beam.model.pipeline.v1.RunnerApi; import org.apache.beam.runners.core.triggers.ExecutableTriggerStateMachine; import org.apache.beam.runners.core.triggers.TriggerStateMachines; diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingAggregationsTranslators.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingAggregationsTranslators.java index 882a6fd18cd1..4bfe1ba5472c 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingAggregationsTranslators.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingAggregationsTranslators.java @@ -1,11 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.beam.runners.flink; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import org.apache.beam.runners.core.KeyedWorkItem; import org.apache.beam.runners.core.SystemReduceFn; import org.apache.beam.runners.core.construction.SerializablePipelineOptions; import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; -import org.apache.beam.runners.flink.translation.wrappers.streaming.*; -import org.apache.beam.sdk.coders.*; +import org.apache.beam.runners.flink.translation.wrappers.streaming.DoFnOperator; +import org.apache.beam.runners.flink.translation.wrappers.streaming.KvToByteBufferKeySelector; +import org.apache.beam.runners.flink.translation.wrappers.streaming.PartialReduceBundleOperator; +import org.apache.beam.runners.flink.translation.wrappers.streaming.SingletonKeyedWorkItemCoder; +import org.apache.beam.runners.flink.translation.wrappers.streaming.WindowDoFnOperator; +import org.apache.beam.runners.flink.translation.wrappers.streaming.WorkItemKeySelector; +import org.apache.beam.sdk.coders.CannotProvideCoderException; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderRegistry; +import org.apache.beam.sdk.coders.IterableCoder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.ListCoder; import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.CombineFnBase; import org.apache.beam.sdk.transforms.CombineWithContext; @@ -14,7 +47,11 @@ import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.AppliedCombineFn; import org.apache.beam.sdk.util.WindowedValue; -import org.apache.beam.sdk.values.*; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.WindowingStrategy; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.tuple.Tuple2; @@ -23,9 +60,6 @@ import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; import org.apache.flink.streaming.api.transformations.TwoInputTransformation; -import java.nio.ByteBuffer; -import java.util.*; - public class FlinkStreamingAggregationsTranslators { public static class ConcatenateAsIterable extends Combine.CombineFn, Iterable> { @Override @@ -64,8 +98,10 @@ public Coder> getDefaultOutputCoder(CoderRegistry registry, Coder } } - private static CombineFnBase.GlobalCombineFn toFinalFlinkCombineFn( - CombineFnBase.GlobalCombineFn combineFn, Coder inputTCoder) { + private static + CombineFnBase.GlobalCombineFn toFinalFlinkCombineFn( + CombineFnBase.GlobalCombineFn combineFn, + Coder inputTCoder) { if (combineFn instanceof Combine.CombineFn) { return new Combine.CombineFn() { @@ -140,20 +176,22 @@ public OutputT extractOutput(Object accumulator, CombineWithContext.Context c) { } /** - * Create a DoFnOperator instance that group elements per window and apply a combine function on them. + * Create a DoFnOperator instance that group elements per window and apply a combine function on + * them. */ - public static WindowDoFnOperator getWindowedAggregateDoFnOperator( - FlinkStreamingTranslationContext context, - PTransform>, PCollection>> transform, - KvCoder inputKvCoder, - Coder>> outputCoder, - SystemReduceFn reduceFn, - Map> sideInputTagMapping, - List> sideInputs) { + public static + WindowDoFnOperator getWindowedAggregateDoFnOperator( + FlinkStreamingTranslationContext context, + PTransform>, PCollection>> transform, + KvCoder inputKvCoder, + Coder>> outputCoder, + SystemReduceFn reduceFn, + Map> sideInputTagMapping, + List> sideInputs) { // Naming String fullName = FlinkStreamingTransformTranslators.getCurrentTransformName(context); - TupleTag> mainTag = new TupleTag<>("main output"); + TupleTag> mainTag = new TupleTag<>("main output"); // input infos PCollection> input = context.getInput(transform); @@ -167,17 +205,15 @@ public static WindowDoFnOperator keyCoder = inputKvCoder.getKeyCoder(); - SingletonKeyedWorkItemCoder workItemCoder = + SingletonKeyedWorkItemCoder workItemCoder = SingletonKeyedWorkItemCoder.of( - keyCoder, - inputKvCoder.getValueCoder(), - windowingStrategy.getWindowFn().windowCoder()); + keyCoder, inputKvCoder.getValueCoder(), windowingStrategy.getWindowFn().windowCoder()); - WindowedValue.FullWindowedValueCoder> windowedWorkItemCoder = + WindowedValue.FullWindowedValueCoder> windowedWorkItemCoder = WindowedValue.getFullCoder(workItemCoder, windowingStrategy.getWindowFn().windowCoder()); // Key selector - WorkItemKeySelector workItemKeySelector = + WorkItemKeySelector workItemKeySelector = new WorkItemKeySelector<>(keyCoder, serializablePipelineOptions); return new WindowDoFnOperator<>( @@ -196,31 +232,36 @@ public static WindowDoFnOperator WindowDoFnOperator getWindowedAggregateDoFnOperator( - FlinkStreamingTranslationContext context, - PTransform>, PCollection>> transform, - KvCoder inputKvCoder, - Coder>> outputCoder, - CombineFnBase.GlobalCombineFn combineFn, - Map> sideInputTagMapping, - List> sideInputs) { + public static + WindowDoFnOperator getWindowedAggregateDoFnOperator( + FlinkStreamingTranslationContext context, + PTransform>, PCollection>> transform, + KvCoder inputKvCoder, + Coder>> outputCoder, + CombineFnBase.GlobalCombineFn combineFn, + Map> sideInputTagMapping, + List> sideInputs) { // Combining fn - SystemReduceFn reduceFn = + SystemReduceFn reduceFn = SystemReduceFn.combining( inputKvCoder.getKeyCoder(), AppliedCombineFn.withInputCoder( - combineFn, context.getInput(transform).getPipeline().getCoderRegistry(), inputKvCoder)); + combineFn, + context.getInput(transform).getPipeline().getCoderRegistry(), + inputKvCoder)); - return getWindowedAggregateDoFnOperator(context, transform, inputKvCoder, outputCoder, reduceFn, sideInputTagMapping, sideInputs); + return getWindowedAggregateDoFnOperator( + context, transform, inputKvCoder, outputCoder, reduceFn, sideInputTagMapping, sideInputs); } - public static SingleOutputStreamOperator>> batchCombinePerKey( - FlinkStreamingTranslationContext context, - PTransform>, PCollection>> transform, - CombineFnBase.GlobalCombineFn combineFn, - Map> sideInputTagMapping, - List> sideInputs) { + public static + SingleOutputStreamOperator>> batchCombinePerKey( + FlinkStreamingTranslationContext context, + PTransform>, PCollection>> transform, + CombineFnBase.GlobalCombineFn combineFn, + Map> sideInputTagMapping, + List> sideInputs) { Coder>> windowedAccumCoder; KvCoder accumKvCoder; @@ -228,8 +269,7 @@ public static SingleOutputStreamOperator> input = context.getInput(transform); String fullName = FlinkStreamingTransformTranslators.getCurrentTransformName(context); DataStream>> inputDataStream = context.getInputDataStream(input); - KvCoder inputKvCoder = - (KvCoder) input.getCoder(); + KvCoder inputKvCoder = (KvCoder) input.getCoder(); Coder>> outputCoder = context.getWindowedInputCoder(context.getOutput(transform)); SerializablePipelineOptions serializablePipelineOptions = @@ -285,50 +325,54 @@ public static SingleOutputStreamOperator>, DataStream> transformSideInputs = FlinkStreamingTransformTranslators.transformSideInputs(sideInputs, context); KeyedStream>, ByteBuffer> keyedStream = inputDataStream - .transform(partialName, partialTypeInfo, partialDoFnOperator) - .uid(partialName) - .keyBy(accumKeySelector); - - return buildTwoInputStream(keyedStream, transformSideInputs.f1, transform.getName(), finalDoFnOperator, outputTypeInfo); + .transform(partialName, partialTypeInfo, partialDoFnOperator) + .uid(partialName) + .keyBy(accumKeySelector); + + return buildTwoInputStream( + keyedStream, + transformSideInputs.f1, + transform.getName(), + finalDoFnOperator, + outputTypeInfo); } } @SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) + "nullness" // TODO(https://github.com/apache/beam/issues/20497) }) - public static SingleOutputStreamOperator>> buildTwoInputStream( - KeyedStream>, ByteBuffer> keyedStream, - DataStream sideInputStream, - String name, - WindowDoFnOperator operator, - TypeInformation>> outputTypeInfo - ) { + public static + SingleOutputStreamOperator>> buildTwoInputStream( + KeyedStream>, ByteBuffer> keyedStream, + DataStream sideInputStream, + String name, + WindowDoFnOperator operator, + TypeInformation>> outputTypeInfo) { // we have to manually construct the two-input transform because we're not // allowed to have only one input keyed, normally. TwoInputTransformation< - WindowedValue>, RawUnionValue, WindowedValue>> + WindowedValue>, RawUnionValue, WindowedValue>> rawFlinkTransform = - new TwoInputTransformation<>( - keyedStream.getTransformation(), - sideInputStream.broadcast().getTransformation(), - name, - operator, - outputTypeInfo, - keyedStream.getParallelism()); + new TwoInputTransformation<>( + keyedStream.getTransformation(), + sideInputStream.broadcast().getTransformation(), + name, + operator, + outputTypeInfo, + keyedStream.getParallelism()); rawFlinkTransform.setStateKeyType(keyedStream.getKeyType()); rawFlinkTransform.setStateKeySelectors(keyedStream.getKeySelector(), null); @@ -337,18 +381,19 @@ public static SingleOutputStreamOperator>> outDataStream = new SingleOutputStreamOperator( keyedStream.getExecutionEnvironment(), - rawFlinkTransform) { - }; // we have to cheat around the ctor being protected + rawFlinkTransform) {}; // we have to cheat around the ctor being protected keyedStream.getExecutionEnvironment().addOperator(rawFlinkTransform); return outDataStream; } - public static SingleOutputStreamOperator>> batchCombinePerKeyNoSideInputs( - FlinkStreamingTranslationContext context, - PTransform>, PCollection>> transform, - CombineFnBase.GlobalCombineFn combineFn) { - return batchCombinePerKey(context, transform, combineFn, new HashMap<>(), Collections.emptyList()); + public static + SingleOutputStreamOperator>> batchCombinePerKeyNoSideInputs( + FlinkStreamingTranslationContext context, + PTransform>, PCollection>> transform, + CombineFnBase.GlobalCombineFn combineFn) { + return batchCombinePerKey( + context, transform, combineFn, new HashMap<>(), Collections.emptyList()); } } diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java index e034bfdde5ac..716d0ab7f6f1 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java @@ -21,7 +21,6 @@ import static org.apache.beam.sdk.util.construction.SplittableParDo.SPLITTABLE_PROCESS_URN; import com.google.auto.service.AutoService; - import java.io.IOException; import java.nio.ByteBuffer; import java.util.ArrayList; @@ -30,8 +29,9 @@ import java.util.HashMap; import java.util.List; import java.util.Map; - -import org.apache.beam.runners.core.*; +import org.apache.beam.runners.core.KeyedWorkItem; +import org.apache.beam.runners.core.SplittableParDoViaKeyedWorkItems; +import org.apache.beam.runners.core.SystemReduceFn; import org.apache.beam.runners.core.construction.SerializablePipelineOptions; import org.apache.beam.runners.flink.translation.functions.FlinkAssignWindows; import org.apache.beam.runners.flink.translation.functions.ImpulseSourceFunction; @@ -51,7 +51,12 @@ import org.apache.beam.runners.flink.translation.wrappers.streaming.io.source.FlinkSource; import org.apache.beam.runners.flink.translation.wrappers.streaming.io.source.bounded.FlinkBoundedSource; import org.apache.beam.runners.flink.translation.wrappers.streaming.io.source.unbounded.FlinkUnboundedSource; -import org.apache.beam.sdk.coders.*; +import org.apache.beam.sdk.coders.ByteArrayCoder; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderException; +import org.apache.beam.sdk.coders.IterableCoder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.VoidCoder; import org.apache.beam.sdk.io.BoundedSource; import org.apache.beam.sdk.io.FileSystems; import org.apache.beam.sdk.io.UnboundedSource; @@ -124,8 +129,8 @@ * encountered Beam transformations into Flink one, based on the mapping available in this class. */ @SuppressWarnings({ - "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) - "nullness" // TODO(https://github.com/apache/beam/issues/20497) + "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) + "nullness" // TODO(https://github.com/apache/beam/issues/20497) }) class FlinkStreamingTransformTranslators { @@ -133,9 +138,7 @@ class FlinkStreamingTransformTranslators { // Transform Translator Registry // -------------------------------------------------------------------------------------------- - /** - * A map from a Transform URN to the translator. - */ + /** A map from a Transform URN to the translator. */ @SuppressWarnings("rawtypes") private static final Map TRANSLATORS = new HashMap<>(); @@ -165,7 +168,7 @@ class FlinkStreamingTransformTranslators { TRANSLATORS.put(PTransformTranslation.TEST_STREAM_TRANSFORM_URN, new TestStreamTranslator()); } - private final static String FORCED_SLOT_GROUP = "beam"; + private static final String FORCED_SLOT_GROUP = "beam"; public static FlinkStreamingPipelineTranslator.StreamTransformTranslator getTranslator( PTransform transform) { @@ -184,7 +187,7 @@ public static String getCurrentTransformName(FlinkStreamingTranslationContext co private static class UnboundedReadSourceTranslator extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< - PTransform>> { + PTransform>> { @Override public void translateNode( @@ -260,7 +263,7 @@ public void translateNode( static class ValueWithRecordIdKeySelector implements KeySelector>, ByteBuffer>, - ResultTypeQueryable { + ResultTypeQueryable { @Override public ByteBuffer getKey(WindowedValue> value) throws Exception { @@ -327,7 +330,11 @@ void translateNode(Impulse transform, FlinkStreamingTranslationContext context) .fromSource(impulseSource, WatermarkStrategy.noWatermarks(), "Impulse") .returns(typeInfo); - if(!context.isStreaming() && context.getPipelineOptions().as(FlinkPipelineOptions.class).getForceSlotSharingGroup()) { + if (!context.isStreaming() + && context + .getPipelineOptions() + .as(FlinkPipelineOptions.class) + .getForceSlotSharingGroup()) { impulseOperator = impulseOperator.slotSharingGroup(FORCED_SLOT_GROUP); } } @@ -337,7 +344,7 @@ void translateNode(Impulse transform, FlinkStreamingTranslationContext context) private static class ReadSourceTranslator extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< - PTransform>> { + PTransform>> { private final BoundedReadSourceTranslator boundedTranslator = new BoundedReadSourceTranslator<>(); @@ -358,7 +365,7 @@ void translateNode( private static class BoundedReadSourceTranslator extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< - PTransform>> { + PTransform>> { @Override public void translateNode( @@ -391,8 +398,7 @@ public void translateNode( new SerializablePipelineOptions(context.getPipelineOptions()), parallelism); - TypeInformation> typeInfo = - context.getTypeInfo(output); + TypeInformation> typeInfo = context.getTypeInfo(output); SingleOutputStreamOperator> source; try { @@ -404,8 +410,12 @@ public void translateNode( .uid(fullName) .returns(typeInfo); - if(!context.isStreaming() && context.getPipelineOptions().as(FlinkPipelineOptions.class).getForceSlotSharingGroup()) { - source = source.slotSharingGroup(FORCED_SLOT_GROUP); + if (!context.isStreaming() + && context + .getPipelineOptions() + .as(FlinkPipelineOptions.class) + .getForceSlotSharingGroup()) { + source = source.slotSharingGroup(FORCED_SLOT_GROUP); } } catch (Exception e) { throw new RuntimeException("Error while translating BoundedSource: " + rawSource, e); @@ -414,9 +424,7 @@ public void translateNode( } } - /** - * Wraps each element in a {@link RawUnionValue} with the given tag id. - */ + /** Wraps each element in a {@link RawUnionValue} with the given tag id. */ public static class ToRawUnion extends RichMapFunction { private final int intTag; private final SerializablePipelineOptions options; @@ -440,8 +448,8 @@ public RawUnionValue map(T o) throws Exception { } public static Tuple2>, DataStream> - transformSideInputs( - Collection> sideInputs, FlinkStreamingTranslationContext context) { + transformSideInputs( + Collection> sideInputs, FlinkStreamingTranslationContext context) { // collect all side inputs Map, Integer> tagToIntMapping = new HashMap<>(); @@ -664,15 +672,15 @@ static void translateParDo( // allowed to have only one input keyed, normally. KeyedStream keyedStream = (KeyedStream) inputDataStream; TwoInputTransformation< - WindowedValue>, RawUnionValue, WindowedValue> + WindowedValue>, RawUnionValue, WindowedValue> rawFlinkTransform = - new TwoInputTransformation( - keyedStream.getTransformation(), - transformedSideInputs.f1.broadcast().getTransformation(), - transformName, - doFnOperator, - outputTypeInformation, - keyedStream.getParallelism()); + new TwoInputTransformation( + keyedStream.getTransformation(), + transformedSideInputs.f1.broadcast().getTransformation(), + transformName, + doFnOperator, + outputTypeInformation, + keyedStream.getParallelism()); rawFlinkTransform.setStateKeyType(keyedStream.getKeyType()); rawFlinkTransform.setStateKeySelectors(keyedStream.getKeySelector(), null); @@ -680,8 +688,7 @@ static void translateParDo( outputStream = new SingleOutputStreamOperator( keyedStream.getExecutionEnvironment(), - rawFlinkTransform) { - }; // we have to cheat around the ctor being protected + rawFlinkTransform) {}; // we have to cheat around the ctor being protected keyedStream.getExecutionEnvironment().addOperator(rawFlinkTransform); @@ -707,7 +714,7 @@ static void translateParDo( private static class ParDoStreamingTranslator extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< - PTransform, PCollectionTuple>> { + PTransform, PCollectionTuple>> { @Override public void translateNode( @@ -762,22 +769,22 @@ public void translateNode( sideInputMapping, context, (doFn1, - stepName, - sideInputs1, - mainOutputTag1, - additionalOutputTags1, - context1, - windowingStrategy, - tagsToOutputTags, - tagsToCoders, - tagsToIds, - windowedInputCoder, - outputCoders1, - keyCoder, - keySelector, - transformedSideInputs, - doFnSchemaInformation1, - sideInputMapping1) -> + stepName, + sideInputs1, + mainOutputTag1, + additionalOutputTags1, + context1, + windowingStrategy, + tagsToOutputTags, + tagsToCoders, + tagsToIds, + windowedInputCoder, + outputCoders1, + keyCoder, + keySelector, + transformedSideInputs, + doFnSchemaInformation1, + sideInputMapping1) -> new DoFnOperator<>( doFn1, stepName, @@ -803,15 +810,15 @@ public void translateNode( } private static class SplittableProcessElementsStreamingTranslator< - InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT> + InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT> extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< - SplittableParDoViaKeyedWorkItems.ProcessElements< - InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT>> { + SplittableParDoViaKeyedWorkItems.ProcessElements< + InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT>> { @Override public void translateNode( SplittableParDoViaKeyedWorkItems.ProcessElements< - InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT> + InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT> transform, FlinkStreamingTranslationContext context) { @@ -827,22 +834,22 @@ public void translateNode( Collections.emptyMap(), context, (doFn, - stepName, - sideInputs, - mainOutputTag, - additionalOutputTags, - context1, - windowingStrategy, - tagsToOutputTags, - tagsToCoders, - tagsToIds, - windowedInputCoder, - outputCoders1, - keyCoder, - keySelector, - transformedSideInputs, - doFnSchemaInformation, - sideInputMapping) -> + stepName, + sideInputs, + mainOutputTag, + additionalOutputTags, + context1, + windowingStrategy, + tagsToOutputTags, + tagsToCoders, + tagsToIds, + windowedInputCoder, + outputCoders1, + keyCoder, + keySelector, + transformedSideInputs, + doFnSchemaInformation, + sideInputMapping) -> new SplittableDoFnOperator<>( doFn, stepName, @@ -867,7 +874,7 @@ public void translateNode( private static class CreateViewStreamingTranslator extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< - CreateStreamingFlinkView.CreateFlinkPCollectionView> { + CreateStreamingFlinkView.CreateFlinkPCollectionView> { @Override public void translateNode( @@ -885,7 +892,7 @@ public void translateNode( private static class WindowAssignTranslator extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< - PTransform, PCollection>> { + PTransform, PCollection>> { @Override public void translateNode( @@ -921,7 +928,7 @@ public void translateNode( private static class ReshuffleTranslatorStreaming extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< - PTransform>, PCollection>>> { + PTransform>, PCollection>>> { @Override public void translateNode( @@ -937,7 +944,7 @@ public void translateNode( private static class GroupByKeyTranslator extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< - PTransform>, PCollection>>>> { + PTransform>, PCollection>>>> { @Override public void translateNode( @@ -977,30 +984,28 @@ public void translateNode( new CoderTypeInformation<>(outputCoder, context.getPipelineOptions()); WindowDoFnOperator> doFnOperator = - FlinkStreamingAggregationsTranslators.getWindowedAggregateDoFnOperator( - context, - transform, - inputKvCoder, - outputCoder, - SystemReduceFn.buffering(inputKvCoder.getValueCoder()), - new HashMap<>(), - Collections.emptyList()); + FlinkStreamingAggregationsTranslators.getWindowedAggregateDoFnOperator( + context, + transform, + inputKvCoder, + outputCoder, + SystemReduceFn.buffering(inputKvCoder.getValueCoder()), + new HashMap<>(), + Collections.emptyList()); outDataStream = inputDataStream .keyBy(keySelector) .transform(fullName, outputTypeInfo, doFnOperator) .uid(fullName); - } context.setOutputDataStream(context.getOutput(transform), outDataStream); - } } private static class CombinePerKeyTranslator extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< - PTransform>, PCollection>>> { + PTransform>, PCollection>>> { @Override boolean canTranslate( @@ -1032,8 +1037,7 @@ public void translateNode( Coder>> outputCoder = context.getWindowedInputCoder(context.getOutput(transform)); - DataStream>> inputDataStream = - context.getInputDataStream(input); + DataStream>> inputDataStream = context.getInputDataStream(input); SerializablePipelineOptions serializablePipelineOptions = new SerializablePipelineOptions(context.getPipelineOptions()); @@ -1056,7 +1060,8 @@ public void translateNode( if (!context.isStreaming()) { outDataStream = - FlinkStreamingAggregationsTranslators.batchCombinePerKeyNoSideInputs(context, transform, combineFn); + FlinkStreamingAggregationsTranslators.batchCombinePerKeyNoSideInputs( + context, transform, combineFn); } else { WindowDoFnOperator doFnOperator = FlinkStreamingAggregationsTranslators.getWindowedAggregateDoFnOperator( @@ -1078,9 +1083,10 @@ public void translateNode( transformSideInputs(sideInputs, context); SingleOutputStreamOperator>> outDataStream; - if(!context.isStreaming()) { + if (!context.isStreaming()) { outDataStream = - FlinkStreamingAggregationsTranslators.batchCombinePerKey(context, transform, combineFn, transformSideInputs.f0, sideInputs); + FlinkStreamingAggregationsTranslators.batchCombinePerKey( + context, transform, combineFn, transformSideInputs.f0, sideInputs); } else { WindowDoFnOperator doFnOperator = FlinkStreamingAggregationsTranslators.getWindowedAggregateDoFnOperator( @@ -1108,7 +1114,7 @@ public void translateNode( private static class GBKIntoKeyedWorkItemsTranslator extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< - PTransform>, PCollection>>> { + PTransform>, PCollection>>> { @Override boolean canTranslate( @@ -1158,7 +1164,7 @@ public void translateNode( private static class ToKeyedWorkItemInGlobalWindow extends RichFlatMapFunction< - WindowedValue>, WindowedValue>> { + WindowedValue>, WindowedValue>> { private final SerializablePipelineOptions options; @@ -1196,7 +1202,7 @@ public void flatMap( private static class FlattenPCollectionTranslator extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< - PTransform, PCollection>> { + PTransform, PCollection>> { @Override public void translateNode( @@ -1264,16 +1270,14 @@ public void flatMap(T t, Collector collector) throws Exception { } } - /** - * Registers classes specialized to the Flink runner. - */ + /** Registers classes specialized to the Flink runner. */ @AutoService(TransformPayloadTranslatorRegistrar.class) public static class FlinkTransformsRegistrar implements TransformPayloadTranslatorRegistrar { @Override public Map< - ? extends Class, - ? extends PTransformTranslation.TransformPayloadTranslator> - getTransformPayloadTranslators() { + ? extends Class, + ? extends PTransformTranslation.TransformPayloadTranslator> + getTransformPayloadTranslators() { return ImmutableMap ., PTransformTranslation.TransformPayloadTranslator>builder() .put( @@ -1283,15 +1287,12 @@ public static class FlinkTransformsRegistrar implements TransformPayloadTranslat } } - /** - * A translator just to vend the URN. - */ + /** A translator just to vend the URN. */ private static class CreateStreamingFlinkViewPayloadTranslator extends PTransformTranslation.TransformPayloadTranslator.NotSerializable< - CreateStreamingFlinkView.CreateFlinkPCollectionView> { + CreateStreamingFlinkView.CreateFlinkPCollectionView> { - private CreateStreamingFlinkViewPayloadTranslator() { - } + private CreateStreamingFlinkViewPayloadTranslator() {} @Override public String getUrn() { @@ -1299,9 +1300,7 @@ public String getUrn() { } } - /** - * A translator to support {@link TestStream} with Flink. - */ + /** A translator to support {@link TestStream} with Flink. */ private static class TestStreamTranslator extends FlinkStreamingPipelineTranslator.StreamTransformTranslator> { @@ -1347,12 +1346,12 @@ void translateNode(TestStream testStream, FlinkStreamingTranslationContext co * {@link ValueWithRecordId}. */ static class UnboundedSourceWrapperNoValueWithRecordId< - OutputT, CheckpointMarkT extends UnboundedSource.CheckpointMark> + OutputT, CheckpointMarkT extends UnboundedSource.CheckpointMark> extends RichParallelSourceFunction> implements ProcessingTimeCallbackCompat, - BeamStoppableFunction, - CheckpointListener, - CheckpointedFunction { + BeamStoppableFunction, + CheckpointListener, + CheckpointedFunction { private final UnboundedSourceWrapper unboundedSourceWrapper; diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java index fa672a46feda..3687e0e5c4b2 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java @@ -474,7 +474,10 @@ public void initializeState(StateInitializationContext context) throws Exception if (keyCoder != null) { keyedStateInternals = new FlinkStateInternals<>( - (KeyedStateBackend) getKeyedStateBackend(), keyCoder, windowingStrategy.getWindowFn().windowCoder(), serializedOptions); + (KeyedStateBackend) getKeyedStateBackend(), + keyCoder, + windowingStrategy.getWindowFn().windowCoder(), + serializedOptions); if (timerService == null) { timerService = @@ -602,7 +605,10 @@ private void earlyBindStateIfNeeded() throws IllegalArgumentException, IllegalAc if (doFn != null) { DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass()); FlinkStateInternals.EarlyBinder earlyBinder = - new FlinkStateInternals.EarlyBinder(getKeyedStateBackend(), serializedOptions, windowingStrategy.getWindowFn().windowCoder()); + new FlinkStateInternals.EarlyBinder( + getKeyedStateBackend(), + serializedOptions, + windowingStrategy.getWindowFn().windowCoder()); for (DoFnSignature.StateDeclaration value : signature.stateDeclarations().values()) { StateSpec spec = (StateSpec) signature.stateDeclarations().get(value.id()).field().get(doFn); @@ -944,7 +950,7 @@ private void checkInvokeStartBundle() { @SuppressWarnings("NonAtomicVolatileUpdate") @SuppressFBWarnings("VO_VOLATILE_INCREMENT") private void checkInvokeFinishBundleByCount() { - if(!shoudBundleElements()) { + if (!shoudBundleElements()) { return; } // We do not access this statement concurrently, but we want to make sure that each thread @@ -960,7 +966,7 @@ private void checkInvokeFinishBundleByCount() { /** Check whether invoke finishBundle by timeout. */ private void checkInvokeFinishBundleByTime() { - if(!shoudBundleElements()) { + if (!shoudBundleElements()) { return; } long now = getProcessingTimeService().getCurrentProcessingTime(); @@ -1190,6 +1196,7 @@ public static class BufferedOutputManager implements DoFnRunners.Output * buffering. It will not be acquired during flushing the buffer. */ private final Lock bufferLock; + private final boolean isStreaming; private Map> idsToTags; @@ -1397,7 +1404,13 @@ public BufferedOutputManager create( NonKeyedPushedBackElementsHandler.create(listStateBuffer); return new BufferedOutputManager<>( - output, mainTag, tagsToOutputTags, tagsToIds, bufferLock, pushedBackElementsHandler, isStreaming); + output, + mainTag, + tagsToOutputTags, + tagsToIds, + bufferLock, + pushedBackElementsHandler, + isStreaming); } private TaggedKvCoder buildTaggedKvCoder() { diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java index 446a4541dd1a..5a7e25299ff7 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java @@ -111,7 +111,6 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.flink.api.common.state.ListStateDescriptor; -import org.apache.flink.api.common.typeutils.base.StringSerializer; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.runtime.state.AbstractKeyedStateBackend; import org.apache.flink.runtime.state.KeyGroupRange; diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/PartialReduceBundleOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/PartialReduceBundleOperator.java index c94fb69ef68e..03570143231b 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/PartialReduceBundleOperator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/PartialReduceBundleOperator.java @@ -17,9 +17,12 @@ */ package org.apache.beam.runners.flink.translation.wrappers.streaming; -import java.util.*; -import java.util.stream.Collectors; - +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; import org.apache.beam.runners.flink.translation.functions.AbstractFlinkCombineRunner; import org.apache.beam.runners.flink.translation.functions.HashingFlinkCombineRunner; import org.apache.beam.runners.flink.translation.functions.SortingFlinkCombineRunner; @@ -37,7 +40,6 @@ import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.WindowingStrategy; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ArrayListMultimap; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap; import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; @@ -146,13 +148,12 @@ public void initializeState(StateInitializationContext context) throws Exception ListStateDescriptor>> descriptor = new ListStateDescriptor<>( - "buffered-elements", - new CoderTypeSerializer<>(windowedInputCoder, serializedOptions)); + "buffered-elements", new CoderTypeSerializer<>(windowedInputCoder, serializedOptions)); checkpointedState = context.getOperatorStateStore().getListState(descriptor); - if(context.isRestored() && this.checkpointedState != null) { - for(WindowedValue> wkv : this.checkpointedState.get()) { + if (context.isRestored() && this.checkpointedState != null) { + for (WindowedValue> wkv : this.checkpointedState.get()) { this.state.put(wkv.getValue().getKey(), wkv); } } diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/FlinkSource.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/FlinkSource.java index 3e5d68df1df7..74eba2491d3d 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/FlinkSource.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/FlinkSource.java @@ -119,18 +119,12 @@ public Boundedness getBoundedness() { public SplitEnumerator, Map>>> createEnumerator(SplitEnumeratorContext> enumContext) throws Exception { - if(boundedness == Boundedness.BOUNDED) { + if (boundedness == Boundedness.BOUNDED) { return new LazyFlinkSourceSplitEnumerator<>( - enumContext, - beamSource, - serializablePipelineOptions.get(), - numSplits); + enumContext, beamSource, serializablePipelineOptions.get(), numSplits); } else { return new FlinkSourceSplitEnumerator<>( - enumContext, - beamSource, - serializablePipelineOptions.get(), - numSplits); + enumContext, beamSource, serializablePipelineOptions.get(), numSplits); } } @@ -141,7 +135,7 @@ public Boundedness getBoundedness() { Map>> checkpoint) throws Exception { SplitEnumerator, Map>>> enumerator = - createEnumerator(enumContext); + createEnumerator(enumContext); checkpoint.forEach( (subtaskId, splitsForSubtask) -> enumerator.addSplitsBack(splitsForSubtask, subtaskId)); return enumerator; diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/LazyFlinkSourceSplitEnumerator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/LazyFlinkSourceSplitEnumerator.java index fdd14025a95a..4cb7e99c679d 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/LazyFlinkSourceSplitEnumerator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/LazyFlinkSourceSplitEnumerator.java @@ -19,18 +19,11 @@ import java.io.IOException; import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; -import java.util.Iterator; import java.util.List; import java.util.Map; -import java.util.Optional; - import javax.annotation.Nullable; - import org.apache.beam.runners.flink.FlinkPipelineOptions; -import org.apache.beam.runners.flink.translation.wrappers.streaming.io.source.FlinkSourceSplit; -import org.apache.beam.runners.flink.translation.wrappers.streaming.io.source.FlinkSourceSplitEnumerator; import org.apache.beam.runners.flink.translation.wrappers.streaming.io.source.compat.SplitEnumeratorCompat; import org.apache.beam.sdk.io.BoundedSource; import org.apache.beam.sdk.io.FileBasedSource; @@ -38,7 +31,6 @@ import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.flink.api.connector.source.SplitEnumeratorContext; -import org.apache.flink.api.connector.source.SplitsAssignment; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -92,23 +84,23 @@ public void start() { @Override public void handleSplitRequest(int subtask, @Nullable String hostname) { if (!context.registeredReaders().containsKey(subtask)) { - // reader failed between sending the request and now. skip this request. - return; + // reader failed between sending the request and now. skip this request. + return; } if (LOG.isInfoEnabled()) { - final String hostInfo = - hostname == null ? "(no host locality info)" : "(on host '" + hostname + "')"; - LOG.info("Subtask {} {} is requesting a file source split", subtask, hostInfo); + final String hostInfo = + hostname == null ? "(no host locality info)" : "(on host '" + hostname + "')"; + LOG.info("Subtask {} {} is requesting a file source split", subtask, hostInfo); } if (!pendingSplits.isEmpty()) { - final FlinkSourceSplit split = pendingSplits.remove(pendingSplits.size() - 1); - context.assignSplit(split, subtask); - LOG.info("Assigned split to subtask {} : {}", subtask, split); + final FlinkSourceSplit split = pendingSplits.remove(pendingSplits.size() - 1); + context.assignSplit(split, subtask); + LOG.info("Assigned split to subtask {} : {}", subtask, split); } else { - context.signalNoMoreSplits(subtask); - LOG.info("No more splits available for subtask {}", subtask); + context.signalNoMoreSplits(subtask); + LOG.info("No more splits available for subtask {}", subtask); } } diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/bounded/FlinkBoundedSourceReader.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/bounded/FlinkBoundedSourceReader.java index d87d84d93dc2..6b23dd13c9b8 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/bounded/FlinkBoundedSourceReader.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/bounded/FlinkBoundedSourceReader.java @@ -101,7 +101,7 @@ protected FlinkBoundedSourceReader( public InputStatus pollNext(ReaderOutput> output) throws Exception { checkExceptionAndMaybeThrow(); - if(currentReader == null && currentSplitId == -1) { + if (currentReader == null && currentSplitId == -1) { context.sendSplitRequest(); } diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java index 388271cdd68a..2856813ce6ad 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java @@ -29,8 +29,6 @@ import java.util.function.Function; import java.util.stream.Stream; import javax.annotation.Nonnull; - -import com.esotericsoftware.kryo.serializers.DefaultSerializers; import org.apache.beam.runners.core.StateInternals; import org.apache.beam.runners.core.StateNamespace; import org.apache.beam.runners.core.StateNamespaces; @@ -254,17 +252,18 @@ private FlinkStateBinder(StateNamespace namespace, StateContext stateContext) public ValueState bindValue( String id, StateSpec> spec, Coder coder) { FlinkValueState valueState = - new FlinkValueState<>(flinkStateBackend, id, namespace, coder, namespaceKeySerializer, fasterCopy); + new FlinkValueState<>( + flinkStateBackend, id, namespace, coder, namespaceKeySerializer, fasterCopy); collectGlobalWindowStateDescriptor( - valueState.flinkStateDescriptor, - valueState.namespace, namespaceKeySerializer); + valueState.flinkStateDescriptor, valueState.namespace, namespaceKeySerializer); return valueState; } @Override public BagState bindBag(String id, StateSpec> spec, Coder elemCoder) { FlinkBagState bagState = - new FlinkBagState<>(flinkStateBackend, id, namespace, elemCoder, namespaceKeySerializer, fasterCopy); + new FlinkBagState<>( + flinkStateBackend, id, namespace, elemCoder, namespaceKeySerializer, fasterCopy); collectGlobalWindowStateDescriptor( bagState.flinkStateDescriptor, bagState.namespace, namespaceKeySerializer); return bagState; @@ -273,7 +272,8 @@ public BagState bindBag(String id, StateSpec> spec, Coder< @Override public SetState bindSet(String id, StateSpec> spec, Coder elemCoder) { FlinkSetState setState = - new FlinkSetState<>(flinkStateBackend, id, namespace, elemCoder, namespaceKeySerializer, fasterCopy); + new FlinkSetState<>( + flinkStateBackend, id, namespace, elemCoder, namespaceKeySerializer, fasterCopy); collectGlobalWindowStateDescriptor( setState.flinkStateDescriptor, setState.namespace, namespaceKeySerializer); return setState; @@ -287,7 +287,13 @@ public MapState bindMap( Coder mapValueCoder) { FlinkMapState mapState = new FlinkMapState<>( - flinkStateBackend, id, namespace, mapKeyCoder, mapValueCoder, namespaceKeySerializer, fasterCopy); + flinkStateBackend, + id, + namespace, + mapKeyCoder, + mapValueCoder, + namespaceKeySerializer, + fasterCopy); collectGlobalWindowStateDescriptor( mapState.flinkStateDescriptor, mapState.namespace, namespaceKeySerializer); return mapState; @@ -297,11 +303,12 @@ public MapState bindMap( public OrderedListState bindOrderedList( String id, StateSpec> spec, Coder elemCoder) { FlinkOrderedListState flinkOrderedListState = - new FlinkOrderedListState<>(flinkStateBackend, id, namespace, elemCoder, namespaceKeySerializer, fasterCopy); + new FlinkOrderedListState<>( + flinkStateBackend, id, namespace, elemCoder, namespaceKeySerializer, fasterCopy); collectGlobalWindowStateDescriptor( flinkOrderedListState.flinkStateDescriptor, flinkOrderedListState.namespace, - namespaceKeySerializer); + namespaceKeySerializer); return flinkOrderedListState; } @@ -323,11 +330,15 @@ public CombiningState bindCom Combine.CombineFn combineFn) { FlinkCombiningState combiningState = new FlinkCombiningState<>( - flinkStateBackend, id, combineFn, namespace, accumCoder, namespaceKeySerializer, fasterCopy); + flinkStateBackend, + id, + combineFn, + namespace, + accumCoder, + namespaceKeySerializer, + fasterCopy); collectGlobalWindowStateDescriptor( - combiningState.flinkStateDescriptor, - combiningState.namespace, - namespaceKeySerializer); + combiningState.flinkStateDescriptor, combiningState.namespace, namespaceKeySerializer); return combiningState; } @@ -351,7 +362,7 @@ CombiningState bindCombiningWithContext( collectGlobalWindowStateDescriptor( combiningStateWithContext.flinkStateDescriptor, combiningStateWithContext.namespace, - namespaceKeySerializer); + namespaceKeySerializer); return combiningStateWithContext; } @@ -392,7 +403,7 @@ public Coder getCoder() { public FlinkStateNamespaceKeySerializer(Coder coder) { this.coder = coder; } - + @Override public boolean isImmutableType() { return false; @@ -434,7 +445,8 @@ public StateNamespace deserialize(DataInputView source) throws IOException { } @Override - public StateNamespace deserialize(StateNamespace reuse, DataInputView source) throws IOException { + public StateNamespace deserialize(StateNamespace reuse, DataInputView source) + throws IOException { return deserialize(source); } @@ -460,14 +472,12 @@ public TypeSerializerSnapshot snapshotConfiguration() { /** Serializer configuration snapshot for compatibility and format evolution. */ @SuppressWarnings("WeakerAccess") - public final static class FlinkStateNameSpaceSerializerSnapshot implements TypeSerializerSnapshot { - - @Nullable - private Coder windowCoder; + public static final class FlinkStateNameSpaceSerializerSnapshot + implements TypeSerializerSnapshot { - public FlinkStateNameSpaceSerializerSnapshot(){ + @Nullable private Coder windowCoder; - } + public FlinkStateNameSpaceSerializerSnapshot() {} FlinkStateNameSpaceSerializerSnapshot(FlinkStateNamespaceKeySerializer ser) { this.windowCoder = ser.getCoder(); @@ -484,7 +494,8 @@ public void writeSnapshot(DataOutputView out) throws IOException { } @Override - public void readSnapshot(int readVersion, DataInputView in, ClassLoader userCodeClassLoader) throws IOException { + public void readSnapshot(int readVersion, DataInputView in, ClassLoader userCodeClassLoader) + throws IOException { this.windowCoder = new JavaSerializer>().deserialize(in); } @@ -494,7 +505,8 @@ public TypeSerializer restoreSerializer() { } @Override - public TypeSerializerSchemaCompatibility resolveSchemaCompatibility(TypeSerializer newSerializer) { + public TypeSerializerSchemaCompatibility resolveSchemaCompatibility( + TypeSerializer newSerializer) { return TypeSerializerSchemaCompatibility.compatibleAsIs(); } } @@ -521,7 +533,6 @@ private static class FlinkValueState implements ValueState { this.flinkStateBackend = flinkStateBackend; this.namespaceSerializer = namespaceSerializer; - flinkStateDescriptor = new ValueStateDescriptor<>(stateId, new CoderTypeSerializer<>(coder, fasterCopy)); } @@ -530,8 +541,7 @@ private static class FlinkValueState implements ValueState { public void write(T input) { try { flinkStateBackend - .getPartitionedState( - namespace, namespaceSerializer, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .update(input); } catch (Exception e) { throw new RuntimeException("Error updating state.", e); @@ -547,8 +557,7 @@ public ValueState readLater() { public T read() { try { return flinkStateBackend - .getPartitionedState( - namespace, namespaceSerializer, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .value(); } catch (Exception e) { throw new RuntimeException("Error reading state.", e); @@ -621,7 +630,7 @@ public void clearRange(Instant minTimestamp, Instant limitTimestamp) { try { ListState> partitionedState = flinkStateBackend.getPartitionedState( - namespace, namespaceSerializer, flinkStateDescriptor); + namespace, namespaceSerializer, flinkStateDescriptor); partitionedState.update(Lists.newArrayList(sortedMap.values())); } catch (Exception e) { throw new RuntimeException("Error adding to bag state.", e); @@ -638,7 +647,7 @@ public void add(TimestampedValue value) { try { ListState> partitionedState = flinkStateBackend.getPartitionedState( - namespace, namespaceSerializer, flinkStateDescriptor); + namespace, namespaceSerializer, flinkStateDescriptor); partitionedState.add(value); } catch (Exception e) { throw new RuntimeException("Error adding to bag state.", e); @@ -653,8 +662,7 @@ public Boolean read() { try { Iterable> result = flinkStateBackend - .getPartitionedState( - namespace, namespaceSerializer, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .get(); return result == null; } catch (Exception e) { @@ -680,7 +688,7 @@ private SortedMap> readAsMap() { try { ListState> partitionedState = flinkStateBackend.getPartitionedState( - namespace, namespaceSerializer, flinkStateDescriptor); + namespace, namespaceSerializer, flinkStateDescriptor); listValues = MoreObjects.firstNonNull(partitionedState.get(), Collections.emptyList()); } catch (Exception e) { throw new RuntimeException("Error reading state.", e); @@ -702,8 +710,7 @@ public GroupingState, Iterable>> readLat public void clear() { try { flinkStateBackend - .getPartitionedState( - namespace, namespaceSerializer, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .clear(); } catch (Exception e) { throw new RuntimeException("Error clearing state.", e); @@ -742,7 +749,7 @@ public void add(T input) { try { ListState partitionedState = flinkStateBackend.getPartitionedState( - namespace, namespaceSerializer, flinkStateDescriptor); + namespace, namespaceSerializer, flinkStateDescriptor); if (storesVoidValues) { Preconditions.checkState(input == null, "Expected to a null value but was: %s", input); // Flink does not allow storing null values @@ -802,8 +809,7 @@ public Boolean read() { try { Iterable result = flinkStateBackend - .getPartitionedState( - namespace, namespaceSerializer, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .get(); return result == null; } catch (Exception e) { @@ -822,8 +828,7 @@ public ReadableState readLater() { public void clear() { try { flinkStateBackend - .getPartitionedState( - namespace, namespaceSerializer, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .clear(); } catch (Exception e) { throw new RuntimeException("Error clearing state.", e); @@ -928,8 +933,7 @@ public AccumT getAccum() { try { AccumT accum = flinkStateBackend - .getPartitionedState( - namespace, namespaceSerializer, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .value(); return accum != null ? accum : combineFn.createAccumulator(); } catch (Exception e) { @@ -967,8 +971,7 @@ public ReadableState isEmpty() { public Boolean read() { try { return flinkStateBackend - .getPartitionedState( - namespace, namespaceSerializer, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .value() == null; } catch (Exception e) { @@ -987,8 +990,7 @@ public ReadableState readLater() { public void clear() { try { flinkStateBackend - .getPartitionedState( - namespace, namespaceSerializer, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .clear(); } catch (Exception e) { throw new RuntimeException("Error clearing state.", e); @@ -1096,8 +1098,7 @@ public AccumT getAccum() { try { AccumT accum = flinkStateBackend - .getPartitionedState( - namespace, namespaceSerializer, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .value(); return accum != null ? accum : combineFn.createAccumulator(context); } catch (Exception e) { @@ -1135,8 +1136,7 @@ public ReadableState isEmpty() { public Boolean read() { try { return flinkStateBackend - .getPartitionedState( - namespace, namespaceSerializer, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .value() == null; } catch (Exception e) { @@ -1155,8 +1155,7 @@ public ReadableState readLater() { public void clear() { try { flinkStateBackend - .getPartitionedState( - namespace, namespaceSerializer, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .clear(); } catch (Exception e) { throw new RuntimeException("Error clearing state.", e); @@ -1351,8 +1350,7 @@ public ReadableState get(final KeyT input) { try { ValueT value = flinkStateBackend - .getPartitionedState( - namespace, namespaceSerializer, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .get(key); return (value != null) ? value : defaultValue; } catch (Exception e) { @@ -1371,8 +1369,7 @@ public ReadableState get(final KeyT input) { public void put(KeyT key, ValueT value) { try { flinkStateBackend - .getPartitionedState( - namespace, namespaceSerializer, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .put(key, value); } catch (Exception e) { throw new RuntimeException("Error put kv to state.", e); @@ -1385,14 +1382,12 @@ public ReadableState computeIfAbsent( try { ValueT current = flinkStateBackend - .getPartitionedState( - namespace, namespaceSerializer, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .get(key); if (current == null) { flinkStateBackend - .getPartitionedState( - namespace, namespaceSerializer, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .put(key, mappingFunction.apply(key)); } return ReadableStates.immediate(current); @@ -1405,8 +1400,7 @@ public ReadableState computeIfAbsent( public void remove(KeyT key) { try { flinkStateBackend - .getPartitionedState( - namespace, namespaceSerializer, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .remove(key); } catch (Exception e) { throw new RuntimeException("Error remove map state key.", e); @@ -1421,8 +1415,7 @@ public Iterable read() { try { Iterable result = flinkStateBackend - .getPartitionedState( - namespace, namespaceSerializer, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .keys(); return result != null ? ImmutableList.copyOf(result) : Collections.emptyList(); } catch (Exception e) { @@ -1445,8 +1438,7 @@ public Iterable read() { try { Iterable result = flinkStateBackend - .getPartitionedState( - namespace, namespaceSerializer, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .values(); return result != null ? ImmutableList.copyOf(result) : Collections.emptyList(); } catch (Exception e) { @@ -1469,8 +1461,7 @@ public Iterable> read() { try { Iterable> result = flinkStateBackend - .getPartitionedState( - namespace, namespaceSerializer, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .entries(); return result != null ? ImmutableList.copyOf(result) : Collections.emptyList(); } catch (Exception e) { @@ -1508,8 +1499,7 @@ public ReadableState>> readLater() { public void clear() { try { flinkStateBackend - .getPartitionedState( - namespace, namespaceSerializer, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .clear(); } catch (Exception e) { throw new RuntimeException("Error clearing state.", e); @@ -1559,6 +1549,7 @@ private static class FlinkSetState implements SetState { this.flinkStateDescriptor = new MapStateDescriptor<>( stateId, new CoderTypeSerializer<>(coder, fasterCopy), BooleanSerializer.INSTANCE); + this.namespaceSerializer = namespaceSerializer; } @Override @@ -1566,8 +1557,7 @@ public ReadableState contains(final T t) { try { Boolean result = flinkStateBackend - .getPartitionedState( - namespace, namespaceSerializer, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .get(t); return ReadableStates.immediate(result != null && result); } catch (Exception e) { @@ -1595,8 +1585,7 @@ public ReadableState addIfAbsent(final T t) { public void remove(T t) { try { flinkStateBackend - .getPartitionedState( - namespace, namespaceSerializer, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .remove(t); } catch (Exception e) { throw new RuntimeException("Error remove value to state.", e); @@ -1612,8 +1601,7 @@ public SetState readLater() { public void add(T value) { try { flinkStateBackend - .getPartitionedState( - namespace, namespaceSerializer, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .put(value, true); } catch (Exception e) { throw new RuntimeException("Error add value to state.", e); @@ -1628,8 +1616,7 @@ public Boolean read() { try { Iterable result = flinkStateBackend - .getPartitionedState( - namespace, namespaceSerializer, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .keys(); return result == null || Iterables.isEmpty(result); } catch (Exception e) { @@ -1649,8 +1636,7 @@ public Iterable read() { try { Iterable result = flinkStateBackend - .getPartitionedState( - namespace, namespaceSerializer, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .keys(); return result != null ? ImmutableList.copyOf(result) : Collections.emptyList(); } catch (Exception e) { @@ -1662,8 +1648,7 @@ public Iterable read() { public void clear() { try { flinkStateBackend - .getPartitionedState( - namespace, namespaceSerializer, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .clear(); } catch (Exception e) { throw new RuntimeException("Error clearing state.", e); @@ -1723,7 +1708,9 @@ public static class EarlyBinder implements StateBinder { private final FlinkStateNamespaceKeySerializer namespaceSerializer; public EarlyBinder( - KeyedStateBackend keyedStateBackend, SerializablePipelineOptions pipelineOptions, Coder windowCoder) { + KeyedStateBackend keyedStateBackend, + SerializablePipelineOptions pipelineOptions, + Coder windowCoder) { this.keyedStateBackend = keyedStateBackend; this.fasterCopy = pipelineOptions.get().as(FlinkPipelineOptions.class).getFasterCopy(); this.namespaceSerializer = new FlinkStateNamespaceKeySerializer(windowCoder); diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java index 4a25e06c6701..2cc0c8c7c13a 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java @@ -800,10 +800,10 @@ public void testGCForGlobalWindow() throws Exception { assertThat(testHarness.numKeyedStateEntries(), is(2)); // Cleanup due to end of global window -// testHarness.processWatermark( -// GlobalWindow.INSTANCE.maxTimestamp().plus(Duration.millis(2)).getMillis()); -// assertThat(testHarness.numEventTimeTimers(), is(0)); -// assertThat(testHarness.numKeyedStateEntries(), is(0)); + // testHarness.processWatermark( + // GlobalWindow.INSTANCE.maxTimestamp().plus(Duration.millis(2)).getMillis()); + // assertThat(testHarness.numEventTimeTimers(), is(0)); + // assertThat(testHarness.numKeyedStateEntries(), is(0)); // Any new state will also be cleaned up on close testHarness.processElement( diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperatorTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperatorTest.java index 22713f6b33c6..a5dc643c5ca0 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperatorTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperatorTest.java @@ -16,301 +16,293 @@ * limitations under the License. */ package org.apache.beam.runners.flink.translation.wrappers.streaming; -// -// import static java.util.Collections.emptyList; -// import static java.util.Collections.emptyMap; -// import static -// org.apache.beam.runners.flink.translation.wrappers.streaming.StreamRecordStripper.stripStreamRecordFromWindowedValue; -// import static org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING; -// import static org.apache.beam.sdk.transforms.windowing.PaneInfo.Timing.ON_TIME; -// import static org.hamcrest.MatcherAssert.assertThat; -// import static org.hamcrest.Matchers.containsInAnyOrder; -// import static org.hamcrest.core.Is.is; -// import static org.joda.time.Duration.standardMinutes; -// import static org.junit.Assert.assertEquals; -// -// import java.io.ByteArrayOutputStream; -// import java.nio.ByteBuffer; -// import org.apache.beam.runners.core.KeyedWorkItem; -// import org.apache.beam.runners.core.SystemReduceFn; -// import org.apache.beam.runners.core.construction.SerializablePipelineOptions; -// import org.apache.beam.runners.flink.FlinkPipelineOptions; -// import -// org.apache.beam.runners.flink.translation.wrappers.streaming.DoFnOperator.MultiOutputOutputManagerFactory; -// import org.apache.beam.sdk.coders.Coder; -// import org.apache.beam.sdk.coders.CoderRegistry; -// import org.apache.beam.sdk.coders.KvCoder; -// import org.apache.beam.sdk.coders.VarLongCoder; -// import org.apache.beam.sdk.transforms.Sum; -// import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -// import org.apache.beam.sdk.transforms.windowing.FixedWindows; -// import org.apache.beam.sdk.transforms.windowing.IntervalWindow; -// import org.apache.beam.sdk.transforms.windowing.PaneInfo; -// import org.apache.beam.sdk.util.AppliedCombineFn; -// import org.apache.beam.sdk.util.WindowedValue; -// import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder; -// import org.apache.beam.sdk.values.KV; -// import org.apache.beam.sdk.values.TupleTag; -// import org.apache.beam.sdk.values.WindowingStrategy; -// import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; -// import org.apache.flink.api.java.functions.KeySelector; -// import org.apache.flink.api.java.typeutils.GenericTypeInfo; -// import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; -// import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; -// import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; -// import org.joda.time.Duration; -// import org.joda.time.Instant; -// import org.junit.Test; -// import org.junit.runner.RunWith; -// import org.junit.runners.JUnit4; -// -/// ** Tests for {@link WindowDoFnOperator}. */ -// @RunWith(JUnit4.class) -// @SuppressWarnings({ -// "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) -// }) -// public class WindowDoFnOperatorTest { -// -// @Test -// public void testRestore() throws Exception { -// // test harness -// KeyedOneInputStreamOperatorTestHarness< -// ByteBuffer, WindowedValue>, WindowedValue>> -// testHarness = createTestHarness(getWindowDoFnOperator()); -// testHarness.open(); -// -// // process elements -// IntervalWindow window = new IntervalWindow(new Instant(0), Duration.millis(10_000)); -// testHarness.processWatermark(0L); -// testHarness.processElement( -// Item.builder().key(1L).timestamp(1L).value(100L).window(window).build().toStreamRecord()); -// testHarness.processElement( -// Item.builder().key(1L).timestamp(2L).value(20L).window(window).build().toStreamRecord()); -// testHarness.processElement( -// Item.builder().key(2L).timestamp(3L).value(77L).window(window).build().toStreamRecord()); -// -// // create snapshot -// OperatorSubtaskState snapshot = testHarness.snapshot(0, 0); -// testHarness.close(); -// -// // restore from the snapshot -// testHarness = createTestHarness(getWindowDoFnOperator()); -// testHarness.initializeState(snapshot); -// testHarness.open(); -// -// // close window -// testHarness.processWatermark(10_000L); -// -// Iterable>> output = -// stripStreamRecordFromWindowedValue(testHarness.getOutput()); -// -// assertEquals(2, Iterables.size(output)); -// assertThat( -// output, -// containsInAnyOrder( -// WindowedValue.of( -// KV.of(1L, 120L), -// new Instant(9_999), -// window, -// PaneInfo.createPane(true, true, ON_TIME)), -// WindowedValue.of( -// KV.of(2L, 77L), -// new Instant(9_999), -// window, -// PaneInfo.createPane(true, true, ON_TIME)))); -// // cleanup -// testHarness.close(); -// } -// -// @Test -// public void testTimerCleanupOfPendingTimerList() throws Exception { -// // test harness -// WindowDoFnOperator windowDoFnOperator = getWindowDoFnOperator(); -// KeyedOneInputStreamOperatorTestHarness< -// ByteBuffer, WindowedValue>, WindowedValue>> -// testHarness = createTestHarness(windowDoFnOperator); -// testHarness.open(); -// -// DoFnOperator, KeyedWorkItem, KV>.FlinkTimerInternals -// timerInternals = -// windowDoFnOperator.timerInternals; -// -// // process elements -// IntervalWindow window = new IntervalWindow(new Instant(0), Duration.millis(100)); -// IntervalWindow window2 = new IntervalWindow(new Instant(100), Duration.millis(100)); -// testHarness.processWatermark(0L); -// -// // Use two different keys to check for correct watermark hold calculation -// testHarness.processElement( -// Item.builder().key(1L).timestamp(1L).value(100L).window(window).build().toStreamRecord()); -// testHarness.processElement( -// Item.builder() -// .key(2L) -// .timestamp(150L) -// .value(150L) -// .window(window2) -// .build() -// .toStreamRecord()); -// -// testHarness.processWatermark(1); -// -// // Note that the following is 1 because the state is key-partitioned -// assertThat(Iterables.size(timerInternals.pendingTimersById.keys()), is(1)); -// -// assertThat(testHarness.numKeyedStateEntries(), is(6)); -// // close bundle -// testHarness.setProcessingTime( -// testHarness.getProcessingTime() -// + 2 * FlinkPipelineOptions.defaults().getMaxBundleTimeMills()); -// assertThat(windowDoFnOperator.getCurrentOutputWatermark(), is(1L)); -// -// // close window -// testHarness.processWatermark(100L); -// -// // Note that the following is zero because we only the first key is active -// assertThat(Iterables.size(timerInternals.pendingTimersById.keys()), is(0)); -// -// assertThat(testHarness.numKeyedStateEntries(), is(3)); -// -// // close bundle -// testHarness.setProcessingTime( -// testHarness.getProcessingTime() -// + 2 * FlinkPipelineOptions.defaults().getMaxBundleTimeMills()); -// assertThat(windowDoFnOperator.getCurrentOutputWatermark(), is(100L)); -// -// testHarness.processWatermark(200L); -// -// // All the state has been cleaned up -// assertThat(testHarness.numKeyedStateEntries(), is(0)); -// -// assertThat( -// stripStreamRecordFromWindowedValue(testHarness.getOutput()), -// containsInAnyOrder( -// WindowedValue.of( -// KV.of(1L, 100L), new Instant(99), window, PaneInfo.createPane(true, true, -// ON_TIME)), -// WindowedValue.of( -// KV.of(2L, 150L), -// new Instant(199), -// window2, -// PaneInfo.createPane(true, true, ON_TIME)))); -// -// // cleanup -// testHarness.close(); -// } -// -// private WindowDoFnOperator getWindowDoFnOperator() { -// WindowingStrategy windowingStrategy = -// WindowingStrategy.of(FixedWindows.of(standardMinutes(1))); -// -// TupleTag> outputTag = new TupleTag<>("main-output"); -// -// SystemReduceFn reduceFn = -// SystemReduceFn.combining( -// VarLongCoder.of(), -// AppliedCombineFn.withInputCoder( -// Sum.ofLongs(), -// CoderRegistry.createDefault(), -// KvCoder.of(VarLongCoder.of(), VarLongCoder.of()))); -// -// Coder windowCoder = windowingStrategy.getWindowFn().windowCoder(); -// SingletonKeyedWorkItemCoder workItemCoder = -// SingletonKeyedWorkItemCoder.of(VarLongCoder.of(), VarLongCoder.of(), windowCoder); -// FullWindowedValueCoder> inputCoder = -// WindowedValue.getFullCoder(workItemCoder, windowCoder); -// FullWindowedValueCoder> outputCoder = -// WindowedValue.getFullCoder(KvCoder.of(VarLongCoder.of(), VarLongCoder.of()), windowCoder); -// -// return new WindowDoFnOperator( -// reduceFn, -// "stepName", -// (Coder) inputCoder, -// outputTag, -// emptyList(), -// new MultiOutputOutputManagerFactory<>( -// outputTag, -// outputCoder, -// new SerializablePipelineOptions(FlinkPipelineOptions.defaults())), -// windowingStrategy, -// emptyMap(), -// emptyList(), -// FlinkPipelineOptions.defaults(), -// VarLongCoder.of(), -// new WorkItemKeySelector( -// VarLongCoder.of(), new SerializablePipelineOptions(FlinkPipelineOptions.defaults()))); -// } -// -// private KeyedOneInputStreamOperatorTestHarness< -// ByteBuffer, WindowedValue>, WindowedValue>> -// createTestHarness(WindowDoFnOperator windowDoFnOperator) throws Exception -// { -// return new KeyedOneInputStreamOperatorTestHarness<>( -// windowDoFnOperator, -// (KeySelector>, ByteBuffer>) -// o -> { -// try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) { -// VarLongCoder.of().encode(o.getValue().key(), baos); -// return ByteBuffer.wrap(baos.toByteArray()); -// } -// }, -// new GenericTypeInfo<>(ByteBuffer.class)); -// } -// -// private static class Item { -// -// static ItemBuilder builder() { -// return new ItemBuilder(); -// } -// -// private long key; -// private long value; -// private long timestamp; -// private IntervalWindow window; -// -// StreamRecord>> toStreamRecord() { -// WindowedValue item = WindowedValue.of(value, new Instant(timestamp), window, -// NO_FIRING); -// WindowedValue> keyedItem = -// WindowedValue.of( -// new SingletonKeyedWorkItem<>(key, item), new Instant(timestamp), window, NO_FIRING); -// return new StreamRecord<>(keyedItem); -// } -// -// private static final class ItemBuilder { -// -// private long key; -// private long value; -// private long timestamp; -// private IntervalWindow window; -// -// ItemBuilder key(long key) { -// this.key = key; -// return this; -// } -// -// ItemBuilder value(long value) { -// this.value = value; -// return this; -// } -// -// ItemBuilder timestamp(long timestamp) { -// this.timestamp = timestamp; -// return this; -// } -// -// ItemBuilder window(IntervalWindow window) { -// this.window = window; -// return this; -// } -// -// Item build() { -// Item item = new Item(); -// item.key = this.key; -// item.value = this.value; -// item.window = this.window; -// item.timestamp = this.timestamp; -// return item; -// } -// } -// } -// } + +import static java.util.Collections.emptyList; +import static java.util.Collections.emptyMap; +import static org.apache.beam.runners.flink.translation.wrappers.streaming.StreamRecordStripper.stripStreamRecordFromWindowedValue; +import static org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING; +import static org.apache.beam.sdk.transforms.windowing.PaneInfo.Timing.ON_TIME; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.core.Is.is; +import static org.joda.time.Duration.standardMinutes; +import static org.junit.Assert.assertEquals; + +import java.io.ByteArrayOutputStream; +import java.nio.ByteBuffer; +import org.apache.beam.runners.core.KeyedWorkItem; +import org.apache.beam.runners.core.SystemReduceFn; +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.runners.flink.FlinkPipelineOptions; +import org.apache.beam.runners.flink.translation.wrappers.streaming.DoFnOperator.MultiOutputOutputManagerFactory; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderRegistry; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.sdk.transforms.Sum; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.FixedWindows; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.util.AppliedCombineFn; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.typeutils.GenericTypeInfo; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link WindowDoFnOperator}. */ +@RunWith(JUnit4.class) +@SuppressWarnings({ + "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) +}) +public class WindowDoFnOperatorTest { + + @Test + public void testRestore() throws Exception { + // test harness + KeyedOneInputStreamOperatorTestHarness< + ByteBuffer, WindowedValue>, WindowedValue>> + testHarness = createTestHarness(getWindowDoFnOperator()); + testHarness.open(); + + // process elements + IntervalWindow window = new IntervalWindow(new Instant(0), Duration.millis(10_000)); + testHarness.processWatermark(0L); + testHarness.processElement( + Item.builder().key(1L).timestamp(1L).value(100L).window(window).build().toStreamRecord()); + testHarness.processElement( + Item.builder().key(1L).timestamp(2L).value(20L).window(window).build().toStreamRecord()); + testHarness.processElement( + Item.builder().key(2L).timestamp(3L).value(77L).window(window).build().toStreamRecord()); + + // create snapshot + OperatorSubtaskState snapshot = testHarness.snapshot(0, 0); + testHarness.close(); + + // restore from the snapshot + testHarness = createTestHarness(getWindowDoFnOperator()); + testHarness.initializeState(snapshot); + testHarness.open(); + + // close window + testHarness.processWatermark(10_000L); + + Iterable>> output = + stripStreamRecordFromWindowedValue(testHarness.getOutput()); + + assertEquals(2, Iterables.size(output)); + assertThat( + output, + containsInAnyOrder( + WindowedValue.of( + KV.of(1L, 120L), + new Instant(9_999), + window, + PaneInfo.createPane(true, true, ON_TIME)), + WindowedValue.of( + KV.of(2L, 77L), + new Instant(9_999), + window, + PaneInfo.createPane(true, true, ON_TIME)))); + // cleanup + testHarness.close(); + } + + @Test + public void testTimerCleanupOfPendingTimerList() throws Exception { + // test harness + WindowDoFnOperator windowDoFnOperator = getWindowDoFnOperator(); + KeyedOneInputStreamOperatorTestHarness< + ByteBuffer, WindowedValue>, WindowedValue>> + testHarness = createTestHarness(windowDoFnOperator); + testHarness.open(); + + DoFnOperator, KeyedWorkItem, KV>.FlinkTimerInternals + timerInternals = windowDoFnOperator.timerInternals; + + // process elements + IntervalWindow window = new IntervalWindow(new Instant(0), Duration.millis(100)); + IntervalWindow window2 = new IntervalWindow(new Instant(100), Duration.millis(100)); + testHarness.processWatermark(0L); + + // Use two different keys to check for correct watermark hold calculation + testHarness.processElement( + Item.builder().key(1L).timestamp(1L).value(100L).window(window).build().toStreamRecord()); + testHarness.processElement( + Item.builder() + .key(2L) + .timestamp(150L) + .value(150L) + .window(window2) + .build() + .toStreamRecord()); + + testHarness.processWatermark(1); + + // Note that the following is 1 because the state is key-partitioned + assertThat(Iterables.size(timerInternals.pendingTimersById.keys()), is(1)); + + assertThat(testHarness.numKeyedStateEntries(), is(6)); + // close bundle + testHarness.setProcessingTime( + testHarness.getProcessingTime() + + 2 * FlinkPipelineOptions.defaults().getMaxBundleTimeMills()); + assertThat(windowDoFnOperator.getCurrentOutputWatermark(), is(1L)); + + // close window + testHarness.processWatermark(100L); + + // Note that the following is zero because we only the first key is active + assertThat(Iterables.size(timerInternals.pendingTimersById.keys()), is(0)); + + assertThat(testHarness.numKeyedStateEntries(), is(3)); + + // close bundle + testHarness.setProcessingTime( + testHarness.getProcessingTime() + + 2 * FlinkPipelineOptions.defaults().getMaxBundleTimeMills()); + assertThat(windowDoFnOperator.getCurrentOutputWatermark(), is(100L)); + + testHarness.processWatermark(200L); + + // All the state has been cleaned up + assertThat(testHarness.numKeyedStateEntries(), is(0)); + + assertThat( + stripStreamRecordFromWindowedValue(testHarness.getOutput()), + containsInAnyOrder( + WindowedValue.of( + KV.of(1L, 100L), new Instant(99), window, PaneInfo.createPane(true, true, ON_TIME)), + WindowedValue.of( + KV.of(2L, 150L), + new Instant(199), + window2, + PaneInfo.createPane(true, true, ON_TIME)))); + + // cleanup + testHarness.close(); + } + + private WindowDoFnOperator getWindowDoFnOperator() { + WindowingStrategy windowingStrategy = + WindowingStrategy.of(FixedWindows.of(standardMinutes(1))); + + TupleTag> outputTag = new TupleTag<>("main-output"); + + SystemReduceFn reduceFn = + SystemReduceFn.combining( + VarLongCoder.of(), + AppliedCombineFn.withInputCoder( + Sum.ofLongs(), + CoderRegistry.createDefault(), + KvCoder.of(VarLongCoder.of(), VarLongCoder.of()))); + + Coder windowCoder = windowingStrategy.getWindowFn().windowCoder(); + SingletonKeyedWorkItemCoder workItemCoder = + SingletonKeyedWorkItemCoder.of(VarLongCoder.of(), VarLongCoder.of(), windowCoder); + FullWindowedValueCoder> inputCoder = + WindowedValue.getFullCoder(workItemCoder, windowCoder); + FullWindowedValueCoder> outputCoder = + WindowedValue.getFullCoder(KvCoder.of(VarLongCoder.of(), VarLongCoder.of()), windowCoder); + + return new WindowDoFnOperator( + reduceFn, + "stepName", + (Coder) inputCoder, + outputTag, + emptyList(), + new MultiOutputOutputManagerFactory<>( + outputTag, + outputCoder, + new SerializablePipelineOptions(FlinkPipelineOptions.defaults())), + windowingStrategy, + emptyMap(), + emptyList(), + FlinkPipelineOptions.defaults(), + VarLongCoder.of(), + new WorkItemKeySelector( + VarLongCoder.of(), new SerializablePipelineOptions(FlinkPipelineOptions.defaults()))); + } + + private KeyedOneInputStreamOperatorTestHarness< + ByteBuffer, WindowedValue>, WindowedValue>> + createTestHarness(WindowDoFnOperator windowDoFnOperator) throws Exception { + return new KeyedOneInputStreamOperatorTestHarness<>( + windowDoFnOperator, + (KeySelector>, ByteBuffer>) + o -> { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) { + VarLongCoder.of().encode(o.getValue().getKey(), baos); + return ByteBuffer.wrap(baos.toByteArray()); + } + }, + new GenericTypeInfo<>(ByteBuffer.class)); + } + + private static class Item { + + static ItemBuilder builder() { + return new ItemBuilder(); + } + + private long key; + private long value; + private long timestamp; + private IntervalWindow window; + + StreamRecord>> toStreamRecord() { + WindowedValue> keyedItem = + WindowedValue.of(KV.of(key, value), new Instant(timestamp), window, NO_FIRING); + return new StreamRecord<>(keyedItem); + } + + private static final class ItemBuilder { + + private long key; + private long value; + private long timestamp; + private IntervalWindow window; + + ItemBuilder key(long key) { + this.key = key; + return this; + } + + ItemBuilder value(long value) { + this.value = value; + return this; + } + + ItemBuilder timestamp(long timestamp) { + this.timestamp = timestamp; + return this; + } + + ItemBuilder window(IntervalWindow window) { + this.window = window; + return this; + } + + Item build() { + Item item = new Item(); + item.key = this.key; + item.value = this.value; + item.window = this.window; + item.timestamp = this.timestamp; + return item; + } + } + } +} diff --git a/website/www/site/layouts/shortcodes/flink_java_pipeline_options.html b/website/www/site/layouts/shortcodes/flink_java_pipeline_options.html index 939c64ed9c49..d572851acfbd 100644 --- a/website/www/site/layouts/shortcodes/flink_java_pipeline_options.html +++ b/website/www/site/layouts/shortcodes/flink_java_pipeline_options.html @@ -107,6 +107,11 @@ Address of the Flink Master where the Pipeline should be executed. Can either be of the form "host:port" or one of the special values [local], [collection] or [auto]. Default: [auto] + + forceSlotSharingGroup + Set a slot sharing group for all bounded sources. This is required when using Datastream to have the same scheduling behaviour as the Dataset API. + Default: true + jobCheckIntervalInSecs Set job check interval in seconds under detached mode in method waitUntilFinish, by default it is 5 seconds diff --git a/website/www/site/layouts/shortcodes/flink_python_pipeline_options.html b/website/www/site/layouts/shortcodes/flink_python_pipeline_options.html index eb5c525d78b7..37a7c4489ccc 100644 --- a/website/www/site/layouts/shortcodes/flink_python_pipeline_options.html +++ b/website/www/site/layouts/shortcodes/flink_python_pipeline_options.html @@ -107,6 +107,11 @@ Address of the Flink Master where the Pipeline should be executed. Can either be of the form "host:port" or one of the special values [local], [collection] or [auto]. Default: [auto] + + force_slot_sharing_group + Set a slot sharing group for all bounded sources. This is required when using Datastream to have the same scheduling behaviour as the Dataset API. + Default: true + job_check_interval_in_secs Set job check interval in seconds under detached mode in method waitUntilFinish, by default it is 5 seconds From aeb8937730e3314958483e72cf3a321a66f3dd8b Mon Sep 17 00:00:00 2001 From: jto Date: Thu, 29 Aug 2024 14:50:58 +0200 Subject: [PATCH 17/19] [Flink] fix WindowDoFnOperatorTest --- .../streaming/WindowDoFnOperatorTest.java | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperatorTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperatorTest.java index a5dc643c5ca0..408e8d05a4a0 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperatorTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperatorTest.java @@ -74,7 +74,7 @@ public void testRestore() throws Exception { // test harness KeyedOneInputStreamOperatorTestHarness< ByteBuffer, WindowedValue>, WindowedValue>> - testHarness = createTestHarness(getWindowDoFnOperator()); + testHarness = createTestHarness(getWindowDoFnOperator(true)); testHarness.open(); // process elements @@ -92,7 +92,7 @@ public void testRestore() throws Exception { testHarness.close(); // restore from the snapshot - testHarness = createTestHarness(getWindowDoFnOperator()); + testHarness = createTestHarness(getWindowDoFnOperator(true)); testHarness.initializeState(snapshot); testHarness.open(); @@ -123,7 +123,7 @@ public void testRestore() throws Exception { @Test public void testTimerCleanupOfPendingTimerList() throws Exception { // test harness - WindowDoFnOperator windowDoFnOperator = getWindowDoFnOperator(); + WindowDoFnOperator windowDoFnOperator = getWindowDoFnOperator(true); KeyedOneInputStreamOperatorTestHarness< ByteBuffer, WindowedValue>, WindowedValue>> testHarness = createTestHarness(windowDoFnOperator); @@ -195,7 +195,7 @@ public void testTimerCleanupOfPendingTimerList() throws Exception { testHarness.close(); } - private WindowDoFnOperator getWindowDoFnOperator() { + private WindowDoFnOperator getWindowDoFnOperator(boolean streaming) { WindowingStrategy windowingStrategy = WindowingStrategy.of(FixedWindows.of(standardMinutes(1))); @@ -217,6 +217,9 @@ private WindowDoFnOperator getWindowDoFnOperator() { FullWindowedValueCoder> outputCoder = WindowedValue.getFullCoder(KvCoder.of(VarLongCoder.of(), VarLongCoder.of()), windowCoder); + FlinkPipelineOptions options = FlinkPipelineOptions.defaults(); + options.setStreaming(streaming); + return new WindowDoFnOperator( reduceFn, "stepName", @@ -224,16 +227,13 @@ private WindowDoFnOperator getWindowDoFnOperator() { outputTag, emptyList(), new MultiOutputOutputManagerFactory<>( - outputTag, - outputCoder, - new SerializablePipelineOptions(FlinkPipelineOptions.defaults())), + outputTag, outputCoder, new SerializablePipelineOptions(options)), windowingStrategy, emptyMap(), emptyList(), - FlinkPipelineOptions.defaults(), + options, VarLongCoder.of(), - new WorkItemKeySelector( - VarLongCoder.of(), new SerializablePipelineOptions(FlinkPipelineOptions.defaults()))); + new WorkItemKeySelector(VarLongCoder.of(), new SerializablePipelineOptions(options))); } private KeyedOneInputStreamOperatorTestHarness< From 5af3269e9a8d3e4f2be6e388b93c171e7fdccf2b Mon Sep 17 00:00:00 2001 From: jto Date: Fri, 30 Aug 2024 11:54:46 +0200 Subject: [PATCH 18/19] [Flink] spotless --- .../runners/core/GroupAlsoByWindowViaWindowSetNewDoFn.java | 3 ++- .../flink/translation/types/CoderTypeSerializer.java | 7 ++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/GroupAlsoByWindowViaWindowSetNewDoFn.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/GroupAlsoByWindowViaWindowSetNewDoFn.java index cc657413f6f1..3e42bb54494e 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/GroupAlsoByWindowViaWindowSetNewDoFn.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/GroupAlsoByWindowViaWindowSetNewDoFn.java @@ -126,7 +126,8 @@ public void processElement(ProcessContext c) throws Exception { new ReduceFnRunner<>( key, windowingStrategy, - ExecutableTriggerStateMachine.create(TriggerStateMachines.stateMachineForTrigger(triggerProto)), + ExecutableTriggerStateMachine.create( + TriggerStateMachines.stateMachineForTrigger(triggerProto)), stateInternals, timerInternals, outputWindowedValue(), diff --git a/runners/flink/1.14/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java b/runners/flink/1.14/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java index 6c21ea8edc00..decee51128a4 100644 --- a/runners/flink/1.14/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java +++ b/runners/flink/1.14/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java @@ -53,7 +53,12 @@ public class CoderTypeSerializer extends TypeSerializer { private final boolean fasterCopy; public CoderTypeSerializer(Coder coder, SerializablePipelineOptions pipelineOptions) { - this(coder, Preconditions.checkNotNull(pipelineOptions).get().as(FlinkPipelineOptions.class).getFasterCopy()); + this( + coder, + Preconditions.checkNotNull(pipelineOptions) + .get() + .as(FlinkPipelineOptions.class) + .getFasterCopy()); } public CoderTypeSerializer(Coder coder, boolean fasterCopy) { From b672504e8669349c6527bbda87cff0cd60d6e763 Mon Sep 17 00:00:00 2001 From: jto Date: Tue, 10 Sep 2024 11:02:03 +0200 Subject: [PATCH 19/19] [Flink] fix broken tests --- .../runners/flink/FlinkSubmissionTest.java | 3 ++- .../wrappers/streaming/DoFnOperatorTest.java | 18 ++++++++++-------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkSubmissionTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkSubmissionTest.java index 22a9ce4f39ab..d411a38b22dc 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkSubmissionTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkSubmissionTest.java @@ -135,6 +135,7 @@ public void testDetachedSubmissionStreaming() throws Exception { private void runSubmission(boolean isDetached, boolean isStreaming) throws Exception { PipelineOptions options = PipelineOptionsFactory.create(); + options.as(FlinkPipelineOptions.class).setStreaming(isStreaming); options.setTempLocation(TEMP_FOLDER.getRoot().getPath()); String jarPath = Iterables.getFirst( @@ -171,7 +172,7 @@ private void waitUntilJobIsCompleted() throws Exception { .allMatch(jobStatus -> jobStatus.getJobState().name().equals("FINISHED"))) { return; } - Thread.sleep(50); + Thread.sleep(100); } } diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java index 2cc0c8c7c13a..67e21a17bc6b 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java @@ -800,10 +800,10 @@ public void testGCForGlobalWindow() throws Exception { assertThat(testHarness.numKeyedStateEntries(), is(2)); // Cleanup due to end of global window - // testHarness.processWatermark( - // GlobalWindow.INSTANCE.maxTimestamp().plus(Duration.millis(2)).getMillis()); - // assertThat(testHarness.numEventTimeTimers(), is(0)); - // assertThat(testHarness.numKeyedStateEntries(), is(0)); + testHarness.processWatermark( + GlobalWindow.INSTANCE.maxTimestamp().plus(Duration.millis(2)).getMillis()); + assertThat(testHarness.numEventTimeTimers(), is(0)); + assertThat(testHarness.numKeyedStateEntries(), is(0)); // Any new state will also be cleaned up on close testHarness.processElement( @@ -866,6 +866,9 @@ public void onTimer(OnTimerContext context, @StateId(stateId) ValueState KeySelector>, ByteBuffer> keySelector = e -> FlinkKeyUtils.encodeKey(e.getValue().getKey(), StringUtf8Coder.of()); + FlinkPipelineOptions options = FlinkPipelineOptions.defaults(); + options.setStreaming(true); + DoFnOperator, KV, KV> doFnOperator = new DoFnOperator<>( fn, @@ -875,11 +878,11 @@ public void onTimer(OnTimerContext context, @StateId(stateId) ValueState outputTag, Collections.emptyList(), new DoFnOperator.MultiOutputOutputManagerFactory<>( - outputTag, coder, new SerializablePipelineOptions(FlinkPipelineOptions.defaults())), + outputTag, coder, new SerializablePipelineOptions(options)), windowingStrategy, new HashMap<>(), /* side-input mapping */ Collections.emptyList(), /* side inputs */ - FlinkPipelineOptions.defaults(), + options, StringUtf8Coder.of(), /* key coder */ keySelector, DoFnSchemaInformation.create(), @@ -888,8 +891,7 @@ outputTag, coder, new SerializablePipelineOptions(FlinkPipelineOptions.defaults( return new KeyedOneInputStreamOperatorTestHarness<>( doFnOperator, keySelector, - new CoderTypeInformation<>( - FlinkKeyUtils.ByteBufferCoder.of(), FlinkPipelineOptions.defaults())); + new CoderTypeInformation<>(FlinkKeyUtils.ByteBufferCoder.of(), options)); } @Test