diff --git a/examples/java/src/test/java/org/apache/beam/examples/cookbook/MapClassIntegrationIT.java b/examples/java/src/test/java/org/apache/beam/examples/cookbook/MapClassIntegrationIT.java new file mode 100644 index 000000000000..05563185b548 --- /dev/null +++ b/examples/java/src/test/java/org/apache/beam/examples/cookbook/MapClassIntegrationIT.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.examples.cookbook; + +import java.util.Map; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.io.GenerateSequence; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.state.MapState; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.StateSpecs; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.WithKeys; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; +import org.joda.time.Duration; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class MapClassIntegrationIT { + + static class MapDoFn extends DoFn, Void> { + @StateId("mapState") + private final StateSpec> mapStateSpec = StateSpecs.map(); + + @ProcessElement + public void processElement( + @Element KV element, @StateId("mapState") MapState mapState) { + mapState.put(Long.toString(element.getValue() % 100), element.getValue()); + if (element.getValue() % 1000 == 0) { + Iterable> entries = mapState.entries().read(); + if (entries != null) { + System.err.println("ENTRIES " + Iterables.toString(entries)); + } else { + System.err.println("ENTRIES IS NULL"); + } + } + } + } + + @Test + public void testDataflowMapState() { + PipelineOptions options = TestPipeline.testingPipelineOptions(); + Pipeline p = Pipeline.create(options); + p.apply( + "GenerateSequence", + GenerateSequence.from(0).withRate(1000, Duration.standardSeconds(1))) + .apply("WithKeys", WithKeys.of("key")) + .apply("MapState", ParDo.of(new MapDoFn())); + p.run(); + } +} diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/InMemoryStateInternals.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/InMemoryStateInternals.java index d69c4807fe62..cd1f9824ac3b 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/InMemoryStateInternals.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/InMemoryStateInternals.java @@ -27,6 +27,7 @@ import java.util.NavigableMap; import java.util.Objects; import java.util.Set; +import java.util.function.Function; import java.util.stream.Collectors; import javax.annotation.Nullable; import org.apache.beam.runners.core.StateTag.StateBinder; @@ -55,6 +56,9 @@ import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; +import org.checkerframework.checker.initialization.qual.Initialized; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.UnknownKeyFor; import org.joda.time.Instant; /** @@ -641,7 +645,23 @@ public void clear() { @Override public ReadableState get(K key) { - return ReadableStates.immediate(contents.get(key)); + return getOrDefault(key, null); + } + + @Override + public @UnknownKeyFor @NonNull @Initialized ReadableState getOrDefault( + K key, @Nullable V defaultValue) { + return new ReadableState() { + @Override + public @org.checkerframework.checker.nullness.qual.Nullable V read() { + return contents.getOrDefault(key, defaultValue); + } + + @Override + public @UnknownKeyFor @NonNull @Initialized ReadableState readLater() { + return this; + } + }; } @Override @@ -650,10 +670,11 @@ public void put(K key, V value) { } @Override - public ReadableState putIfAbsent(K key, V value) { + public ReadableState computeIfAbsent( + K key, Function mappingFunction) { V v = contents.get(key); if (v == null) { - v = contents.put(key, value); + v = contents.put(key, mappingFunction.apply(key)); } return ReadableStates.immediate(v); @@ -701,6 +722,23 @@ public ReadableState>> entries() { return CollectionViewState.of(contents.entrySet()); } + @Override + public @UnknownKeyFor @NonNull @Initialized ReadableState< + @UnknownKeyFor @NonNull @Initialized Boolean> + isEmpty() { + return new ReadableState() { + @Override + public @org.checkerframework.checker.nullness.qual.Nullable Boolean read() { + return contents.isEmpty(); + } + + @Override + public @UnknownKeyFor @NonNull @Initialized ReadableState readLater() { + return this; + } + }; + } + @Override public boolean isCleared() { return contents.isEmpty(); diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java index e185b0a0079e..acee02b84834 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java @@ -239,6 +239,12 @@ public static StateTag> convertToBagT StateSpecs.convertToBagSpecInternal(combiningTag.getSpec())); } + public static StateTag> convertToMapTagInternal( + StateTag> setTag) { + return new SimpleStateTag<>( + new StructuredId(setTag.getId()), StateSpecs.convertToMapSpecInternal(setTag.getSpec())); + } + private static class StructuredId implements Serializable { private final StateKind kind; private final String rawId; diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java index b11083c152a3..4c793d34d99c 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java @@ -632,7 +632,7 @@ public void testMapReadable() throws Exception { // test get ReadableState get = value.get("B"); value.put("B", 2); - assertNull(get.read()); + assertThat(get.read(), equalTo(2)); // test addIfAbsent value.putIfAbsent("C", 3); diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java index aa9c614eda03..610a4568d52e 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java @@ -24,6 +24,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.function.Function; import java.util.stream.Stream; import javax.annotation.Nonnull; import org.apache.beam.runners.core.StateInternals; @@ -72,7 +73,10 @@ import org.apache.flink.runtime.state.KeyedStateBackend; import org.apache.flink.runtime.state.VoidNamespace; import org.apache.flink.runtime.state.VoidNamespaceSerializer; +import org.checkerframework.checker.initialization.qual.Initialized; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.checker.nullness.qual.UnknownKeyFor; import org.joda.time.Instant; /** @@ -1033,15 +1037,32 @@ private static class FlinkMapState implements MapState get(final KeyT input) { - try { - return ReadableStates.immediate( - flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) - .get(input)); - } catch (Exception e) { - throw new RuntimeException("Error get from state.", e); - } + return getOrDefault(input, null); + } + + @Override + public @UnknownKeyFor @NonNull @Initialized ReadableState getOrDefault( + KeyT key, @Nullable ValueT defaultValue) { + return new ReadableState() { + @Override + public @Nullable ValueT read() { + try { + ValueT value = + flinkStateBackend + .getPartitionedState( + namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .get(key); + return (value != null) ? value : defaultValue; + } catch (Exception e) { + throw new RuntimeException("Error get from state.", e); + } + } + + @Override + public @UnknownKeyFor @NonNull @Initialized ReadableState readLater() { + return this; + } + }; } @Override @@ -1057,7 +1078,8 @@ public void put(KeyT key, ValueT value) { } @Override - public ReadableState putIfAbsent(final KeyT key, final ValueT value) { + public ReadableState computeIfAbsent( + final KeyT key, Function mappingFunction) { try { ValueT current = flinkStateBackend @@ -1069,7 +1091,7 @@ public ReadableState putIfAbsent(final KeyT key, final ValueT value) { flinkStateBackend .getPartitionedState( namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) - .put(key, value); + .put(key, mappingFunction.apply(key)); } return ReadableStates.immediate(current); } catch (Exception e) { @@ -1161,6 +1183,25 @@ public ReadableState>> readLater() { }; } + @Override + public @UnknownKeyFor @NonNull @Initialized ReadableState< + @UnknownKeyFor @NonNull @Initialized Boolean> + isEmpty() { + ReadableState> keys = this.keys(); + return new ReadableState() { + @Override + public @Nullable Boolean read() { + return Iterables.isEmpty(keys.read()); + } + + @Override + public @UnknownKeyFor @NonNull @Initialized ReadableState readLater() { + keys.readLater(); + return this; + } + }; + } + @Override public void clear() { try { diff --git a/runners/google-cloud-dataflow-java/build.gradle b/runners/google-cloud-dataflow-java/build.gradle index 197eff83aa43..ebea7ed982c0 100644 --- a/runners/google-cloud-dataflow-java/build.gradle +++ b/runners/google-cloud-dataflow-java/build.gradle @@ -153,8 +153,6 @@ def commonLegacyExcludeCategories = [ 'org.apache.beam.sdk.testing.UsesCrossLanguageTransforms', 'org.apache.beam.sdk.testing.UsesDistributionMetrics', 'org.apache.beam.sdk.testing.UsesGaugeMetrics', - 'org.apache.beam.sdk.testing.UsesSetState', - 'org.apache.beam.sdk.testing.UsesMapState', 'org.apache.beam.sdk.testing.UsesSplittableParDoWithWindowedSideInputs', 'org.apache.beam.sdk.testing.UsesTestStream', 'org.apache.beam.sdk.testing.UsesParDoLifecycle', diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java index e93a328deb7e..6281c9cb2ffd 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java @@ -176,7 +176,13 @@ ParDo.SingleOutput, OutputT> getOriginalParDo() { public PCollection expand(PCollection> input) { DoFn, OutputT> fn = originalParDo.getFn(); verifyFnIsStateful(fn); - DataflowRunner.verifyDoFnSupportedBatch(fn); + DataflowPipelineOptions options = + input.getPipeline().getOptions().as(DataflowPipelineOptions.class); + DataflowRunner.verifyDoFnSupported( + fn, + false, + DataflowRunner.useUnifiedWorker(options), + DataflowRunner.useStreamingEngine(options)); DataflowRunner.verifyStateSupportForWindowingStrategy(input.getWindowingStrategy()); if (isFnApi) { @@ -209,7 +215,13 @@ static class StatefulMultiOutputParDo public PCollectionTuple expand(PCollection> input) { DoFn, OutputT> fn = originalParDo.getFn(); verifyFnIsStateful(fn); - DataflowRunner.verifyDoFnSupportedBatch(fn); + DataflowPipelineOptions options = + input.getPipeline().getOptions().as(DataflowPipelineOptions.class); + DataflowRunner.verifyDoFnSupported( + fn, + false, + DataflowRunner.useUnifiedWorker(options), + DataflowRunner.useStreamingEngine(options)); DataflowRunner.verifyStateSupportForWindowingStrategy(input.getWindowingStrategy()); if (isFnApi) { diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java index 3a75f40cca18..bf301eb916d0 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java @@ -1252,7 +1252,12 @@ private static void translateFn( boolean isStateful = DoFnSignatures.isStateful(fn); if (isStateful) { - DataflowRunner.verifyDoFnSupported(fn, context.getPipelineOptions().isStreaming()); + DataflowPipelineOptions options = context.getPipelineOptions(); + DataflowRunner.verifyDoFnSupported( + fn, + options.isStreaming(), + DataflowRunner.useUnifiedWorker(options), + DataflowRunner.useStreamingEngine(options)); DataflowRunner.verifyStateSupportForWindowingStrategy(windowingStrategy); } diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java index cf3555af169d..5bdc7ae55b3f 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java @@ -2227,35 +2227,47 @@ static boolean useUnifiedWorker(DataflowPipelineOptions options) { return hasExperiment(options, "use_runner_v2") || hasExperiment(options, "use_unified_worker"); } - static void verifyDoFnSupportedBatch(DoFn fn) { - verifyDoFnSupported(fn, false); + static boolean useStreamingEngine(DataflowPipelineOptions options) { + return hasExperiment(options, GcpOptions.STREAMING_ENGINE_EXPERIMENT) + || hasExperiment(options, GcpOptions.WINDMILL_SERVICE_EXPERIMENT); } - static void verifyDoFnSupportedStreaming(DoFn fn) { - verifyDoFnSupported(fn, true); - } - - static void verifyDoFnSupported(DoFn fn, boolean streaming) { - if (DoFnSignatures.usesSetState(fn)) { - // https://issues.apache.org/jira/browse/BEAM-1479 - throw new UnsupportedOperationException( - String.format( - "%s does not currently support %s", - DataflowRunner.class.getSimpleName(), SetState.class.getSimpleName())); - } - if (DoFnSignatures.usesMapState(fn)) { - // https://issues.apache.org/jira/browse/BEAM-1474 - throw new UnsupportedOperationException( - String.format( - "%s does not currently support %s", - DataflowRunner.class.getSimpleName(), MapState.class.getSimpleName())); - } + static void verifyDoFnSupported( + DoFn fn, boolean streaming, boolean workerV2, boolean streamingEngine) { if (streaming && DoFnSignatures.requiresTimeSortedInput(fn)) { throw new UnsupportedOperationException( String.format( "%s does not currently support @RequiresTimeSortedInput in streaming mode.", DataflowRunner.class.getSimpleName())); } + if (DoFnSignatures.usesSetState(fn)) { + if (workerV2) { + throw new UnsupportedOperationException( + String.format( + "%s does not currently support %s when using runner V2", + DataflowRunner.class.getSimpleName(), SetState.class.getSimpleName())); + } + if (streaming && streamingEngine) { + throw new UnsupportedOperationException( + String.format( + "%s does not currently support %s when using streaming engine", + DataflowRunner.class.getSimpleName(), SetState.class.getSimpleName())); + } + } + if (DoFnSignatures.usesMapState(fn)) { + if (workerV2) { + throw new UnsupportedOperationException( + String.format( + "%s does not currently support %s when using runner V2", + DataflowRunner.class.getSimpleName(), MapState.class.getSimpleName())); + } + if (streaming && streamingEngine) { + throw new UnsupportedOperationException( + String.format( + "%s does not currently support %s when using streaming engine", + DataflowRunner.class.getSimpleName(), MapState.class.getSimpleName())); + } + } } static void verifyStateSupportForWindowingStrategy(WindowingStrategy strategy) { diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java index 282fc9580690..16fec840bd41 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java @@ -1417,16 +1417,19 @@ public void process() {} } @Test - public void testMapStateUnsupportedInBatch() throws Exception { + public void testMapStateUnsupportedRunnerV2() throws Exception { PipelineOptions options = buildPipelineOptions(); - options.as(StreamingOptions.class).setStreaming(false); + ExperimentalOptions.addExperiment(options.as(ExperimentalOptions.class), "use_runner_v2"); verifyMapStateUnsupported(options); } @Test - public void testMapStateUnsupportedInStreaming() throws Exception { + public void testMapStateUnsupportedStreamingEngine() throws Exception { PipelineOptions options = buildPipelineOptions(); - options.as(StreamingOptions.class).setStreaming(true); + ExperimentalOptions.addExperiment( + options.as(ExperimentalOptions.class), GcpOptions.STREAMING_ENGINE_EXPERIMENT); + options.as(DataflowPipelineOptions.class).setStreaming(true); + verifyMapStateUnsupported(options); } @@ -1449,17 +1452,19 @@ public void process() {} } @Test - public void testSetStateUnsupportedInBatch() throws Exception { + public void testSetStateUnsupportedRunnerV2() throws Exception { PipelineOptions options = buildPipelineOptions(); - options.as(StreamingOptions.class).setStreaming(false); + ExperimentalOptions.addExperiment(options.as(ExperimentalOptions.class), "use_runner_v2"); Pipeline.create(options); verifySetStateUnsupported(options); } @Test - public void testSetStateUnsupportedInStreaming() throws Exception { + public void testSetStateUnsupportedStreamingEngine() throws Exception { PipelineOptions options = buildPipelineOptions(); - options.as(StreamingOptions.class).setStreaming(true); + ExperimentalOptions.addExperiment( + options.as(ExperimentalOptions.class), GcpOptions.STREAMING_ENGINE_EXPERIMENT); + options.as(DataflowPipelineOptions.class).setStreaming(true); verifySetStateUnsupported(options); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java index 95e190a87781..185e5f6d3b6e 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java @@ -24,6 +24,7 @@ import java.io.OutputStream; import java.io.OutputStreamWriter; import java.nio.charset.StandardCharsets; +import java.util.AbstractMap; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -31,11 +32,14 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.Random; +import java.util.Set; import java.util.SortedSet; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.function.BiConsumer; +import java.util.function.Function; import javax.annotation.concurrent.NotThreadSafe; import org.apache.beam.runners.core.StateInternals; import org.apache.beam.runners.core.StateNamespace; @@ -51,6 +55,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.TagSortedListInsertRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.TagSortedListUpdateRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest; +import org.apache.beam.sdk.coders.BooleanCoder; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.Coder.Context; import org.apache.beam.sdk.coders.CoderException; @@ -94,7 +99,10 @@ import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.TreeRangeSet; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.Futures; +import org.checkerframework.checker.initialization.qual.Initialized; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.checker.nullness.qual.UnknownKeyFor; import org.joda.time.Duration; import org.joda.time.Instant; @@ -166,17 +174,23 @@ public BagState bindBag(StateTag> address, Coder elemCoder @Override public SetState bindSet(StateTag> spec, Coder elemCoder) { - throw new UnsupportedOperationException( - String.format("%s is not supported", SetState.class.getSimpleName())); + WindmillSet result = + new WindmillSet(namespace, spec, stateFamily, elemCoder, cache, isNewKey); + result.initializeForWorkItem(reader, scopedReadStateSupplier); + return result; } @Override public MapState bindMap( - StateTag> spec, - Coder mapKeyCoder, - Coder mapValueCoder) { - throw new UnsupportedOperationException( - String.format("%s is not supported", SetState.class.getSimpleName())); + StateTag> spec, Coder keyCoder, Coder valueCoder) { + WindmillMap result = (WindmillMap) cache.get(namespace, spec); + if (result == null) { + result = + new WindmillMap( + namespace, spec, stateFamily, keyCoder, valueCoder, isNewKey); + } + result.initializeForWorkItem(reader, scopedReadStateSupplier); + return result; } @Override @@ -1075,6 +1089,488 @@ private Future>> getFuture( } } + static class WindmillSet extends SimpleWindmillState implements SetState { + WindmillMap windmillMap; + + WindmillSet( + StateNamespace namespace, + StateTag> address, + String stateFamily, + Coder keyCoder, + WindmillStateCache.ForKey cache, + boolean isNewKey) { + StateTag> internalMapAddress = + StateTags.convertToMapTagInternal(address); + WindmillMap cachedMap = + (WindmillMap) cache.get(namespace, internalMapAddress); + this.windmillMap = + (cachedMap != null) + ? cachedMap + : new WindmillMap<>( + namespace, + internalMapAddress, + stateFamily, + keyCoder, + BooleanCoder.of(), + isNewKey); + } + + @Override + protected WorkItemCommitRequest persistDirectly(ForKey cache) throws IOException { + return windmillMap.persistDirectly(cache); + } + + @Override + public @UnknownKeyFor @NonNull @Initialized ReadableState< + @UnknownKeyFor @NonNull @Initialized Boolean> + contains(K k) { + return windmillMap.getOrDefault(k, false); + } + + @Override + public @UnknownKeyFor @NonNull @Initialized ReadableState< + @UnknownKeyFor @NonNull @Initialized Boolean> + addIfAbsent(K k) { + return new ReadableState() { + ReadableState putState = windmillMap.putIfAbsent(k, true); + + @Override + public @Nullable Boolean read() { + Boolean result = putState.read(); + return (result != null) ? result : false; + } + + @Override + public @UnknownKeyFor @NonNull @Initialized ReadableState readLater() { + putState = putState.readLater(); + return this; + } + }; + } + + @Override + public void remove(K k) { + windmillMap.remove(k); + } + + @Override + public void add(K value) { + windmillMap.put(value, true); + } + + @Override + public @UnknownKeyFor @NonNull @Initialized ReadableState< + @UnknownKeyFor @NonNull @Initialized Boolean> + isEmpty() { + return windmillMap.isEmpty(); + } + + @Override + public @Nullable Iterable read() { + return windmillMap.keys().read(); + } + + @Override + public @UnknownKeyFor @NonNull @Initialized SetState readLater() { + windmillMap.keys().readLater(); + return this; + } + + @Override + public void clear() { + windmillMap.clear(); + } + + @Override + void initializeForWorkItem( + WindmillStateReader reader, Supplier scopedReadStateSupplier) { + windmillMap.initializeForWorkItem(reader, scopedReadStateSupplier); + } + + @Override + void cleanupAfterWorkItem() { + windmillMap.cleanupAfterWorkItem(); + } + } + + static class WindmillMap extends SimpleWindmillState implements MapState { + private final StateNamespace namespace; + private final StateTag> address; + private final ByteString stateKeyPrefix; + private final String stateFamily; + private final Coder keyCoder; + private final Coder valueCoder; + private boolean complete; + + // TODO(reuvenlax): Should we evict items from the cache? We would have to make sure + // that anything in the cache that is not committed is not evicted. negativeCache could be + // evicted whenever we want. + private Map cachedValues = Maps.newHashMap(); + private Set negativeCache = Sets.newHashSet(); + private boolean cleared = false; + + private Set localAdditions = Sets.newHashSet(); + private Set localRemovals = Sets.newHashSet(); + + WindmillMap( + StateNamespace namespace, + StateTag> address, + String stateFamily, + Coder keyCoder, + Coder valueCoder, + boolean isNewKey) { + this.namespace = namespace; + this.address = address; + this.stateKeyPrefix = encodeKey(namespace, address); + this.stateFamily = stateFamily; + this.keyCoder = keyCoder; + this.valueCoder = valueCoder; + this.complete = isNewKey; + } + + private K userKeyFromProtoKey(ByteString tag) throws IOException { + Preconditions.checkState(tag.startsWith(stateKeyPrefix)); + ByteString keyBytes = tag.substring(stateKeyPrefix.size()); + return keyCoder.decode(keyBytes.newInput(), Context.OUTER); + } + + private ByteString protoKeyFromUserKey(K key) throws IOException { + ByteString.Output keyStream = ByteString.newOutput(); + stateKeyPrefix.writeTo(keyStream); + keyCoder.encode(key, keyStream, Context.OUTER); + return keyStream.toByteString(); + } + + @Override + protected WorkItemCommitRequest persistDirectly(ForKey cache) throws IOException { + if (!cleared && localAdditions.isEmpty() && localRemovals.isEmpty()) { + // No changes, so return directly. + return WorkItemCommitRequest.newBuilder().buildPartial(); + } + + WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder(); + + if (cleared) { + commitBuilder + .addTagValuePrefixDeletesBuilder() + .setStateFamily(stateFamily) + .setTagPrefix(stateKeyPrefix); + } + cleared = false; + + for (K key : localAdditions) { + ByteString keyBytes = protoKeyFromUserKey(key); + ByteString.Output valueStream = ByteString.newOutput(); + valueCoder.encode(cachedValues.get(key), valueStream, Context.OUTER); + ByteString valueBytes = valueStream.toByteString(); + + commitBuilder + .addValueUpdatesBuilder() + .setTag(keyBytes) + .setStateFamily(stateFamily) + .getValueBuilder() + .setData(valueBytes) + .setTimestamp(Long.MAX_VALUE); + } + localAdditions.clear(); + + for (K key : localRemovals) { + ByteString.Output keyStream = ByteString.newOutput(); + stateKeyPrefix.writeTo(keyStream); + keyCoder.encode(key, keyStream, Context.OUTER); + ByteString keyBytes = keyStream.toByteString(); + // Leaving data blank means that we delete the tag. + commitBuilder + .addValueUpdatesBuilder() + .setTag(keyBytes) + .setStateFamily(stateFamily) + .getValueBuilder() + .setTimestamp(Long.MAX_VALUE); + + V cachedValue = cachedValues.remove(key); + if (cachedValue != null) { + ByteString.Output valueStream = ByteString.newOutput(); + valueCoder.encode(cachedValues.get(key), valueStream, Context.OUTER); + } + } + negativeCache.addAll(localRemovals); + localRemovals.clear(); + + // TODO(reuvenlax): We should store in the cache parameter, as that would enable caching the + // map + // between work items, reducing fetches to Windmill. To do so, we need keep track of the + // encoded size + // of the map, and to do so efficiently (i.e. without iterating over the entire map on every + // persist) + // we need to track the sizes of each map entry. + cache.put(namespace, address, this, 1); + return commitBuilder.buildPartial(); + } + + @Override + public @UnknownKeyFor @NonNull @Initialized ReadableState get(K key) { + return getOrDefault(key, null); + } + + @Override + public @UnknownKeyFor @NonNull @Initialized ReadableState getOrDefault( + K key, @Nullable V defaultValue) { + return new ReadableState() { + @Override + public @Nullable V read() { + Future persistedData = getFutureForKey(key); + try (Closeable scope = scopedReadState()) { + if (localRemovals.contains(key) || negativeCache.contains(key)) { + return null; + } + @Nullable V cachedValue = cachedValues.get(key); + if (cachedValue != null || complete) { + return cachedValue; + } + + V persistedValue = persistedData.get(); + if (persistedValue == null) { + negativeCache.add(key); + return defaultValue; + } + // TODO: Don't do this if it was already in cache. + cachedValues.put(key, persistedValue); + return persistedValue; + } catch (InterruptedException | ExecutionException | IOException e) { + if (e instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + throw new RuntimeException("Unable to read state", e); + } + } + + @Override + @SuppressWarnings("FutureReturnValueIgnored") + public @UnknownKeyFor @NonNull @Initialized ReadableState readLater() { + WindmillMap.this.getFutureForKey(key); + return this; + } + }; + } + + @Override + public @UnknownKeyFor @NonNull @Initialized ReadableState< + @UnknownKeyFor @NonNull @Initialized Iterable> + keys() { + ReadableState>> entries = entries(); + return new ReadableState>() { + @Override + public @Nullable Iterable read() { + return Iterables.transform(entries.read(), e -> e.getKey()); + } + + @Override + public @UnknownKeyFor @NonNull @Initialized ReadableState> readLater() { + entries.readLater(); + return this; + } + }; + } + + @Override + public @UnknownKeyFor @NonNull @Initialized ReadableState< + @UnknownKeyFor @NonNull @Initialized Iterable> + values() { + ReadableState>> entries = entries(); + return new ReadableState>() { + @Override + public @Nullable Iterable read() { + return Iterables.transform(entries.read(), e -> e.getValue()); + } + + @Override + public @UnknownKeyFor @NonNull @Initialized ReadableState> readLater() { + entries.readLater(); + return this; + } + }; + } + + @Override + public @UnknownKeyFor @NonNull @Initialized ReadableState< + @UnknownKeyFor @NonNull @Initialized Iterable< + @UnknownKeyFor @NonNull @Initialized Entry>> + entries() { + return new ReadableState>>() { + @Override + public Iterable> read() { + if (complete) { + return Iterables.unmodifiableIterable(cachedValues.entrySet()); + } + Future>> persistedData = getFuture(); + try (Closeable scope = scopedReadState()) { + Iterable> data = persistedData.get(); + Iterable> transformedData = + Iterables., Map.Entry>transform( + data, + entry -> { + try { + return new AbstractMap.SimpleEntry<>( + userKeyFromProtoKey(entry.getKey()), entry.getValue()); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + + if (data instanceof Weighted) { + // This is a known amount of data. Cache it all. + transformedData.forEach( + e -> { + // The cached data overrides what is read from state, so call putIfAbsent. + cachedValues.putIfAbsent(e.getKey(), e.getValue()); + }); + complete = true; + return Iterables.unmodifiableIterable(cachedValues.entrySet()); + } else { + // This means that the result might be too large to cache, so don't add it to the + // local cache. Instead merge the iterables, giving priority to any local additions + // (represented in cachedValued and localRemovals) that may not have been committed + // yet. + return Iterables.unmodifiableIterable( + Iterables.concat( + cachedValues.entrySet(), + Iterables.filter( + transformedData, + e -> + !cachedValues.containsKey(e.getKey()) + && !localRemovals.contains(e.getKey())))); + } + + } catch (InterruptedException | ExecutionException | IOException e) { + if (e instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + throw new RuntimeException("Unable to read state", e); + } + } + + @Override + @SuppressWarnings("FutureReturnValueIgnored") + public @UnknownKeyFor @NonNull @Initialized ReadableState>> + readLater() { + WindmillMap.this.getFuture(); + return this; + } + }; + } + + @Override + public ReadableState isEmpty() { + return new ReadableState() { + // TODO(reuvenlax): Can we find a more efficient way of implementing isEmpty than reading + // the entire map? + ReadableState> keys = WindmillMap.this.keys(); + + @Override + public @Nullable Boolean read() { + return Iterables.isEmpty(keys.read()); + } + + @Override + public @UnknownKeyFor @NonNull @Initialized ReadableState readLater() { + keys.readLater(); + return this; + } + }; + } + + @Override + public void put(K key, V value) { + V oldValue = cachedValues.put(key, value); + if (valueCoder.consistentWithEquals() && value.equals(oldValue)) { + return; + } + localAdditions.add(key); + localRemovals.remove(key); + negativeCache.remove(key); + } + + @Override + public @UnknownKeyFor @NonNull @Initialized ReadableState computeIfAbsent( + K key, Function mappingFunction) { + return new ReadableState() { + @Override + public @Nullable V read() { + Future persistedData = getFutureForKey(key); + try (Closeable scope = scopedReadState()) { + if (localRemovals.contains(key) || negativeCache.contains(key)) { + return null; + } + @Nullable V cachedValue = cachedValues.get(key); + if (cachedValue != null || complete) { + return cachedValue; + } + + V persistedValue = persistedData.get(); + if (persistedValue == null) { + // This is a new value. Add it to the map and return null. + put(key, mappingFunction.apply(key)); + return null; + } + // TODO: Don't do this if it was already in cache. + cachedValues.put(key, persistedValue); + return persistedValue; + } catch (InterruptedException | ExecutionException | IOException e) { + if (e instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + throw new RuntimeException("Unable to read state", e); + } + } + + @Override + @SuppressWarnings("FutureReturnValueIgnored") + public @UnknownKeyFor @NonNull @Initialized ReadableState readLater() { + WindmillMap.this.getFutureForKey(key); + return this; + } + }; + } + + @Override + public void remove(K key) { + if (localRemovals.add(key)) { + cachedValues.remove(key); + localAdditions.remove(key); + } + } + + @Override + public void clear() { + cachedValues.clear(); + localAdditions.clear(); + localRemovals.clear(); + negativeCache.clear(); + cleared = true; + complete = true; + } + + private Future getFutureForKey(K key) { + try { + ByteString.Output keyStream = ByteString.newOutput(); + stateKeyPrefix.writeTo(keyStream); + keyCoder.encode(key, keyStream, Context.OUTER); + return reader.valueFuture(keyStream.toByteString(), stateFamily, valueCoder); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private Future>> getFuture() { + if (complete) { + // The caller will merge in local cached values. + return Futures.immediateFuture(Collections.emptyList()); + } else { + return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder); + } + } + }; + private static class WindmillBag extends SimpleWindmillState implements BagState { private final StateNamespace namespace; diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateReader.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateReader.java index 3ef2d7feea0e..62237417208f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateReader.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateReader.java @@ -23,12 +23,14 @@ import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; import java.io.IOException; import java.io.InputStream; +import java.util.AbstractMap; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.HashSet; import java.util.Iterator; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; @@ -38,13 +40,16 @@ import java.util.concurrent.TimeoutException; import javax.annotation.Nonnull; import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.worker.WindmillStateReader.StateTag.Kind; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.SortedListEntry; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.SortedListRange; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.TagBag; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.TagSortedListFetchRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.TagValue; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.TagValuePrefixRequest; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.Coder.Context; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.Weighted; import org.apache.beam.sdk.values.TimestampedValue; @@ -86,6 +91,12 @@ class WindmillStateReader { */ public static final long MAX_ORDERED_LIST_BYTES = 8L << 20; // 8MB + /** + * Ideal maximum bytes in a tag-value prefix response. However, Windmill will always return at + * least one value if possible irrespective of this limit. + */ + public static final long MAX_TAG_VALUE_PREFIX_BYTES = 8L << 20; // 8MB + /** * Ideal maximum bytes in a KeyedGetDataResponse. However, Windmill will always return at least * one value if possible irrespective of this limit. @@ -102,7 +113,8 @@ enum Kind { VALUE, BAG, WATERMARK, - ORDERED_LIST + ORDERED_LIST, + VALUE_PREFIX } abstract Kind getKind(); @@ -112,9 +124,9 @@ enum Kind { abstract String getStateFamily(); /** - * For {@link Kind#BAG, Kind#ORDERED_LIST} kinds: A previous 'continuation_position' returned by - * Windmill to signal the resulting bag was incomplete. Sending that position will request the - * next page of values. Null for first request. + * For {@link Kind#BAG, Kind#ORDERED_LIST, Kind#VALUE_PREFIX} kinds: A previous + * 'continuation_position' returned by Windmill to signal the resulting bag was incomplete. + * Sending that position will request the next page of values. Null for first request. * *

