diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java index b4ce64fd700b..34250a5019c8 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java @@ -388,11 +388,7 @@ public void initializeState(StateInitializationContext context) throws Exception outputManager = outputManagerFactory.create( - output, - getLockToAcquireForStateAccessDuringBundles(), - getOperatorStateBackend(), - getKeyedStateBackend(), - keySelector); + output, getLockToAcquireForStateAccessDuringBundles(), getOperatorStateBackend()); } /** @@ -828,9 +824,7 @@ interface OutputManagerFactory extends Serializable { BufferedOutputManager create( Output>> output, Lock bufferLock, - @Nullable OperatorStateBackend operatorStateBackend, - @Nullable KeyedStateBackend keyedStateBackend, - @Nullable KeySelector keySelector) + OperatorStateBackend operatorStateBackend) throws Exception; } @@ -1018,35 +1012,19 @@ public MultiOutputOutputManagerFactory( public BufferedOutputManager create( Output>> output, Lock bufferLock, - OperatorStateBackend operatorStateBackend, - @Nullable KeyedStateBackend keyedStateBackend, - @Nullable KeySelector keySelector) + OperatorStateBackend operatorStateBackend) throws Exception { Preconditions.checkNotNull(output); Preconditions.checkNotNull(bufferLock); Preconditions.checkNotNull(operatorStateBackend); - Preconditions.checkState( - (keyedStateBackend == null) == (keySelector == null), - "Either both KeyedStatebackend and Keyselector are provided or none."); TaggedKvCoder taggedKvCoder = buildTaggedKvCoder(); ListStateDescriptor>> taggedOutputPushbackStateDescriptor = new ListStateDescriptor<>("bundle-buffer-tag", new CoderTypeSerializer<>(taggedKvCoder)); - - final PushedBackElementsHandler>> pushedBackElementsHandler; - if (keyedStateBackend != null) { - // build a key selector for the tagged output - KeySelector>, ?> taggedValueKeySelector = - (KeySelector>, Object>) - value -> keySelector.getKey(value.getValue()); - pushedBackElementsHandler = - KeyedPushedBackElementsHandler.create( - taggedValueKeySelector, keyedStateBackend, taggedOutputPushbackStateDescriptor); - } else { - ListState>> listState = - operatorStateBackend.getListState(taggedOutputPushbackStateDescriptor); - pushedBackElementsHandler = NonKeyedPushedBackElementsHandler.create(listState); - } + ListState>> listStateBuffer = + operatorStateBackend.getListState(taggedOutputPushbackStateDescriptor); + PushedBackElementsHandler>> pushedBackElementsHandler = + NonKeyedPushedBackElementsHandler.create(listStateBuffer); return new BufferedOutputManager<>( output, mainTag, tagsToOutputTags, tagsToIds, bufferLock, pushedBackElementsHandler); diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java index 6e96b3b712fd..c5eca1ee61f8 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java @@ -1313,20 +1313,25 @@ public void testBundleKeyed() throws Exception { options.setMaxBundleSize(2L); options.setMaxBundleTimeMills(10L); - IdentityDoFn> doFn = - new IdentityDoFn>() { + DoFn, String> doFn = + new DoFn, String>() { + @ProcessElement + public void processElement(ProcessContext ctx) { + // Change output type of element to test that we do not depend on the input keying + ctx.output(ctx.element().getValue()); + } + @FinishBundle public void finishBundle(FinishBundleContext context) { context.output( - KV.of("key2", "finishBundle"), - BoundedWindow.TIMESTAMP_MIN_VALUE, - GlobalWindow.INSTANCE); + "finishBundle", BoundedWindow.TIMESTAMP_MIN_VALUE, GlobalWindow.INSTANCE); } }; - DoFnOperator.MultiOutputOutputManagerFactory> outputManagerFactory = + DoFnOperator.MultiOutputOutputManagerFactory outputManagerFactory = new DoFnOperator.MultiOutputOutputManagerFactory( - outputTag, WindowedValue.getFullCoder(kvCoder, GlobalWindow.Coder.INSTANCE)); + outputTag, + WindowedValue.getFullCoder(kvCoder.getValueCoder(), GlobalWindow.Coder.INSTANCE)); DoFnOperator, KV> doFnOperator = new DoFnOperator( @@ -1347,8 +1352,7 @@ public void finishBundle(FinishBundleContext context) { DoFnSchemaInformation.create(), Collections.emptyMap()); - OneInputStreamOperatorTestHarness< - WindowedValue>, WindowedValue>> + OneInputStreamOperatorTestHarness>, WindowedValue> testHarness = new KeyedOneInputStreamOperatorTestHarness( doFnOperator, keySelector, keySelector.getProducedType()); @@ -1365,10 +1369,10 @@ public void finishBundle(FinishBundleContext context) { assertThat( stripStreamRecordFromWindowedValue(testHarness.getOutput()), contains( - WindowedValue.valueInGlobalWindow(KV.of("key", "a")), - WindowedValue.valueInGlobalWindow(KV.of("key", "b")), - WindowedValue.valueInGlobalWindow(KV.of("key2", "finishBundle")), - WindowedValue.valueInGlobalWindow(KV.of("key", "c")))); + WindowedValue.valueInGlobalWindow("a"), + WindowedValue.valueInGlobalWindow("b"), + WindowedValue.valueInGlobalWindow("finishBundle"), + WindowedValue.valueInGlobalWindow("c"))); // Take a snapshot OperatorSubtaskState snapshot = testHarness.snapshot(0, 0); @@ -1376,12 +1380,11 @@ public void finishBundle(FinishBundleContext context) { // Finish bundle element will be buffered as part of finishing a bundle in snapshot() PushedBackElementsHandler>> pushedBackElementsHandler = doFnOperator.outputManager.pushedBackElementsHandler; - assertThat(pushedBackElementsHandler, instanceOf(KeyedPushedBackElementsHandler.class)); + assertThat(pushedBackElementsHandler, instanceOf(NonKeyedPushedBackElementsHandler.class)); List>> bufferedElements = pushedBackElementsHandler.getElements().collect(Collectors.toList()); assertThat( - bufferedElements, - contains(KV.of(0, WindowedValue.valueInGlobalWindow(KV.of("key2", "finishBundle"))))); + bufferedElements, contains(KV.of(0, WindowedValue.valueInGlobalWindow("finishBundle")))); testHarness.close(); @@ -1424,9 +1427,9 @@ public void finishBundle(FinishBundleContext context) { stripStreamRecordFromWindowedValue(testHarness.getOutput()), contains( // The first finishBundle is restored from the checkpoint - WindowedValue.valueInGlobalWindow(KV.of("key2", "finishBundle")), - WindowedValue.valueInGlobalWindow(KV.of("key", "d")), - WindowedValue.valueInGlobalWindow(KV.of("key2", "finishBundle")))); + WindowedValue.valueInGlobalWindow("finishBundle"), + WindowedValue.valueInGlobalWindow("d"), + WindowedValue.valueInGlobalWindow("finishBundle"))); testHarness.close(); }