Skip to content
Merged
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.examples.cookbook;

import java.util.Map;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.io.GenerateSequence;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.state.MapState;
import org.apache.beam.sdk.state.StateSpec;
import org.apache.beam.sdk.state.StateSpecs;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.WithKeys;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.joda.time.Duration;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
public class MapClassIntegrationIT {

static class MapDoFn extends DoFn<KV<String, Long>, Void> {
@StateId("mapState")
private final StateSpec<MapState<String, Long>> mapStateSpec = StateSpecs.map();

@ProcessElement
public void processElement(
@Element KV<String, Long> element, @StateId("mapState") MapState<String, Long> mapState) {
mapState.put(Long.toString(element.getValue() % 100), element.getValue());
if (element.getValue() % 1000 == 0) {
Iterable<Map.Entry<String, Long>> entries = mapState.entries().read();
if (entries != null) {
System.err.println("ENTRIES " + Iterables.toString(entries));
} else {
System.err.println("ENTRIES IS NULL");
}
}
}
}

@Test
public void testDataflowMapState() {
PipelineOptions options = TestPipeline.testingPipelineOptions();
Pipeline p = Pipeline.create(options);
p.apply(
"GenerateSequence",
GenerateSequence.from(0).withRate(1000, Duration.standardSeconds(1)))
.apply("WithKeys", WithKeys.of("key"))
.apply("MapState", ParDo.of(new MapDoFn()));
p.run();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.util.NavigableMap;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.beam.runners.core.StateTag.StateBinder;
Expand Down Expand Up @@ -55,6 +56,9 @@
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
import org.checkerframework.checker.initialization.qual.Initialized;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.checker.nullness.qual.UnknownKeyFor;
import org.joda.time.Instant;

/**
Expand Down Expand Up @@ -641,7 +645,23 @@ public void clear() {

@Override
public ReadableState<V> get(K key) {
return ReadableStates.immediate(contents.get(key));
return getOrDefault(key, null);
}

@Override
public @UnknownKeyFor @NonNull @Initialized ReadableState<V> getOrDefault(
K key, @Nullable V defaultValue) {
return new ReadableState<V>() {
@Override
public @org.checkerframework.checker.nullness.qual.Nullable V read() {
return contents.getOrDefault(key, defaultValue);
}

@Override
public @UnknownKeyFor @NonNull @Initialized ReadableState<V> readLater() {
return this;
}
};
}

@Override
Expand All @@ -650,10 +670,11 @@ public void put(K key, V value) {
}

@Override
public ReadableState<V> putIfAbsent(K key, V value) {
public ReadableState<V> computeIfAbsent(
K key, Function<? super K, ? extends V> mappingFunction) {
V v = contents.get(key);
if (v == null) {
v = contents.put(key, value);
v = contents.put(key, mappingFunction.apply(key));
}

return ReadableStates.immediate(v);
Expand Down Expand Up @@ -701,6 +722,23 @@ public ReadableState<Iterable<Map.Entry<K, V>>> entries() {
return CollectionViewState.of(contents.entrySet());
}

@Override
public @UnknownKeyFor @NonNull @Initialized ReadableState<
@UnknownKeyFor @NonNull @Initialized Boolean>
isEmpty() {
return new ReadableState<Boolean>() {
@Override
public @org.checkerframework.checker.nullness.qual.Nullable Boolean read() {
return contents.isEmpty();
}

@Override
public @UnknownKeyFor @NonNull @Initialized ReadableState<Boolean> readLater() {
return this;
}
};
}

@Override
public boolean isCleared() {
return contents.isEmpty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,12 @@ public static <InputT, AccumT, OutputT> StateTag<BagState<AccumT>> convertToBagT
StateSpecs.convertToBagSpecInternal(combiningTag.getSpec()));
}

public static <KeyT> StateTag<MapState<KeyT, Boolean>> convertToMapTagInternal(
StateTag<SetState<KeyT>> setTag) {
return new SimpleStateTag<>(
new StructuredId(setTag.getId()), StateSpecs.convertToMapSpecInternal(setTag.getSpec()));
}

private static class StructuredId implements Serializable {
private final StateKind kind;
private final String rawId;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ public void testMapReadable() throws Exception {
// test get
ReadableState<Integer> get = value.get("B");
value.put("B", 2);
assertNull(get.read());
assertThat(get.read(), equalTo(2));

// test addIfAbsent
value.putIfAbsent("C", 3);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Stream;
import javax.annotation.Nonnull;
import org.apache.beam.runners.core.StateInternals;
Expand Down Expand Up @@ -72,7 +73,10 @@
import org.apache.flink.runtime.state.KeyedStateBackend;
import org.apache.flink.runtime.state.VoidNamespace;
import org.apache.flink.runtime.state.VoidNamespaceSerializer;
import org.checkerframework.checker.initialization.qual.Initialized;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.checkerframework.checker.nullness.qual.UnknownKeyFor;
import org.joda.time.Instant;

/**
Expand Down Expand Up @@ -1033,15 +1037,32 @@ private static class FlinkMapState<KeyT, ValueT> implements MapState<KeyT, Value

@Override
public ReadableState<ValueT> get(final KeyT input) {
try {
return ReadableStates.immediate(
flinkStateBackend
.getPartitionedState(
namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor)
.get(input));
} catch (Exception e) {
throw new RuntimeException("Error get from state.", e);
}
return getOrDefault(input, null);
}

@Override
public @UnknownKeyFor @NonNull @Initialized ReadableState<ValueT> getOrDefault(
KeyT key, @Nullable ValueT defaultValue) {
return new ReadableState<ValueT>() {
@Override
public @Nullable ValueT read() {
try {
ValueT value =
flinkStateBackend
.getPartitionedState(
namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor)
.get(key);
return (value != null) ? value : defaultValue;
} catch (Exception e) {
throw new RuntimeException("Error get from state.", e);
}
}

@Override
public @UnknownKeyFor @NonNull @Initialized ReadableState<ValueT> readLater() {
return this;
}
};
}

@Override
Expand All @@ -1057,7 +1078,8 @@ public void put(KeyT key, ValueT value) {
}

@Override
public ReadableState<ValueT> putIfAbsent(final KeyT key, final ValueT value) {
public ReadableState<ValueT> computeIfAbsent(
final KeyT key, Function<? super KeyT, ? extends ValueT> mappingFunction) {
try {
ValueT current =
flinkStateBackend
Expand All @@ -1069,7 +1091,7 @@ public ReadableState<ValueT> putIfAbsent(final KeyT key, final ValueT value) {
flinkStateBackend
.getPartitionedState(
namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor)
.put(key, value);
.put(key, mappingFunction.apply(key));
}
return ReadableStates.immediate(current);
} catch (Exception e) {
Expand Down Expand Up @@ -1161,6 +1183,25 @@ public ReadableState<Iterable<Map.Entry<KeyT, ValueT>>> readLater() {
};
}

@Override
public @UnknownKeyFor @NonNull @Initialized ReadableState<
@UnknownKeyFor @NonNull @Initialized Boolean>
isEmpty() {
ReadableState<Iterable<KeyT>> keys = this.keys();
return new ReadableState<Boolean>() {
@Override
public @Nullable Boolean read() {
return Iterables.isEmpty(keys.read());
}

@Override
public @UnknownKeyFor @NonNull @Initialized ReadableState<Boolean> readLater() {
keys.readLater();
return this;
}
};
}

@Override
public void clear() {
try {
Expand Down
2 changes: 0 additions & 2 deletions runners/google-cloud-dataflow-java/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,6 @@ def commonLegacyExcludeCategories = [
'org.apache.beam.sdk.testing.UsesCrossLanguageTransforms',
'org.apache.beam.sdk.testing.UsesDistributionMetrics',
'org.apache.beam.sdk.testing.UsesGaugeMetrics',
'org.apache.beam.sdk.testing.UsesSetState',
'org.apache.beam.sdk.testing.UsesMapState',
'org.apache.beam.sdk.testing.UsesSplittableParDoWithWindowedSideInputs',
'org.apache.beam.sdk.testing.UsesTestStream',
'org.apache.beam.sdk.testing.UsesParDoLifecycle',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,13 @@ ParDo.SingleOutput<KV<K, InputT>, OutputT> getOriginalParDo() {
public PCollection<OutputT> expand(PCollection<KV<K, InputT>> input) {
DoFn<KV<K, InputT>, OutputT> fn = originalParDo.getFn();
verifyFnIsStateful(fn);
DataflowRunner.verifyDoFnSupportedBatch(fn);
DataflowPipelineOptions options =
input.getPipeline().getOptions().as(DataflowPipelineOptions.class);
DataflowRunner.verifyDoFnSupported(
fn,
false,
DataflowRunner.useUnifiedWorker(options),
DataflowRunner.useStreamingEngine(options));
DataflowRunner.verifyStateSupportForWindowingStrategy(input.getWindowingStrategy());

if (isFnApi) {
Expand Down Expand Up @@ -209,7 +215,13 @@ static class StatefulMultiOutputParDo<K, InputT, OutputT>
public PCollectionTuple expand(PCollection<KV<K, InputT>> input) {
DoFn<KV<K, InputT>, OutputT> fn = originalParDo.getFn();
verifyFnIsStateful(fn);
DataflowRunner.verifyDoFnSupportedBatch(fn);
DataflowPipelineOptions options =
input.getPipeline().getOptions().as(DataflowPipelineOptions.class);
DataflowRunner.verifyDoFnSupported(
fn,
false,
DataflowRunner.useUnifiedWorker(options),
DataflowRunner.useStreamingEngine(options));
DataflowRunner.verifyStateSupportForWindowingStrategy(input.getWindowingStrategy());

if (isFnApi) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1252,7 +1252,12 @@ private static void translateFn(

boolean isStateful = DoFnSignatures.isStateful(fn);
if (isStateful) {
DataflowRunner.verifyDoFnSupported(fn, context.getPipelineOptions().isStreaming());
DataflowPipelineOptions options = context.getPipelineOptions();
DataflowRunner.verifyDoFnSupported(
fn,
options.isStreaming(),
DataflowRunner.useUnifiedWorker(options),
DataflowRunner.useStreamingEngine(options));
DataflowRunner.verifyStateSupportForWindowingStrategy(windowingStrategy);
}

Expand Down
Loading