From a420191e04e9d730b6fc504310d50e7f73e1d21b Mon Sep 17 00:00:00 2001 From: Sam Whittle Date: Thu, 6 Feb 2025 15:46:48 +0100 Subject: [PATCH 1/6] Add a utility class to enable sharing across all deserialized instances of a DoFn and use it in UnboundedSourceAsSdfWrapperFn to cache Readers across dofn instances --- .../java/org/apache/beam/sdk/io/Read.java | 41 +++--- .../beam/sdk/util/PerSerializationStatic.java | 58 ++++++++ .../java/org/apache/beam/sdk/io/ReadTest.java | 7 +- .../sdk/util/PerSerializationStaticTest.java | 126 ++++++++++++++++++ 4 files changed, 208 insertions(+), 24 deletions(-) create mode 100644 sdks/java/core/src/main/java/org/apache/beam/sdk/util/PerSerializationStatic.java create mode 100644 sdks/java/core/src/test/java/org/apache/beam/sdk/util/PerSerializationStaticTest.java diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java index 901856b93dae..c5a087366e53 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java @@ -56,6 +56,7 @@ import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimators; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.NameUtils; +import org.apache.beam.sdk.util.PerSerializationStatic; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; @@ -481,12 +482,31 @@ static class UnboundedSourceAsSDFWrapperFn checkpointCoder; - private @Nullable Cache> cachedReaders; + private final PerSerializationStatic>> cachedReaders; private @Nullable Coder> restrictionCoder; @VisibleForTesting UnboundedSourceAsSDFWrapperFn(Coder checkpointCoder) { this.checkpointCoder = checkpointCoder; + cachedReaders = + new PerSerializationStatic<>( + () -> + CacheBuilder.newBuilder() + .expireAfterWrite(1, TimeUnit.MINUTES) + .maximumSize(100) + .removalListener( + (RemovalListener>) + removalNotification -> { + if (removalNotification.wasEvicted()) { + try { + Preconditions.checkNotNull(removalNotification.getValue()) + .close(); + } catch (IOException e) { + LOG.warn("Failed to close UnboundedReader.", e); + } + } + }) + .build()); } @GetInitialRestriction @@ -498,22 +518,6 @@ public UnboundedSourceRestriction initialRestriction( @Setup public void setUp() throws Exception { restrictionCoder = restrictionCoder(); - cachedReaders = - CacheBuilder.newBuilder() - .expireAfterWrite(1, TimeUnit.MINUTES) - .maximumSize(100) - .removalListener( - (RemovalListener>) - removalNotification -> { - if (removalNotification.wasEvicted()) { - try { - Preconditions.checkNotNull(removalNotification.getValue()).close(); - } catch (IOException e) { - LOG.warn("Failed to close UnboundedReader.", e); - } - } - }) - .build(); } @SplitRestriction @@ -556,7 +560,8 @@ public void splitRestriction( PipelineOptions pipelineOptions) { Coder> restrictionCoder = checkStateNotNull(this.restrictionCoder); - Cache> cachedReaders = checkStateNotNull(this.cachedReaders); + Cache> cachedReaders = + checkStateNotNull(this.cachedReaders.get()); return new UnboundedSourceAsSDFRestrictionTracker<>( restriction, pipelineOptions, cachedReaders, restrictionCoder); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/PerSerializationStatic.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/PerSerializationStatic.java new file mode 100644 index 000000000000..e53648047f1b --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/PerSerializationStatic.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.util; + +import java.io.Serializable; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; +import javax.annotation.Nullable; +import org.checkerframework.checker.nullness.qual.MonotonicNonNull; +import org.checkerframework.checker.nullness.qual.NonNull; + +/** + * An object that simplifies having a variable that behaves like a static object but which is scoped + * to deserialized instances. + * + *

In particular this can be useful for use within a DoFn class to maintain shared state across + * all instances of the DoFn that are the same step in the graph. This differs from a static + * variable which would be shared across all instances of the DoFn and a non-static variable which + * is per instance. + */ +public class PerSerializationStatic implements Serializable { + private static final AtomicInteger idGenerator = new AtomicInteger(); + private final int id; + + private static final ConcurrentHashMap staticCache = new ConcurrentHashMap<>(); + private final SerializableSupplier<@NonNull T> supplier; + private transient volatile @MonotonicNonNull T value; + + public PerSerializationStatic(SerializableSupplier<@NonNull T> supplier) { + id = idGenerator.incrementAndGet(); + this.supplier = supplier; + } + + @SuppressWarnings("unchecked") + public T get() { + @Nullable T result = value; + if (result != null) { + return result; + } + @Nullable T mapValue = (T) staticCache.computeIfAbsent(id, ignored -> supplier.get()); + return value = Preconditions.checkStateNotNull(mapValue); + } +} diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/ReadTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/ReadTest.java index aa528c4f08f4..c6e570cdfcfb 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/ReadTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/ReadTest.java @@ -179,12 +179,7 @@ public void testUnboundedSdfWrapperCacheStartedReaders() { // read is default. ExperimentalOptions.addExperiment( pipeline.getOptions().as(ExperimentalOptions.class), "use_sdf_read"); - // Force the pipeline to run with one thread to ensure the reader will be reused on one DoFn - // instance. - // We are not able to use DirectOptions because of circular dependency. - pipeline - .runWithAdditionalOptionArgs(ImmutableList.of("--targetParallelism=1")) - .waitUntilFinish(); + pipeline.run().waitUntilFinish(); } @Test diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/PerSerializationStaticTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/PerSerializationStaticTest.java new file mode 100644 index 000000000000..c95dbf436a7d --- /dev/null +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/PerSerializationStaticTest.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.util; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertSame; + +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class PerSerializationStaticTest { + + @SuppressWarnings("unchecked") + @Test + public void testSharedAcrossDeserialize() throws Exception { + PerSerializationStatic instance = + new PerSerializationStatic<>(AtomicInteger::new); + SerializableUtils.ensureSerializable(instance); + + AtomicInteger i = instance.get(); + i.set(10); + assertSame(i, instance.get()); + + byte[] serialized = SerializableUtils.serializeToByteArray(instance); + PerSerializationStatic deserialized1 = + (PerSerializationStatic) + SerializableUtils.deserializeFromByteArray(serialized, "instance"); + assertSame(i, deserialized1.get()); + + PerSerializationStatic deserialized2 = + (PerSerializationStatic) + SerializableUtils.deserializeFromByteArray(serialized, "instance"); + assertSame(i, deserialized2.get()); + assertEquals(10, i.get()); + } + + @Test + public void testDifferentInstancesSeparate() throws Exception { + PerSerializationStatic instance = + new PerSerializationStatic<>(AtomicInteger::new); + SerializableUtils.ensureSerializable(instance); + AtomicInteger i = instance.get(); + i.set(10); + assertSame(i, instance.get()); + + PerSerializationStatic instance2 = + new PerSerializationStatic<>(AtomicInteger::new); + SerializableUtils.ensureSerializable(instance2); + AtomicInteger j = instance2.get(); + j.set(20); + assertSame(j, instance2.get()); + assertNotSame(j, i); + + PerSerializationStatic instance1clone = SerializableUtils.clone(instance); + assertSame(instance1clone.get(), i); + PerSerializationStatic instance2clone = SerializableUtils.clone(instance2); + assertSame(instance2clone.get(), j); + } + + @SuppressWarnings("unchecked") + @Test + public void testDifferentInstancesSeparateNoGetBeforeSerialization() throws Exception { + PerSerializationStatic instance = + new PerSerializationStatic<>(AtomicInteger::new); + SerializableUtils.ensureSerializable(instance); + + PerSerializationStatic instance2 = + new PerSerializationStatic<>(AtomicInteger::new); + SerializableUtils.ensureSerializable(instance2); + + byte[] serialized = SerializableUtils.serializeToByteArray(instance); + PerSerializationStatic deserialized1 = + (PerSerializationStatic) + SerializableUtils.deserializeFromByteArray(serialized, "instance"); + PerSerializationStatic deserialized2 = + (PerSerializationStatic) + SerializableUtils.deserializeFromByteArray(serialized, "instance"); + assertSame(deserialized1.get(), deserialized2.get()); + + PerSerializationStatic instance2clone = SerializableUtils.clone(instance2); + assertNotSame(instance2clone.get(), deserialized1.get()); + } + + @Test + public void testDifferentTypes() throws Exception { + PerSerializationStatic instance = + new PerSerializationStatic<>(AtomicInteger::new); + SerializableUtils.ensureSerializable(instance); + AtomicInteger i = instance.get(); + i.set(10); + assertSame(i, instance.get()); + + PerSerializationStatic> instance2 = + new PerSerializationStatic<>(ConcurrentHashMap::new); + SerializableUtils.ensureSerializable(instance2); + ConcurrentHashMap j = instance2.get(); + j.put(1, 100); + assertSame(j, instance2.get()); + + PerSerializationStatic instance1clone = SerializableUtils.clone(instance); + assertSame(instance1clone.get(), i); + PerSerializationStatic> instance2clone = + SerializableUtils.clone(instance2); + assertSame(instance2clone.get(), j); + } +} From 755170c8a7245c326199c6076ded1d50256761fc Mon Sep 17 00:00:00 2001 From: Sam Whittle Date: Thu, 13 Feb 2025 12:55:32 +0100 Subject: [PATCH 2/6] rename and remove size limit --- .../java/org/apache/beam/sdk/io/Read.java | 15 ++--- ...PerInstantiationSerializableSupplier.java} | 21 ++++--- ...nstantiationSerializableSupplierTest.java} | 62 ++++++++++--------- 3 files changed, 54 insertions(+), 44 deletions(-) rename sdks/java/core/src/main/java/org/apache/beam/sdk/util/{PerSerializationStatic.java => MemoizingPerInstantiationSerializableSupplier.java} (66%) rename sdks/java/core/src/test/java/org/apache/beam/sdk/util/{PerSerializationStaticTest.java => MemoizingPerInstantiationSerializableSupplierTest.java} (58%) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java index c5a087366e53..264239e26611 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java @@ -51,12 +51,11 @@ import org.apache.beam.sdk.transforms.splittabledofn.ManualWatermarkEstimator; import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker.HasProgress; -import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker.Progress; import org.apache.beam.sdk.transforms.splittabledofn.SplitResult; import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimators; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.util.MemoizingPerInstantiationSerializableSupplier; import org.apache.beam.sdk.util.NameUtils; -import org.apache.beam.sdk.util.PerSerializationStatic; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; @@ -482,18 +481,19 @@ static class UnboundedSourceAsSDFWrapperFn checkpointCoder; - private final PerSerializationStatic>> cachedReaders; + private final MemoizingPerInstantiationSerializableSupplier< + Cache>> + cachedReaders; private @Nullable Coder> restrictionCoder; @VisibleForTesting UnboundedSourceAsSDFWrapperFn(Coder checkpointCoder) { this.checkpointCoder = checkpointCoder; cachedReaders = - new PerSerializationStatic<>( + new MemoizingPerInstantiationSerializableSupplier<>( () -> CacheBuilder.newBuilder() .expireAfterWrite(1, TimeUnit.MINUTES) - .maximumSize(100) .removalListener( (RemovalListener>) removalNotification -> { @@ -845,10 +845,11 @@ private static class UnboundedSourceAsSDFRestrictionTracker< implements HasProgress { private final UnboundedSourceRestriction initialRestriction; private final PipelineOptions pipelineOptions; + private final Cache> cachedReaders; + private final Coder> restrictionCoder; + private UnboundedSource.@Nullable UnboundedReader currentReader; private boolean readerHasBeenStarted; - private Cache> cachedReaders; - private Coder> restrictionCoder; UnboundedSourceAsSDFRestrictionTracker( UnboundedSourceRestriction initialRestriction, diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/PerSerializationStatic.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/MemoizingPerInstantiationSerializableSupplier.java similarity index 66% rename from sdks/java/core/src/main/java/org/apache/beam/sdk/util/PerSerializationStatic.java rename to sdks/java/core/src/main/java/org/apache/beam/sdk/util/MemoizingPerInstantiationSerializableSupplier.java index e53648047f1b..a6a09d93b818 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/PerSerializationStatic.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/MemoizingPerInstantiationSerializableSupplier.java @@ -20,20 +20,24 @@ import java.io.Serializable; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; import javax.annotation.Nullable; import org.checkerframework.checker.nullness.qual.MonotonicNonNull; import org.checkerframework.checker.nullness.qual.NonNull; /** - * An object that simplifies having a variable that behaves like a static object but which is scoped - * to deserialized instances. + * A supplier that memoizes within an instantiation across serialization/deserialization. * - *

In particular this can be useful for use within a DoFn class to maintain shared state across - * all instances of the DoFn that are the same step in the graph. This differs from a static - * variable which would be shared across all instances of the DoFn and a non-static variable which - * is per instance. + *

Specifically the wrapped supplier will be called once and the result memoized - per instance + * for instances created via new - once per group of instances deserialized from the same serialized + * instance. + * + *

A particular use for this is within a DoFn class to maintain shared state across all instances + * of the DoFn that correspond to same step in the graph but separate from other steps in the graph + * using the same DoFn. This differs from a static variable which would be shared across all + * instances of the DoFn and a non-static variable which is per instance. */ -public class PerSerializationStatic implements Serializable { +public class MemoizingPerInstantiationSerializableSupplier implements Serializable, Supplier { private static final AtomicInteger idGenerator = new AtomicInteger(); private final int id; @@ -41,11 +45,12 @@ public class PerSerializationStatic implements Serializable { private final SerializableSupplier<@NonNull T> supplier; private transient volatile @MonotonicNonNull T value; - public PerSerializationStatic(SerializableSupplier<@NonNull T> supplier) { + public MemoizingPerInstantiationSerializableSupplier(SerializableSupplier<@NonNull T> supplier) { id = idGenerator.incrementAndGet(); this.supplier = supplier; } + @Override @SuppressWarnings("unchecked") public T get() { @Nullable T result = value; diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/PerSerializationStaticTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/MemoizingPerInstantiationSerializableSupplierTest.java similarity index 58% rename from sdks/java/core/src/test/java/org/apache/beam/sdk/util/PerSerializationStaticTest.java rename to sdks/java/core/src/test/java/org/apache/beam/sdk/util/MemoizingPerInstantiationSerializableSupplierTest.java index c95dbf436a7d..216682276ecf 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/PerSerializationStaticTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/MemoizingPerInstantiationSerializableSupplierTest.java @@ -28,13 +28,13 @@ import org.junit.runners.JUnit4; @RunWith(JUnit4.class) -public class PerSerializationStaticTest { +public class MemoizingPerInstantiationSerializableSupplierTest { @SuppressWarnings("unchecked") @Test public void testSharedAcrossDeserialize() throws Exception { - PerSerializationStatic instance = - new PerSerializationStatic<>(AtomicInteger::new); + MemoizingPerInstantiationSerializableSupplier instance = + new MemoizingPerInstantiationSerializableSupplier<>(AtomicInteger::new); SerializableUtils.ensureSerializable(instance); AtomicInteger i = instance.get(); @@ -42,13 +42,13 @@ public void testSharedAcrossDeserialize() throws Exception { assertSame(i, instance.get()); byte[] serialized = SerializableUtils.serializeToByteArray(instance); - PerSerializationStatic deserialized1 = - (PerSerializationStatic) + MemoizingPerInstantiationSerializableSupplier deserialized1 = + (MemoizingPerInstantiationSerializableSupplier) SerializableUtils.deserializeFromByteArray(serialized, "instance"); assertSame(i, deserialized1.get()); - PerSerializationStatic deserialized2 = - (PerSerializationStatic) + MemoizingPerInstantiationSerializableSupplier deserialized2 = + (MemoizingPerInstantiationSerializableSupplier) SerializableUtils.deserializeFromByteArray(serialized, "instance"); assertSame(i, deserialized2.get()); assertEquals(10, i.get()); @@ -56,71 +56,75 @@ public void testSharedAcrossDeserialize() throws Exception { @Test public void testDifferentInstancesSeparate() throws Exception { - PerSerializationStatic instance = - new PerSerializationStatic<>(AtomicInteger::new); + MemoizingPerInstantiationSerializableSupplier instance = + new MemoizingPerInstantiationSerializableSupplier<>(AtomicInteger::new); SerializableUtils.ensureSerializable(instance); AtomicInteger i = instance.get(); i.set(10); assertSame(i, instance.get()); - PerSerializationStatic instance2 = - new PerSerializationStatic<>(AtomicInteger::new); + MemoizingPerInstantiationSerializableSupplier instance2 = + new MemoizingPerInstantiationSerializableSupplier<>(AtomicInteger::new); SerializableUtils.ensureSerializable(instance2); AtomicInteger j = instance2.get(); j.set(20); assertSame(j, instance2.get()); assertNotSame(j, i); - PerSerializationStatic instance1clone = SerializableUtils.clone(instance); + MemoizingPerInstantiationSerializableSupplier instance1clone = + SerializableUtils.clone(instance); assertSame(instance1clone.get(), i); - PerSerializationStatic instance2clone = SerializableUtils.clone(instance2); + MemoizingPerInstantiationSerializableSupplier instance2clone = + SerializableUtils.clone(instance2); assertSame(instance2clone.get(), j); } @SuppressWarnings("unchecked") @Test public void testDifferentInstancesSeparateNoGetBeforeSerialization() throws Exception { - PerSerializationStatic instance = - new PerSerializationStatic<>(AtomicInteger::new); + MemoizingPerInstantiationSerializableSupplier instance = + new MemoizingPerInstantiationSerializableSupplier<>(AtomicInteger::new); SerializableUtils.ensureSerializable(instance); - PerSerializationStatic instance2 = - new PerSerializationStatic<>(AtomicInteger::new); + MemoizingPerInstantiationSerializableSupplier instance2 = + new MemoizingPerInstantiationSerializableSupplier<>(AtomicInteger::new); SerializableUtils.ensureSerializable(instance2); byte[] serialized = SerializableUtils.serializeToByteArray(instance); - PerSerializationStatic deserialized1 = - (PerSerializationStatic) + MemoizingPerInstantiationSerializableSupplier deserialized1 = + (MemoizingPerInstantiationSerializableSupplier) SerializableUtils.deserializeFromByteArray(serialized, "instance"); - PerSerializationStatic deserialized2 = - (PerSerializationStatic) + MemoizingPerInstantiationSerializableSupplier deserialized2 = + (MemoizingPerInstantiationSerializableSupplier) SerializableUtils.deserializeFromByteArray(serialized, "instance"); assertSame(deserialized1.get(), deserialized2.get()); - PerSerializationStatic instance2clone = SerializableUtils.clone(instance2); + MemoizingPerInstantiationSerializableSupplier instance2clone = + SerializableUtils.clone(instance2); assertNotSame(instance2clone.get(), deserialized1.get()); } @Test public void testDifferentTypes() throws Exception { - PerSerializationStatic instance = - new PerSerializationStatic<>(AtomicInteger::new); + MemoizingPerInstantiationSerializableSupplier instance = + new MemoizingPerInstantiationSerializableSupplier<>(AtomicInteger::new); SerializableUtils.ensureSerializable(instance); AtomicInteger i = instance.get(); i.set(10); assertSame(i, instance.get()); - PerSerializationStatic> instance2 = - new PerSerializationStatic<>(ConcurrentHashMap::new); + MemoizingPerInstantiationSerializableSupplier> instance2 = + new MemoizingPerInstantiationSerializableSupplier<>(ConcurrentHashMap::new); SerializableUtils.ensureSerializable(instance2); ConcurrentHashMap j = instance2.get(); j.put(1, 100); assertSame(j, instance2.get()); - PerSerializationStatic instance1clone = SerializableUtils.clone(instance); + MemoizingPerInstantiationSerializableSupplier instance1clone = + SerializableUtils.clone(instance); assertSame(instance1clone.get(), i); - PerSerializationStatic> instance2clone = - SerializableUtils.clone(instance2); + MemoizingPerInstantiationSerializableSupplier> + instance2clone = SerializableUtils.clone(instance2); assertSame(instance2clone.get(), j); } } From 43c67ccb0e6ad1154fc03b4b354e864d43e2e43f Mon Sep 17 00:00:00 2001 From: Sam Whittle Date: Thu, 13 Feb 2025 13:00:02 +0100 Subject: [PATCH 3/6] rename member, use specific interface --- .../core/src/main/java/org/apache/beam/sdk/io/Read.java | 6 +++--- .../util/MemoizingPerInstantiationSerializableSupplier.java | 4 +--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java index 264239e26611..ad14a97a46da 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java @@ -483,13 +483,13 @@ static class UnboundedSourceAsSDFWrapperFn checkpointCoder; private final MemoizingPerInstantiationSerializableSupplier< Cache>> - cachedReaders; + readerCacheSupplier; private @Nullable Coder> restrictionCoder; @VisibleForTesting UnboundedSourceAsSDFWrapperFn(Coder checkpointCoder) { this.checkpointCoder = checkpointCoder; - cachedReaders = + readerCacheSupplier = new MemoizingPerInstantiationSerializableSupplier<>( () -> CacheBuilder.newBuilder() @@ -561,7 +561,7 @@ public void splitRestriction( Coder> restrictionCoder = checkStateNotNull(this.restrictionCoder); Cache> cachedReaders = - checkStateNotNull(this.cachedReaders.get()); + checkStateNotNull(this.readerCacheSupplier.get()); return new UnboundedSourceAsSDFRestrictionTracker<>( restriction, pipelineOptions, cachedReaders, restrictionCoder); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/MemoizingPerInstantiationSerializableSupplier.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/MemoizingPerInstantiationSerializableSupplier.java index a6a09d93b818..94297e073acb 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/MemoizingPerInstantiationSerializableSupplier.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/MemoizingPerInstantiationSerializableSupplier.java @@ -17,10 +17,8 @@ */ package org.apache.beam.sdk.util; -import java.io.Serializable; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; -import java.util.function.Supplier; import javax.annotation.Nullable; import org.checkerframework.checker.nullness.qual.MonotonicNonNull; import org.checkerframework.checker.nullness.qual.NonNull; @@ -37,7 +35,7 @@ * using the same DoFn. This differs from a static variable which would be shared across all * instances of the DoFn and a non-static variable which is per instance. */ -public class MemoizingPerInstantiationSerializableSupplier implements Serializable, Supplier { +public class MemoizingPerInstantiationSerializableSupplier implements SerializableSupplier { private static final AtomicInteger idGenerator = new AtomicInteger(); private final int id; From cd927ac53b64051d312dfbbb866575d384c2e703 Mon Sep 17 00:00:00 2001 From: Sam Whittle Date: Thu, 13 Feb 2025 13:02:21 +0100 Subject: [PATCH 4/6] reword comment --- .../util/MemoizingPerInstantiationSerializableSupplier.java | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/MemoizingPerInstantiationSerializableSupplier.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/MemoizingPerInstantiationSerializableSupplier.java index 94297e073acb..6cdb02ae701c 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/MemoizingPerInstantiationSerializableSupplier.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/MemoizingPerInstantiationSerializableSupplier.java @@ -26,9 +26,8 @@ /** * A supplier that memoizes within an instantiation across serialization/deserialization. * - *

Specifically the wrapped supplier will be called once and the result memoized - per instance - * for instances created via new - once per group of instances deserialized from the same serialized - * instance. + *

Specifically the wrapped supplier will be called once and the result memoized per group consisting + * of an instance and all instances deserialized from its serialized state. * *

A particular use for this is within a DoFn class to maintain shared state across all instances * of the DoFn that correspond to same step in the graph but separate from other steps in the graph From 9d19ec286f0d1ecdfbbee2815de921d6af955457 Mon Sep 17 00:00:00 2001 From: Sam Whittle Date: Wed, 19 Feb 2025 12:57:02 +0100 Subject: [PATCH 5/6] Use remove to avoid possible shared unbounded reader. Use executor to close as there were previous problems closing inline with dataflow reader cache. --- .../java/org/apache/beam/sdk/io/Read.java | 29 ++++++++++++------- ...gPerInstantiationSerializableSupplier.java | 4 +-- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java index ad14a97a46da..6da5c6cfcb86 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java @@ -28,6 +28,8 @@ import java.util.Arrays; import java.util.List; import java.util.NoSuchElementException; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderException; @@ -68,7 +70,9 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Cache; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.RemovalCause; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.RemovalListener; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; import org.checkerframework.checker.nullness.qual.EnsuresNonNull; import org.checkerframework.checker.nullness.qual.Nullable; import org.checkerframework.common.value.qual.ArrayLen; @@ -484,6 +488,9 @@ static class UnboundedSourceAsSDFWrapperFn>> readerCacheSupplier; + private static final Executor closeExecutor = + Executors.newCachedThreadPool( + new ThreadFactoryBuilder().setNameFormat("UnboundedReaderCloses-%d").build()); private @Nullable Coder> restrictionCoder; @VisibleForTesting @@ -497,13 +504,16 @@ static class UnboundedSourceAsSDFWrapperFn>) removalNotification -> { - if (removalNotification.wasEvicted()) { - try { - Preconditions.checkNotNull(removalNotification.getValue()) - .close(); - } catch (IOException e) { - LOG.warn("Failed to close UnboundedReader.", e); - } + if (removalNotification.getCause() != RemovalCause.EXPLICIT) { + closeExecutor.execute( + () -> { + try { + Preconditions.checkNotNull(removalNotification.getValue()) + .close(); + } catch (IOException e) { + LOG.warn("Failed to close UnboundedReader.", e); + } + }); } }) .build()); @@ -876,7 +886,8 @@ private void initializeCurrentReader() throws IOException { checkState(currentReader == null); Object cacheKey = createCacheKey(initialRestriction.getSource(), initialRestriction.getCheckpoint()); - UnboundedReader cachedReader = cachedReaders.getIfPresent(cacheKey); + // We remove the reader if cached so that it is not possibly claimed by multiple DoFns. + UnboundedReader cachedReader = cachedReaders.asMap().remove(cacheKey); if (cachedReader == null) { this.currentReader = @@ -885,9 +896,7 @@ private void initializeCurrentReader() throws IOException { .createReader(pipelineOptions, initialRestriction.getCheckpoint()); } else { // If the reader is from cache, then we know that the reader has been started. - // We also remove this cache entry to avoid eviction. readerHasBeenStarted = true; - cachedReaders.invalidate(cacheKey); this.currentReader = cachedReader; } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/MemoizingPerInstantiationSerializableSupplier.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/MemoizingPerInstantiationSerializableSupplier.java index 6cdb02ae701c..b7b62cd24274 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/MemoizingPerInstantiationSerializableSupplier.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/MemoizingPerInstantiationSerializableSupplier.java @@ -26,8 +26,8 @@ /** * A supplier that memoizes within an instantiation across serialization/deserialization. * - *

Specifically the wrapped supplier will be called once and the result memoized per group consisting - * of an instance and all instances deserialized from its serialized state. + *

Specifically the wrapped supplier will be called once and the result memoized per group + * consisting of an instance and all instances deserialized from its serialized state. * *

A particular use for this is within a DoFn class to maintain shared state across all instances * of the DoFn that correspond to same step in the graph but separate from other steps in the graph From 8126dd193d4ab3584351cfcc599890fdb18a6123 Mon Sep 17 00:00:00 2001 From: Sam Whittle Date: Thu, 20 Feb 2025 10:50:00 +0000 Subject: [PATCH 6/6] address comment --- sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java index 6da5c6cfcb86..01af3099a78d 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java @@ -67,7 +67,6 @@ import org.apache.beam.sdk.values.ValueWithRecordId.StripIdsDoFn; import org.apache.beam.sdk.values.ValueWithRecordId.ValueWithRecordIdCoder; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Cache; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.RemovalCause; @@ -508,8 +507,7 @@ static class UnboundedSourceAsSDFWrapperFn { try { - Preconditions.checkNotNull(removalNotification.getValue()) - .close(); + checkStateNotNull(removalNotification.getValue()).close(); } catch (IOException e) { LOG.warn("Failed to close UnboundedReader.", e); }