Null for other kinds. */ @@ -197,11 +209,11 @@ public WindmillStateReader( this.workToken = workToken; } - private static final class CoderAndFuture { - private Coder coder; + private static final class CoderAndFuture { + private Coder coder = null; private final SettableFuture future; - private CoderAndFuture(Coder coder, SettableFuture future) { + private CoderAndFuture(Coder coder, SettableFuture future) { this.coder = coder; this.future = future; } @@ -217,11 +229,14 @@ private SettableFuture getNonDoneFuture(StateTag stateTag) { return future; } - private Coder getAndClearCoder() { + private Coder getAndClearCoder() { if (coder == null) { throw new IllegalStateException("Coder has already been cleared from cache"); } - Coder result = coder; + Coder result = (Coder) coder; + if (result == null) { + throw new IllegalStateException("Coder has already been cleared from cache"); + } coder = null; return result; } @@ -236,13 +251,11 @@ private void checkNoCoder() { @VisibleForTesting ConcurrentLinkedQueue> pendingLookups = new ConcurrentLinkedQueue<>(); - private ConcurrentHashMap, CoderAndFuture> waiting = new ConcurrentHashMap<>(); + private ConcurrentHashMap, CoderAndFuture> waiting = new ConcurrentHashMap<>(); - private Future stateFuture( - StateTag stateTag, @Nullable Coder coder) { - CoderAndFuture coderAndFuture = - new CoderAndFuture<>(coder, SettableFuture.create()); - CoderAndFuture existingCoderAndFutureWildcard = + private Future stateFuture(StateTag stateTag, @Nullable Coder coder) { + CoderAndFuture coderAndFuture = new CoderAndFuture<>(coder, SettableFuture.create()); + CoderAndFuture existingCoderAndFutureWildcard = waiting.putIfAbsent(stateTag, coderAndFuture); if (existingCoderAndFutureWildcard == null) { // Schedule a new request. It's response is guaranteed to find the future and coder. @@ -250,17 +263,16 @@ private Future stateFuture( } else { // Piggy-back on the pending or already answered request. @SuppressWarnings("unchecked") - CoderAndFuture existingCoderAndFuture = - (CoderAndFuture) existingCoderAndFutureWildcard; + CoderAndFuture existingCoderAndFuture = + (CoderAndFuture) existingCoderAndFutureWildcard; coderAndFuture = existingCoderAndFuture; } return wrappedFuture(coderAndFuture.getFuture()); } - private CoderAndFuture getWaiting( - StateTag stateTag, boolean shouldRemove) { - CoderAndFuture coderAndFutureWildcard; + private CoderAndFuture getWaiting(StateTag stateTag, boolean shouldRemove) { + CoderAndFuture coderAndFutureWildcard; if (shouldRemove) { coderAndFutureWildcard = waiting.remove(stateTag); } else { @@ -270,8 +282,7 @@ private CoderAndFuture getWaiting( throw new IllegalStateException("Missing future for " + stateTag); } @SuppressWarnings("unchecked") - CoderAndFuture coderAndFuture = - (CoderAndFuture) coderAndFutureWildcard; + CoderAndFuture coderAndFuture = (CoderAndFuture) coderAndFutureWildcard; return coderAndFuture; } @@ -303,18 +314,27 @@ public Future>> orderedListFuture( valuesToPagingIterableFuture(stateTag, elemCoder, this.stateFuture(stateTag, elemCoder))); } + public Future>> valuePrefixFuture( + ByteString prefix, String stateFamily, Coder valueCoder) { + // First request has no continuation position. + StateTag stateTag = + StateTag.of(Kind.VALUE_PREFIX, prefix, stateFamily).toBuilder().build(); + return Preconditions.checkNotNull( + valuesToPagingIterableFuture(stateTag, valueCoder, this.stateFuture(stateTag, valueCoder))); + } + /** * Internal request to fetch the next 'page' of values. Return null if no continuation position is * in {@code contStateTag}, which signals there are no more pages. */ - private @Nullable + private @Nullable Future> continuationFuture( - StateTag contStateTag, Coder elemCoder) { + StateTag contStateTag, Coder coder) { if (contStateTag.getRequestPosition() == null) { // We're done. return null; } - return stateFuture(contStateTag, elemCoder); + return stateFuture(contStateTag, coder); } /** @@ -367,7 +387,7 @@ private Future wrappedFuture(final Future future) { } /** Function to extract an {@link Iterable} from the continuation-supporting page read future. */ - private static class ToIterableFunction + private static class ToIterableFunction implements Function, Iterable> { /** * Reader to request continuation pages from, or {@literal null} if no continuation pages @@ -376,13 +396,13 @@ private static class ToIterableFunction private @Nullable WindmillStateReader reader; private final StateTag stateTag; - private final Coder elemCoder; + private final Coder coder; public ToIterableFunction( - WindmillStateReader reader, StateTag stateTag, Coder elemCoder) { + WindmillStateReader reader, StateTag stateTag, Coder coder) { this.reader = reader; this.stateTag = stateTag; - this.elemCoder = elemCoder; + this.coder = coder; } @SuppressFBWarnings( @@ -407,8 +427,8 @@ public Iterable apply( contStateTag = contStateTag.toBuilder().setSortedListRange(stateTag.getSortedListRange()).build(); } - return new PagingIterable( - reader, valuesAndContPosition.values, contStateTag, elemCoder); + return new PagingIterable( + reader, valuesAndContPosition.values, contStateTag, coder); } } } @@ -417,12 +437,12 @@ public Iterable apply( * Return future which transforms a {@code ValuesAndContPosition} result into the initial * Iterable result expected from the external caller. */ - private Future> valuesToPagingIterableFuture( + private Future> valuesToPagingIterableFuture( final StateTag stateTag, - final Coder elemCoder, + final Coder coder, final Future> future) { Function, Iterable> toIterable = - new ToIterableFunction<>(this, stateTag, elemCoder); + new ToIterableFunction<>(this, stateTag, coder); return Futures.lazyTransform(future, toIterable); } @@ -500,6 +520,18 @@ private Windmill.KeyedGetDataRequest createRequest(Iterable> toFetch .setStateFamily(stateTag.getStateFamily()); break; + case VALUE_PREFIX: + TagValuePrefixRequest.Builder prefixFetchBuilder = + keyedDataBuilder + .addTagValuePrefixesToFetchBuilder() + .setTagPrefix(stateTag.getTag()) + .setStateFamily(stateTag.getStateFamily()) + .setFetchMaxBytes(MAX_TAG_VALUE_PREFIX_BYTES); + if (stateTag.getRequestPosition() != null) { + prefixFetchBuilder.setRequestPosition((ByteString) stateTag.getRequestPosition()); + } + break; + default: throw new RuntimeException("Unknown kind of tag requested: " + stateTag.getKind()); } @@ -583,6 +615,19 @@ private void consumeResponse( } consumeTagValue(value, stateTag); } + for (Windmill.TagValuePrefixResponse prefix_response : response.getTagValuePrefixesList()) { + StateTag stateTag = + StateTag.of( + Kind.VALUE_PREFIX, + prefix_response.getTagPrefix(), + prefix_response.getStateFamily(), + prefix_response.hasRequestPosition() ? prefix_response.getRequestPosition() : null); + if (!toFetch.remove(stateTag)) { + throw new IllegalStateException( + "Received response for unrequested tag " + stateTag + ". Pending tags: " + toFetch); + } + consumeTagPrefixResponse(prefix_response, stateTag); + } for (Windmill.TagSortedListFetchResponse sorted_list : response.getTagSortedListsList()) { SortedListRange sortedListRange = Iterables.getOnlyElement(sorted_list.getFetchRangesList()); Range range = Range.closedOpen(sortedListRange.getStart(), sortedListRange.getLimit()); @@ -680,6 +725,28 @@ private List> sortedListPageValues( return entryList; } + private List> tagPrefixPageTagValues( + Windmill.TagValuePrefixResponse tagValuePrefixResponse, Coder valueCoder) { + if (tagValuePrefixResponse.getTagValuesCount() == 0) { + return new WeightedList<>(Collections.emptyList()); + } + + WeightedList> entryList = + new WeightedList>( + new ArrayList<>(tagValuePrefixResponse.getTagValuesCount())); + for (TagValue entry : tagValuePrefixResponse.getTagValuesList()) { + try { + V value = valueCoder.decode(entry.getValue().getData().newInput(), Context.OUTER); + entryList.addWeighted( + new AbstractMap.SimpleEntry<>(entry.getTag(), value), + entry.getTag().size() + entry.getValue().getData().size()); + } catch (IOException e) { + throw new IllegalStateException("Unable to decode tag value " + e); + } + } + return entryList; + } + private void consumeBag(TagBag bag, StateTag stateTag) { boolean shouldRemove; if (stateTag.getRequestPosition() == null) { @@ -693,11 +760,11 @@ private void consumeBag(TagBag bag, StateTag stateTag) { // continuation positions. shouldRemove = true; } - CoderAndFuture> coderAndFuture = + CoderAndFuture> coderAndFuture = getWaiting(stateTag, shouldRemove); SettableFuture> future = coderAndFuture.getNonDoneFuture(stateTag); - Coder coder = coderAndFuture.getAndClearCoder(); + Coder coder = coderAndFuture.getAndClearCoder(); List values = this.bagPageValues(bag, coder); future.set( new ValuesAndContPosition<>( @@ -705,7 +772,7 @@ private void consumeBag(TagBag bag, StateTag stateTag) { } private void consumeWatermark(Windmill.WatermarkHold watermarkHold, StateTag stateTag) { - CoderAndFuture coderAndFuture = getWaiting(stateTag, false); + CoderAndFuture coderAndFuture = getWaiting(stateTag, false); SettableFuture future = coderAndFuture.getNonDoneFuture(stateTag); // No coders for watermarks coderAndFuture.checkNoCoder(); @@ -725,7 +792,7 @@ private void consumeWatermark(Windmill.WatermarkHold watermarkHold, StateTag void consumeTagValue(TagValue tagValue, StateTag stateTag) { - CoderAndFuture coderAndFuture = getWaiting(stateTag, false); + CoderAndFuture coderAndFuture = getWaiting(stateTag, false); SettableFuture future = coderAndFuture.getNonDoneFuture(stateTag); Coder coder = coderAndFuture.getAndClearCoder(); @@ -744,6 +811,36 @@ private void consumeTagValue(TagValue tagValue, StateTag stateTag) { } } + private void consumeTagPrefixResponse( + Windmill.TagValuePrefixResponse tagValuePrefixResponse, StateTag stateTag) { + boolean shouldRemove; + if (stateTag.getRequestPosition() == null) { + // This is the response for the first page.// Leave the future in the cache so subsequent + // requests for the first page + // can return immediately. + shouldRemove = false; + } else { + // This is a response for a subsequent page. + // Don't cache the future since we may need to make multiple requests with different + // continuation positions. + shouldRemove = true; + } + + CoderAndFuture, ByteString>> coderAndFuture = + getWaiting(stateTag, shouldRemove); + SettableFuture, ByteString>> future = + coderAndFuture.getNonDoneFuture(stateTag); + Coder valueCoder = coderAndFuture.getAndClearCoder(); + List> values = + this.tagPrefixPageTagValues(tagValuePrefixResponse, valueCoder); + future.set( + new ValuesAndContPosition<>( + values, + tagValuePrefixResponse.hasContinuationPosition() + ? tagValuePrefixResponse.getContinuationPosition() + : null)); + } + private void consumeSortedList( Windmill.TagSortedListFetchResponse sortedListFetchResponse, StateTag stateTag) { boolean shouldRemove; @@ -759,7 +856,7 @@ private void consumeSortedList( shouldRemove = true; } - CoderAndFuture, ByteString>> coderAndFuture = + CoderAndFuture, ByteString>> coderAndFuture = getWaiting(stateTag, shouldRemove); SettableFuture, ByteString>> future = coderAndFuture.getNonDoneFuture(stateTag); @@ -772,7 +869,6 @@ private void consumeSortedList( ? sortedListFetchResponse.getContinuationPosition() : null)); } - /** * An iterable over elements backed by paginated GetData requests to Windmill. The iterable may be * iterated over an arbitrary number of times and multiple iterators may be active simultaneously. @@ -789,7 +885,7 @@ private void consumeSortedList( * call to iterator. * */ - private static class PagingIterable implements Iterable { + private static class PagingIterable implements Iterable { /** * The reader we will use for scheduling continuation pages. * @@ -804,17 +900,17 @@ private static class PagingIterable implements It private final StateTag secondPagePos; /** Coder for elements. */ - private final Coder elemCoder; + private final Coder coder; private PagingIterable( WindmillStateReader reader, List firstPage, StateTag secondPagePos, - Coder elemCoder) { + Coder coder) { this.reader = reader; this.firstPage = firstPage; this.secondPagePos = secondPagePos; - this.elemCoder = elemCoder; + this.coder = coder; } @Override @@ -824,7 +920,7 @@ public Iterator iterator() { private StateTag nextPagePos = secondPagePos; private Future> pendingNextPage = // NOTE: The results of continuation page reads are never cached. - reader.continuationFuture(nextPagePos, elemCoder); + reader.continuationFuture(nextPagePos, coder); @Override protected ResultT computeNext() { @@ -854,7 +950,7 @@ protected ResultT computeNext() { valuesAndContPosition.continuationPosition); pendingNextPage = // NOTE: The results of continuation page reads are never cached. - reader.continuationFuture(nextPagePos, elemCoder); + reader.continuationFuture(nextPagePos, coder); } } }; diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternalsTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternalsTest.java index 206cba9cd1ac..e093537899fd 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternalsTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternalsTest.java @@ -22,6 +22,8 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.never; @@ -30,11 +32,16 @@ import com.google.common.collect.Iterables; import java.io.Closeable; +import java.io.IOException; +import java.util.AbstractMap; +import java.util.AbstractMap.SimpleEntry; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Map; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import javax.annotation.Nullable; import org.apache.beam.runners.core.StateNamespace; import org.apache.beam.runners.core.StateNamespaceForTest; import org.apache.beam.runners.core.StateTag; @@ -47,12 +54,14 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.TagSortedListUpdateRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.TagValue; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.Coder.Context; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.state.BagState; import org.apache.beam.sdk.state.CombiningState; import org.apache.beam.sdk.state.GroupingState; +import org.apache.beam.sdk.state.MapState; import org.apache.beam.sdk.state.OrderedListState; import org.apache.beam.sdk.state.ReadableState; import org.apache.beam.sdk.state.ValueState; @@ -185,6 +194,391 @@ private WindmillStateReader.WeightedList weightedList(String... elems) { return result; } + private ByteString protoKeyFromUserKey(@Nullable K tag, Coder keyCoder) + throws IOException { + ByteString.Output keyStream = ByteString.newOutput(); + key(NAMESPACE, "map").writeTo(keyStream); + if (tag != null) { + keyCoder.encode(tag, keyStream, Context.OUTER); + } + return keyStream.toByteString(); + } + + private K userKeyFromProtoKey(ByteString tag, Coder keyCoder) throws IOException { + ByteString keyBytes = tag.substring(key(NAMESPACE, "map").size()); + return keyCoder.decode(keyBytes.newInput(), Context.OUTER); + } + + @Test + public void testMapAddBeforeGet() throws Exception { + StateTag> addr = + StateTags.map("map", StringUtf8Coder.of(), VarIntCoder.of()); + MapState mapState = underTest.state(NAMESPACE, addr); + + final String tag = "tag"; + SettableFuture future = SettableFuture.create(); + when(mockReader.valueFuture( + protoKeyFromUserKey(tag, StringUtf8Coder.of()), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(future); + + ReadableState result = mapState.get("tag"); + result = result.readLater(); + waitAndSet(future, 1, 200); + assertEquals(1, (int) result.read()); + mapState.put("tag", 2); + assertEquals(2, (int) result.read()); + } + + @Test + public void testMapAddClearBeforeGet() throws Exception { + StateTag> addr = + StateTags.map("map", StringUtf8Coder.of(), VarIntCoder.of()); + MapState mapState = underTest.state(NAMESPACE, addr); + + final String tag = "tag"; + + SettableFuture>> prefixFuture = SettableFuture.create(); + when(mockReader.valuePrefixFuture( + protoKeyFromUserKey(null, StringUtf8Coder.of()), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(prefixFuture); + + ReadableState result = mapState.get("tag"); + result = result.readLater(); + waitAndSet( + prefixFuture, + ImmutableList.of( + new AbstractMap.SimpleEntry<>(protoKeyFromUserKey(tag, StringUtf8Coder.of()), 1)), + 50); + assertFalse(mapState.isEmpty().read()); + mapState.clear(); + assertTrue(mapState.isEmpty().read()); + assertNull(mapState.get("tag").read()); + mapState.put("tag", 2); + assertFalse(mapState.isEmpty().read()); + assertEquals(2, (int) result.read()); + } + + @Test + public void testMapLocalAddOverridesStorage() throws Exception { + StateTag> addr = + StateTags.map("map", StringUtf8Coder.of(), VarIntCoder.of()); + MapState mapState = underTest.state(NAMESPACE, addr); + + final String tag = "tag"; + + SettableFuture future = SettableFuture.create(); + when(mockReader.valueFuture( + protoKeyFromUserKey(tag, StringUtf8Coder.of()), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(future); + SettableFuture>> prefixFuture = SettableFuture.create(); + when(mockReader.valuePrefixFuture( + protoKeyFromUserKey(null, StringUtf8Coder.of()), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(prefixFuture); + + waitAndSet(future, 1, 50); + waitAndSet( + prefixFuture, + ImmutableList.of( + new AbstractMap.SimpleEntry<>(protoKeyFromUserKey(tag, StringUtf8Coder.of()), 1)), + 50); + mapState.put(tag, 42); + assertEquals(42, (int) mapState.get(tag).read()); + assertThat( + mapState.entries().read(), + Matchers.containsInAnyOrder(new AbstractMap.SimpleEntry<>(tag, 42))); + } + + @Test + public void testMapLocalRemoveOverridesStorage() throws Exception { + StateTag> addr = + StateTags.map("map", StringUtf8Coder.of(), VarIntCoder.of()); + MapState mapState = underTest.state(NAMESPACE, addr); + + final String tag1 = "tag1"; + final String tag2 = "tag2"; + + SettableFuture future = SettableFuture.create(); + when(mockReader.valueFuture( + protoKeyFromUserKey(tag1, StringUtf8Coder.of()), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(future); + SettableFuture>> prefixFuture = SettableFuture.create(); + when(mockReader.valuePrefixFuture( + protoKeyFromUserKey(null, StringUtf8Coder.of()), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(prefixFuture); + + waitAndSet(future, 1, 50); + waitAndSet( + prefixFuture, + ImmutableList.of( + new AbstractMap.SimpleEntry<>(protoKeyFromUserKey(tag1, StringUtf8Coder.of()), 1), + new AbstractMap.SimpleEntry<>(protoKeyFromUserKey(tag2, StringUtf8Coder.of()), 2)), + 50); + mapState.remove(tag1); + assertNull(mapState.get(tag1).read()); + assertThat( + mapState.entries().read(), + Matchers.containsInAnyOrder(new AbstractMap.SimpleEntry<>(tag2, 2))); + + mapState.remove(tag2); + assertTrue(mapState.isEmpty().read()); + } + + @Test + public void testMapLocalClearOverridesStorage() throws Exception { + StateTag> addr = + StateTags.map("map", StringUtf8Coder.of(), VarIntCoder.of()); + MapState mapState = underTest.state(NAMESPACE, addr); + + final String tag1 = "tag1"; + final String tag2 = "tag2"; + + SettableFuture future1 = SettableFuture.create(); + SettableFuture future2 = SettableFuture.create(); + + when(mockReader.valueFuture( + protoKeyFromUserKey(tag1, StringUtf8Coder.of()), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(future1); + when(mockReader.valueFuture( + protoKeyFromUserKey(tag2, StringUtf8Coder.of()), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(future2); + SettableFuture>> prefixFuture = SettableFuture.create(); + when(mockReader.valuePrefixFuture( + protoKeyFromUserKey(null, StringUtf8Coder.of()), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(prefixFuture); + + waitAndSet(future1, 1, 50); + waitAndSet(future2, 2, 50); + waitAndSet( + prefixFuture, + ImmutableList.of( + new AbstractMap.SimpleEntry<>(protoKeyFromUserKey(tag1, StringUtf8Coder.of()), 1), + new AbstractMap.SimpleEntry<>(protoKeyFromUserKey(tag2, StringUtf8Coder.of()), 2)), + 50); + mapState.clear(); + assertNull(mapState.get(tag1).read()); + assertNull(mapState.get(tag2).read()); + assertThat(mapState.entries().read(), Matchers.emptyIterable()); + assertTrue(mapState.isEmpty().read()); + } + + @Test + public void testMapAddBeforeRead() throws Exception { + StateTag> addr = + StateTags.map("map", StringUtf8Coder.of(), VarIntCoder.of()); + MapState mapState = underTest.state(NAMESPACE, addr); + + final String tag1 = "tag1"; + final String tag2 = "tag2"; + final String tag3 = "tag3"; + SettableFuture>> prefixFuture = SettableFuture.create(); + when(mockReader.valuePrefixFuture( + protoKeyFromUserKey(null, StringUtf8Coder.of()), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(prefixFuture); + + ReadableState>> result = mapState.entries(); + result = result.readLater(); + + mapState.put(tag1, 1); + waitAndSet( + prefixFuture, + ImmutableList.of( + new AbstractMap.SimpleEntry<>(protoKeyFromUserKey(tag2, StringUtf8Coder.of()), 2)), + 200); + Iterable> readData = result.read(); + assertThat( + readData, + Matchers.containsInAnyOrder( + new AbstractMap.SimpleEntry<>(tag1, 1), new AbstractMap.SimpleEntry<>(tag2, 2))); + + mapState.put(tag3, 3); + assertThat( + result.read(), + Matchers.containsInAnyOrder( + new AbstractMap.SimpleEntry<>(tag1, 1), + new AbstractMap.SimpleEntry<>(tag2, 2), + new AbstractMap.SimpleEntry<>(tag3, 3))); + } + + @Test + public void testMapPutIfAbsentSucceeds() throws Exception { + StateTag> addr = + StateTags.map("map", StringUtf8Coder.of(), VarIntCoder.of()); + MapState mapState = underTest.state(NAMESPACE, addr); + + final String tag1 = "tag1"; + SettableFuture future = SettableFuture.create(); + when(mockReader.valueFuture( + protoKeyFromUserKey(tag1, StringUtf8Coder.of()), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(future); + waitAndSet(future, null, 50); + + assertNull(mapState.putIfAbsent(tag1, 42).read()); + assertEquals(42, (int) mapState.get(tag1).read()); + } + + @Test + public void testMapPutIfAbsentFails() throws Exception { + StateTag> addr = + StateTags.map("map", StringUtf8Coder.of(), VarIntCoder.of()); + MapState mapState = underTest.state(NAMESPACE, addr); + + final String tag1 = "tag1"; + mapState.put(tag1, 1); + assertEquals(1, (int) mapState.putIfAbsent(tag1, 42).read()); + assertEquals(1, (int) mapState.get(tag1).read()); + + final String tag2 = "tag2"; + SettableFuture future = SettableFuture.create(); + when(mockReader.valueFuture( + protoKeyFromUserKey(tag2, StringUtf8Coder.of()), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(future); + waitAndSet(future, 2, 50); + assertEquals(2, (int) mapState.putIfAbsent(tag2, 42).read()); + assertEquals(2, (int) mapState.get(tag2).read()); + } + + @Test + public void testMapNegativeCache() throws Exception { + StateTag> addr = + StateTags.map("map", StringUtf8Coder.of(), VarIntCoder.of()); + MapState mapState = underTest.state(NAMESPACE, addr); + + final String tag = "tag"; + SettableFuture future = SettableFuture.create(); + when(mockReader.valueFuture( + protoKeyFromUserKey(tag, StringUtf8Coder.of()), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(future); + waitAndSet(future, null, 200); + assertNull(mapState.get(tag).read()); + future.set(42); + assertNull(mapState.get(tag).read()); + } + + private Map.Entry fromTagValue( + TagValue tagValue, Coder keyCoder, Coder valueCoder) { + try { + V value = + !tagValue.getValue().getData().isEmpty() + ? valueCoder.decode(tagValue.getValue().getData().newInput()) + : null; + return new AbstractMap.SimpleEntry<>(userKeyFromProtoKey(tagValue.getTag(), keyCoder), value); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Test + public void testMapAddPersist() throws Exception { + StateTag> addr = + StateTags.map("map", StringUtf8Coder.of(), VarIntCoder.of()); + MapState mapState = underTest.state(NAMESPACE, addr); + + final String tag1 = "tag1"; + final String tag2 = "tag2"; + mapState.put(tag1, 1); + mapState.put(tag2, 2); + + Windmill.WorkItemCommitRequest.Builder commitBuilder = + Windmill.WorkItemCommitRequest.newBuilder(); + underTest.persist(commitBuilder); + + assertEquals(2, commitBuilder.getValueUpdatesCount()); + assertThat( + commitBuilder.getValueUpdatesList().stream() + .map(tv -> fromTagValue(tv, StringUtf8Coder.of(), VarIntCoder.of())) + .collect(Collectors.toList()), + Matchers.containsInAnyOrder(new SimpleEntry<>(tag1, 1), new SimpleEntry<>(tag2, 2))); + } + + @Test + public void testMapRemovePersist() throws Exception { + StateTag> addr = + StateTags.map("map", StringUtf8Coder.of(), VarIntCoder.of()); + MapState mapState = underTest.state(NAMESPACE, addr); + + final String tag1 = "tag1"; + final String tag2 = "tag2"; + mapState.remove(tag1); + mapState.remove(tag2); + + Windmill.WorkItemCommitRequest.Builder commitBuilder = + Windmill.WorkItemCommitRequest.newBuilder(); + underTest.persist(commitBuilder); + + assertEquals(2, commitBuilder.getValueUpdatesCount()); + assertThat( + commitBuilder.getValueUpdatesList().stream() + .map(tv -> fromTagValue(tv, StringUtf8Coder.of(), VarIntCoder.of())) + .collect(Collectors.toList()), + Matchers.containsInAnyOrder(new SimpleEntry<>(tag1, null), new SimpleEntry<>(tag2, null))); + } + + @Test + public void testMapClearPersist() throws Exception { + StateTag> addr = + StateTags.map("map", StringUtf8Coder.of(), VarIntCoder.of()); + MapState mapState = underTest.state(NAMESPACE, addr); + + final String tag1 = "tag1"; + final String tag2 = "tag2"; + mapState.put(tag1, 1); + mapState.put(tag2, 2); + mapState.clear(); + + Windmill.WorkItemCommitRequest.Builder commitBuilder = + Windmill.WorkItemCommitRequest.newBuilder(); + underTest.persist(commitBuilder); + + assertEquals(0, commitBuilder.getValueUpdatesCount()); + assertEquals(1, commitBuilder.getTagValuePrefixDeletesCount()); + System.err.println(commitBuilder); + assertEquals(STATE_FAMILY, commitBuilder.getTagValuePrefixDeletes(0).getStateFamily()); + assertEquals( + protoKeyFromUserKey(null, StringUtf8Coder.of()), + commitBuilder.getTagValuePrefixDeletes(0).getTagPrefix()); + } + + @Test + public void testMapComplexPersist() throws Exception { + StateTag> addr = + StateTags.map("map", StringUtf8Coder.of(), VarIntCoder.of()); + MapState mapState = underTest.state(NAMESPACE, addr); + + final String tag1 = "tag1"; + final String tag2 = "tag2"; + final String tag3 = "tag3"; + final String tag4 = "tag4"; + + mapState.put(tag1, 1); + mapState.clear(); + mapState.put(tag2, 2); + mapState.put(tag3, 3); + mapState.remove(tag2); + mapState.remove(tag4); + + Windmill.WorkItemCommitRequest.Builder commitBuilder = + Windmill.WorkItemCommitRequest.newBuilder(); + underTest.persist(commitBuilder); + assertEquals(1, commitBuilder.getTagValuePrefixDeletesCount()); + assertEquals(STATE_FAMILY, commitBuilder.getTagValuePrefixDeletes(0).getStateFamily()); + assertEquals( + protoKeyFromUserKey(null, StringUtf8Coder.of()), + commitBuilder.getTagValuePrefixDeletes(0).getTagPrefix()); + assertThat( + commitBuilder.getValueUpdatesList().stream() + .map(tv -> fromTagValue(tv, StringUtf8Coder.of(), VarIntCoder.of())) + .collect(Collectors.toList()), + Matchers.containsInAnyOrder( + new SimpleEntry<>(tag3, 3), + new SimpleEntry<>(tag2, null), + new SimpleEntry<>(tag4, null))); + + // Once persist has been called, calling persist again should be a noop. + commitBuilder = Windmill.WorkItemCommitRequest.newBuilder(); + assertEquals(0, commitBuilder.getTagValuePrefixDeletesCount()); + assertEquals(0, commitBuilder.getValueUpdatesCount()); + } + public static final Range FULL_ORDERED_LIST_RANGE = Range.closedOpen(WindmillOrderedList.MIN_TS_MICROS, WindmillOrderedList.MAX_TS_MICROS); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateReaderTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateReaderTest.java index 5ed5f3f2f3d9..bb1281cd566b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateReaderTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateReaderTest.java @@ -23,12 +23,15 @@ import static org.junit.Assert.fail; import java.io.IOException; +import java.util.AbstractMap; +import java.util.Map; import java.util.concurrent.Future; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.SortedListEntry; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.SortedListRange; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.values.TimestampedValue; @@ -54,6 +57,7 @@ "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402) }) public class WindmillStateReaderTest { + private static final StringUtf8Coder STRING_CODER = StringUtf8Coder.of(); private static final VarIntCoder INT_CODER = VarIntCoder.of(); private static final String COMPUTATION = "computation"; @@ -62,6 +66,7 @@ public class WindmillStateReaderTest { private static final long WORK_TOKEN = 5043L; private static final long CONT_POSITION = 1391631351L; + private static final ByteString STATE_KEY_PREFIX = ByteString.copyFromUtf8("key"); private static final ByteString STATE_KEY_1 = ByteString.copyFromUtf8("key1"); private static final ByteString STATE_KEY_2 = ByteString.copyFromUtf8("key2"); private static final String STATE_FAMILY = "family"; @@ -463,6 +468,139 @@ public void testReadSortedListWithContinuations() throws Exception { // NOTE: The future will still contain a reference to the underlying reader. } + @Test + public void testReadTagValuePrefix() throws Exception { + Future>> future = + underTest.valuePrefixFuture(STATE_KEY_PREFIX, STATE_FAMILY, INT_CODER); + Mockito.verifyNoMoreInteractions(mockWindmill); + + Windmill.KeyedGetDataRequest.Builder expectedRequest = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(DATA_KEY) + .setShardingKey(SHARDING_KEY) + .setWorkToken(WORK_TOKEN) + .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES) + .addTagValuePrefixesToFetch( + Windmill.TagValuePrefixRequest.newBuilder() + .setTagPrefix(STATE_KEY_PREFIX) + .setStateFamily(STATE_FAMILY) + .setFetchMaxBytes(WindmillStateReader.MAX_TAG_VALUE_PREFIX_BYTES)); + + Windmill.KeyedGetDataResponse.Builder response = + Windmill.KeyedGetDataResponse.newBuilder() + .setKey(DATA_KEY) + .addTagValuePrefixes( + Windmill.TagValuePrefixResponse.newBuilder() + .setTagPrefix(STATE_KEY_PREFIX) + .setStateFamily(STATE_FAMILY) + .addTagValues( + Windmill.TagValue.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .setValue(intValue(8))) + .addTagValues( + Windmill.TagValue.newBuilder() + .setTag(STATE_KEY_2) + .setStateFamily(STATE_FAMILY) + .setValue(intValue(9)))); + + Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest.build())) + .thenReturn(response.build()); + + Iterable> result = future.get(); + Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest.build()); + Mockito.verifyNoMoreInteractions(mockWindmill); + + assertThat( + result, + Matchers.containsInAnyOrder( + new AbstractMap.SimpleEntry<>(STATE_KEY_1, 8), + new AbstractMap.SimpleEntry<>(STATE_KEY_2, 9))); + + assertNoReader(future); + } + + @Test + public void testReadTagValuePrefixWithContinuations() throws Exception { + Future>> future = + underTest.valuePrefixFuture(STATE_KEY_PREFIX, STATE_FAMILY, INT_CODER); + Mockito.verifyNoMoreInteractions(mockWindmill); + + Windmill.KeyedGetDataRequest.Builder expectedRequest1 = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(DATA_KEY) + .setShardingKey(SHARDING_KEY) + .setWorkToken(WORK_TOKEN) + .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES) + .addTagValuePrefixesToFetch( + Windmill.TagValuePrefixRequest.newBuilder() + .setTagPrefix(STATE_KEY_PREFIX) + .setStateFamily(STATE_FAMILY) + .setFetchMaxBytes(WindmillStateReader.MAX_TAG_VALUE_PREFIX_BYTES)); + + final ByteString CONT = ByteString.copyFrom("CONTINUATION", Charsets.UTF_8); + Windmill.KeyedGetDataResponse.Builder response1 = + Windmill.KeyedGetDataResponse.newBuilder() + .setKey(DATA_KEY) + .addTagValuePrefixes( + Windmill.TagValuePrefixResponse.newBuilder() + .setTagPrefix(STATE_KEY_PREFIX) + .setStateFamily(STATE_FAMILY) + .setContinuationPosition(CONT) + .addTagValues( + Windmill.TagValue.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .setValue(intValue(8)))); + + Windmill.KeyedGetDataRequest.Builder expectedRequest2 = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(DATA_KEY) + .setShardingKey(SHARDING_KEY) + .setWorkToken(WORK_TOKEN) + .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES) + .addTagValuePrefixesToFetch( + Windmill.TagValuePrefixRequest.newBuilder() + .setTagPrefix(STATE_KEY_PREFIX) + .setStateFamily(STATE_FAMILY) + .setRequestPosition(CONT) + .setFetchMaxBytes(WindmillStateReader.MAX_TAG_VALUE_PREFIX_BYTES)); + + Windmill.KeyedGetDataResponse.Builder response2 = + Windmill.KeyedGetDataResponse.newBuilder() + .setKey(DATA_KEY) + .addTagValuePrefixes( + Windmill.TagValuePrefixResponse.newBuilder() + .setTagPrefix(STATE_KEY_PREFIX) + .setStateFamily(STATE_FAMILY) + .setRequestPosition(CONT) + .addTagValues( + Windmill.TagValue.newBuilder() + .setTag(STATE_KEY_2) + .setStateFamily(STATE_FAMILY) + .setValue(intValue(9)))); + + Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest1.build())) + .thenReturn(response1.build()); + Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest2.build())) + .thenReturn(response2.build()); + + Iterable> results = future.get(); + Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest1.build()); + for (Map.Entry unused : results) { + // Iterate over the results to force loading all the pages. + } + Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest2.build()); + Mockito.verifyNoMoreInteractions(mockWindmill); + + assertThat( + results, + Matchers.containsInAnyOrder( + new AbstractMap.SimpleEntry<>(STATE_KEY_1, 8), + new AbstractMap.SimpleEntry<>(STATE_KEY_2, 9))); + // NOTE: The future will still contain a reference to the underlying reader. + } + @Test public void testReadValue() throws Exception { Future future = underTest.valueFuture(STATE_KEY_1, STATE_FAMILY, INT_CODER); diff --git a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto index b0e8bda375ff..fdb37ba06971 100644 --- a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto +++ b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto @@ -87,6 +87,29 @@ message TagValue { optional string state_family = 3; } +message TagValuePrefix { + optional bytes tag_prefix = 1; + optional string state_family = 2; +} + +message TagValuePrefixRequest { + optional bytes tag_prefix = 1; + optional string state_family = 2; + // In request: A previously returned continuation_token from an earlier + // request. Indicates we wish to fetch the next page of values. + // In response: Copied from request. + optional bytes request_position = 3; + optional int64 fetch_max_bytes = 4 [default = 0x6400000]; +} + +message TagValuePrefixResponse { + optional bytes tag_prefix = 1; + optional string state_family = 2; + repeated TagValue tag_values = 3; + optional bytes continuation_position = 4; + optional bytes request_position = 5; +} + message TagBag { optional bytes tag = 1; // In request: All existing items in the list will be deleted. If new values @@ -256,6 +279,7 @@ message KeyedGetDataRequest { required fixed64 work_token = 2; optional fixed64 sharding_key = 6; repeated TagValue values_to_fetch = 3; + repeated TagValuePrefixRequest tag_value_prefixes_to_fetch = 10; repeated TagBag bags_to_fetch = 8; // Must be at most one sorted_list_to_fetch for a given state family and tag. repeated TagSortedListFetchRequest sorted_lists_to_fetch = 9; @@ -286,6 +310,7 @@ message KeyedGetDataResponse { // The response for this key is not populated due to the fetch failing. optional bool failed = 2; repeated TagValue values = 3; + repeated TagValuePrefixResponse tag_value_prefixes = 9; repeated TagBag bags = 6; // There is one TagSortedListFetchResponse per state-family, tag pair. repeated TagSortedListFetchResponse tag_sorted_lists = 8; @@ -351,6 +376,7 @@ message WorkItemCommitRequest { repeated PubSubMessageBundle pubsub_messages = 7; repeated Timer output_timers = 4; repeated TagValue value_updates = 5; + repeated TagValuePrefix tag_value_prefix_deletes = 25; repeated TagBag bag_updates = 18; repeated TagSortedListUpdateRequest sorted_list_updates = 24; repeated Counter counter_updates = 8; diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/SamzaStoreStateInternals.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/SamzaStoreStateInternals.java index cd30d2286dd4..f908468db7a5 100644 --- a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/SamzaStoreStateInternals.java +++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/SamzaStoreStateInternals.java @@ -32,6 +32,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.function.Function; import javax.annotation.Nonnull; import org.apache.beam.runners.core.StateInternals; import org.apache.beam.runners.core.StateInternalsFactory; @@ -73,7 +74,10 @@ import org.apache.samza.storage.kv.Entry; import org.apache.samza.storage.kv.KeyValueIterator; import org.apache.samza.storage.kv.KeyValueStore; +import org.checkerframework.checker.initialization.qual.Initialized; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.checker.nullness.qual.UnknownKeyFor; import org.joda.time.Instant; /** {@link StateInternals} that uses Samza local {@link KeyValueStore} to manage state. */ @@ -620,11 +624,12 @@ public void put(KeyT key, ValueT value) { } @Override - public @Nullable ReadableState putIfAbsent(KeyT key, ValueT value) { + public @Nullable ReadableState computeIfAbsent( + KeyT key, Function mappingFunction) { final ByteArray encodedKey = encodeKey(key); final ValueT current = decodeValue(store.get(encodedKey)); if (current == null) { - put(key, value); + put(key, mappingFunction.apply(key)); } return current == null ? null : ReadableStates.immediate(current); @@ -637,8 +642,24 @@ public void remove(KeyT key) { @Override public ReadableState get(KeyT key) { - ValueT value = decodeValue(store.get(encodeKey(key))); - return ReadableStates.immediate(value); + return getOrDefault(key, null); + } + + @Override + public @UnknownKeyFor @NonNull @Initialized ReadableState getOrDefault( + KeyT key, @Nullable ValueT defaultValue) { + return new ReadableState() { + @Override + public @Nullable ValueT read() { + ValueT value = decodeValue(store.get(encodeKey(key))); + return value != null ? value : defaultValue; + } + + @Override + public @UnknownKeyFor @NonNull @Initialized ReadableState readLater() { + return this; + } + }; } @Override @@ -689,6 +710,25 @@ public ReadableState>> readLater() { }; } + @Override + public @UnknownKeyFor @NonNull @Initialized ReadableState< + @UnknownKeyFor @NonNull @Initialized Boolean> + isEmpty() { + ReadableState> keys = this.keys(); + return new ReadableState() { + @Override + public @Nullable Boolean read() { + return Iterables.isEmpty(keys.read()); + } + + @Override + public @UnknownKeyFor @NonNull @Initialized ReadableState readLater() { + keys.readLater(); + return this; + } + }; + } + @Override public ReadableState>> readIterator() { final ByteArray maxKey = createMaxKey(); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/MapState.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/MapState.java index f5f5dd2b4744..6c05ba8940a7 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/MapState.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/MapState.java @@ -18,6 +18,8 @@ package org.apache.beam.sdk.state; import java.util.Map; +import java.util.function.Function; +import javax.annotation.Nullable; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.annotations.Experimental.Kind; @@ -54,8 +56,37 @@ public interface MapState extends State { *

Changes will not be reflected in the results returned by previous calls to {@link * ReadableState#read} on the results any of the reading methods ({@link #get}, {@link #keys}, * {@link #values}, and {@link #entries}). + * + *

Since the condition is not evaluated until {@link ReadableState#read} is called, a call to + * {@link #putIfAbsent} followed by a call to {@link #remove} followed by a read on the + * putIfAbsent return will result in the item being written to the map. Similarly, if there are + * multiple calls to {@link #putIfAbsent} for the same key, precedence will be given to the first + * one on which read is called. + */ + default ReadableState putIfAbsent(K key, V value) { + return computeIfAbsent(key, k -> value); + } + + /** + * A deferred read-followed-by-write. + * + *

When {@code read()} is called on the result or state is committed, it forces a read of the + * map and reconciliation with any pending modifications. + * + *

If the specified key is not already associated with a value (or is mapped to {@code null}) + * associates it with the computed and returns {@code null}, else returns the current value. + * + *

Changes will not be reflected in the results returned by previous calls to {@link + * ReadableState#read} on the results any of the reading methods ({@link #get}, {@link #keys}, + * {@link #values}, and {@link #entries}). + * + *

Since the condition is not evaluated until {@link ReadableState#read} is called, a call to + * {@link #putIfAbsent} followed by a call to {@link #remove} followed by a read on the + * putIfAbsent return will result in the item being written to the map. Similarly, if there are + * multiple calls to {@link #putIfAbsent} for the same key, precedence will be given to the first + * one on which read is called. */ - ReadableState putIfAbsent(K key, V value); + ReadableState computeIfAbsent(K key, Function mappingFunction); /** * Remove the mapping for a key from this map if it is present. @@ -67,7 +98,7 @@ public interface MapState extends State { void remove(K key); /** - * A deferred lookup. + * A deferred lookup, using null values if the item is not found. * *

A user is encouraged to call {@code get} for all relevant keys and call {@code readLater()} * on the results. @@ -77,6 +108,17 @@ public interface MapState extends State { */ ReadableState get(K key); + /** + * A deferred lookup. + * + *

A user is encouraged to call {@code get} for all relevant keys and call {@code readLater()} + * on the results. + * + *

When {@code read()} is called, a particular state implementation is encouraged to perform + * all pending reads in a single batch. + */ + ReadableState getOrDefault(K key, @Nullable V defaultValue); + /** Returns an {@link Iterable} over the keys contained in this map. */ ReadableState> keys(); @@ -85,4 +127,10 @@ public interface MapState extends State { /** Returns an {@link Iterable} over the key-value pairs contained in this map. */ ReadableState>> entries(); + + /** + * Returns a {@link ReadableState} whose {@link ReadableState#read} method will return true if + * this state is empty at the point when that {@link ReadableState#read} call returns. + */ + ReadableState isEmpty(); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java index 000ab549fe11..b812080f5acf 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java @@ -23,6 +23,7 @@ import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.annotations.Experimental.Kind; import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.coders.BooleanCoder; import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderRegistry; @@ -320,6 +321,25 @@ public static StateSpec> convertToBag } } + /** + * For internal use only; no backwards-compatibility guarantees. + * + *

Convert a set state spec to a map-state spec. + */ + @Internal + public static StateSpec> convertToMapSpecInternal( + StateSpec> setStateSpec) { + if (setStateSpec instanceof SetStateSpec) { + // Checked above; conversion to a map spec depends on the provided spec being one of those + // created via the factory methods in this class. + @SuppressWarnings("unchecked") + SetStateSpec typedSpec = (SetStateSpec) setStateSpec; + return typedSpec.asMapSpec(); + } else { + throw new IllegalArgumentException("Unexpected StateSpec " + setStateSpec); + } + } + /** * A specification for a state cell holding a settable value of type {@code T}. * @@ -773,6 +793,10 @@ public boolean equals(@Nullable Object obj) { public int hashCode() { return Objects.hash(getClass(), elemCoder); } + + private StateSpec> asMapSpec() { + return new MapStateSpec<>(this.elemCoder, BooleanCoder.of()); + } } /**