diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java index 901856b93dae..01af3099a78d 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java @@ -28,6 +28,8 @@ import java.util.Arrays; import java.util.List; import java.util.NoSuchElementException; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderException; @@ -51,10 +53,10 @@ import org.apache.beam.sdk.transforms.splittabledofn.ManualWatermarkEstimator; import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker.HasProgress; -import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker.Progress; import org.apache.beam.sdk.transforms.splittabledofn.SplitResult; import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimators; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.util.MemoizingPerInstantiationSerializableSupplier; import org.apache.beam.sdk.util.NameUtils; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.values.PBegin; @@ -65,10 +67,11 @@ import org.apache.beam.sdk.values.ValueWithRecordId.StripIdsDoFn; import org.apache.beam.sdk.values.ValueWithRecordId.ValueWithRecordIdCoder; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Cache; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.RemovalCause; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.RemovalListener; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; import org.checkerframework.checker.nullness.qual.EnsuresNonNull; import org.checkerframework.checker.nullness.qual.Nullable; import org.checkerframework.common.value.qual.ArrayLen; @@ -481,12 +484,37 @@ static class UnboundedSourceAsSDFWrapperFn checkpointCoder; - private @Nullable Cache> cachedReaders; + private final MemoizingPerInstantiationSerializableSupplier< + Cache>> + readerCacheSupplier; + private static final Executor closeExecutor = + Executors.newCachedThreadPool( + new ThreadFactoryBuilder().setNameFormat("UnboundedReaderCloses-%d").build()); private @Nullable Coder> restrictionCoder; @VisibleForTesting UnboundedSourceAsSDFWrapperFn(Coder checkpointCoder) { this.checkpointCoder = checkpointCoder; + readerCacheSupplier = + new MemoizingPerInstantiationSerializableSupplier<>( + () -> + CacheBuilder.newBuilder() + .expireAfterWrite(1, TimeUnit.MINUTES) + .removalListener( + (RemovalListener>) + removalNotification -> { + if (removalNotification.getCause() != RemovalCause.EXPLICIT) { + closeExecutor.execute( + () -> { + try { + checkStateNotNull(removalNotification.getValue()).close(); + } catch (IOException e) { + LOG.warn("Failed to close UnboundedReader.", e); + } + }); + } + }) + .build()); } @GetInitialRestriction @@ -498,22 +526,6 @@ public UnboundedSourceRestriction initialRestriction( @Setup public void setUp() throws Exception { restrictionCoder = restrictionCoder(); - cachedReaders = - CacheBuilder.newBuilder() - .expireAfterWrite(1, TimeUnit.MINUTES) - .maximumSize(100) - .removalListener( - (RemovalListener>) - removalNotification -> { - if (removalNotification.wasEvicted()) { - try { - Preconditions.checkNotNull(removalNotification.getValue()).close(); - } catch (IOException e) { - LOG.warn("Failed to close UnboundedReader.", e); - } - } - }) - .build(); } @SplitRestriction @@ -556,7 +568,8 @@ public void splitRestriction( PipelineOptions pipelineOptions) { Coder> restrictionCoder = checkStateNotNull(this.restrictionCoder); - Cache> cachedReaders = checkStateNotNull(this.cachedReaders); + Cache> cachedReaders = + checkStateNotNull(this.readerCacheSupplier.get()); return new UnboundedSourceAsSDFRestrictionTracker<>( restriction, pipelineOptions, cachedReaders, restrictionCoder); } @@ -840,10 +853,11 @@ private static class UnboundedSourceAsSDFRestrictionTracker< implements HasProgress { private final UnboundedSourceRestriction initialRestriction; private final PipelineOptions pipelineOptions; + private final Cache> cachedReaders; + private final Coder> restrictionCoder; + private UnboundedSource.@Nullable UnboundedReader currentReader; private boolean readerHasBeenStarted; - private Cache> cachedReaders; - private Coder> restrictionCoder; UnboundedSourceAsSDFRestrictionTracker( UnboundedSourceRestriction initialRestriction, @@ -870,7 +884,8 @@ private void initializeCurrentReader() throws IOException { checkState(currentReader == null); Object cacheKey = createCacheKey(initialRestriction.getSource(), initialRestriction.getCheckpoint()); - UnboundedReader cachedReader = cachedReaders.getIfPresent(cacheKey); + // We remove the reader if cached so that it is not possibly claimed by multiple DoFns. + UnboundedReader cachedReader = cachedReaders.asMap().remove(cacheKey); if (cachedReader == null) { this.currentReader = @@ -879,9 +894,7 @@ private void initializeCurrentReader() throws IOException { .createReader(pipelineOptions, initialRestriction.getCheckpoint()); } else { // If the reader is from cache, then we know that the reader has been started. - // We also remove this cache entry to avoid eviction. readerHasBeenStarted = true; - cachedReaders.invalidate(cacheKey); this.currentReader = cachedReader; } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/MemoizingPerInstantiationSerializableSupplier.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/MemoizingPerInstantiationSerializableSupplier.java new file mode 100644 index 000000000000..b7b62cd24274 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/MemoizingPerInstantiationSerializableSupplier.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.util; + +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; +import javax.annotation.Nullable; +import org.checkerframework.checker.nullness.qual.MonotonicNonNull; +import org.checkerframework.checker.nullness.qual.NonNull; + +/** + * A supplier that memoizes within an instantiation across serialization/deserialization. + * + *

Specifically the wrapped supplier will be called once and the result memoized per group + * consisting of an instance and all instances deserialized from its serialized state. + * + *

A particular use for this is within a DoFn class to maintain shared state across all instances + * of the DoFn that correspond to same step in the graph but separate from other steps in the graph + * using the same DoFn. This differs from a static variable which would be shared across all + * instances of the DoFn and a non-static variable which is per instance. + */ +public class MemoizingPerInstantiationSerializableSupplier implements SerializableSupplier { + private static final AtomicInteger idGenerator = new AtomicInteger(); + private final int id; + + private static final ConcurrentHashMap staticCache = new ConcurrentHashMap<>(); + private final SerializableSupplier<@NonNull T> supplier; + private transient volatile @MonotonicNonNull T value; + + public MemoizingPerInstantiationSerializableSupplier(SerializableSupplier<@NonNull T> supplier) { + id = idGenerator.incrementAndGet(); + this.supplier = supplier; + } + + @Override + @SuppressWarnings("unchecked") + public T get() { + @Nullable T result = value; + if (result != null) { + return result; + } + @Nullable T mapValue = (T) staticCache.computeIfAbsent(id, ignored -> supplier.get()); + return value = Preconditions.checkStateNotNull(mapValue); + } +} diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/ReadTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/ReadTest.java index aa528c4f08f4..c6e570cdfcfb 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/ReadTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/ReadTest.java @@ -179,12 +179,7 @@ public void testUnboundedSdfWrapperCacheStartedReaders() { // read is default. ExperimentalOptions.addExperiment( pipeline.getOptions().as(ExperimentalOptions.class), "use_sdf_read"); - // Force the pipeline to run with one thread to ensure the reader will be reused on one DoFn - // instance. - // We are not able to use DirectOptions because of circular dependency. - pipeline - .runWithAdditionalOptionArgs(ImmutableList.of("--targetParallelism=1")) - .waitUntilFinish(); + pipeline.run().waitUntilFinish(); } @Test diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/MemoizingPerInstantiationSerializableSupplierTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/MemoizingPerInstantiationSerializableSupplierTest.java new file mode 100644 index 000000000000..216682276ecf --- /dev/null +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/MemoizingPerInstantiationSerializableSupplierTest.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.util; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertSame; + +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class MemoizingPerInstantiationSerializableSupplierTest { + + @SuppressWarnings("unchecked") + @Test + public void testSharedAcrossDeserialize() throws Exception { + MemoizingPerInstantiationSerializableSupplier instance = + new MemoizingPerInstantiationSerializableSupplier<>(AtomicInteger::new); + SerializableUtils.ensureSerializable(instance); + + AtomicInteger i = instance.get(); + i.set(10); + assertSame(i, instance.get()); + + byte[] serialized = SerializableUtils.serializeToByteArray(instance); + MemoizingPerInstantiationSerializableSupplier deserialized1 = + (MemoizingPerInstantiationSerializableSupplier) + SerializableUtils.deserializeFromByteArray(serialized, "instance"); + assertSame(i, deserialized1.get()); + + MemoizingPerInstantiationSerializableSupplier deserialized2 = + (MemoizingPerInstantiationSerializableSupplier) + SerializableUtils.deserializeFromByteArray(serialized, "instance"); + assertSame(i, deserialized2.get()); + assertEquals(10, i.get()); + } + + @Test + public void testDifferentInstancesSeparate() throws Exception { + MemoizingPerInstantiationSerializableSupplier instance = + new MemoizingPerInstantiationSerializableSupplier<>(AtomicInteger::new); + SerializableUtils.ensureSerializable(instance); + AtomicInteger i = instance.get(); + i.set(10); + assertSame(i, instance.get()); + + MemoizingPerInstantiationSerializableSupplier instance2 = + new MemoizingPerInstantiationSerializableSupplier<>(AtomicInteger::new); + SerializableUtils.ensureSerializable(instance2); + AtomicInteger j = instance2.get(); + j.set(20); + assertSame(j, instance2.get()); + assertNotSame(j, i); + + MemoizingPerInstantiationSerializableSupplier instance1clone = + SerializableUtils.clone(instance); + assertSame(instance1clone.get(), i); + MemoizingPerInstantiationSerializableSupplier instance2clone = + SerializableUtils.clone(instance2); + assertSame(instance2clone.get(), j); + } + + @SuppressWarnings("unchecked") + @Test + public void testDifferentInstancesSeparateNoGetBeforeSerialization() throws Exception { + MemoizingPerInstantiationSerializableSupplier instance = + new MemoizingPerInstantiationSerializableSupplier<>(AtomicInteger::new); + SerializableUtils.ensureSerializable(instance); + + MemoizingPerInstantiationSerializableSupplier instance2 = + new MemoizingPerInstantiationSerializableSupplier<>(AtomicInteger::new); + SerializableUtils.ensureSerializable(instance2); + + byte[] serialized = SerializableUtils.serializeToByteArray(instance); + MemoizingPerInstantiationSerializableSupplier deserialized1 = + (MemoizingPerInstantiationSerializableSupplier) + SerializableUtils.deserializeFromByteArray(serialized, "instance"); + MemoizingPerInstantiationSerializableSupplier deserialized2 = + (MemoizingPerInstantiationSerializableSupplier) + SerializableUtils.deserializeFromByteArray(serialized, "instance"); + assertSame(deserialized1.get(), deserialized2.get()); + + MemoizingPerInstantiationSerializableSupplier instance2clone = + SerializableUtils.clone(instance2); + assertNotSame(instance2clone.get(), deserialized1.get()); + } + + @Test + public void testDifferentTypes() throws Exception { + MemoizingPerInstantiationSerializableSupplier instance = + new MemoizingPerInstantiationSerializableSupplier<>(AtomicInteger::new); + SerializableUtils.ensureSerializable(instance); + AtomicInteger i = instance.get(); + i.set(10); + assertSame(i, instance.get()); + + MemoizingPerInstantiationSerializableSupplier> instance2 = + new MemoizingPerInstantiationSerializableSupplier<>(ConcurrentHashMap::new); + SerializableUtils.ensureSerializable(instance2); + ConcurrentHashMap j = instance2.get(); + j.put(1, 100); + assertSame(j, instance2.get()); + + MemoizingPerInstantiationSerializableSupplier instance1clone = + SerializableUtils.clone(instance); + assertSame(instance1clone.get(), i); + MemoizingPerInstantiationSerializableSupplier> + instance2clone = SerializableUtils.clone(instance2); + assertSame(instance2clone.get(), j); + } +}