diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark.json
index f21bd16f4700..ff31524ec3ae 100644
--- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark.json
+++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark.json
@@ -5,5 +5,6 @@
"https://github.com/apache/beam/pull/32546": "noting that PR #32546 should run this test",
"https://github.com/apache/beam/pull/33267": "noting that PR #33267 should run this test",
"https://github.com/apache/beam/pull/33322": "noting that PR #33322 should run this test",
- "https://github.com/apache/beam/pull/34123": "noting that PR #34123 should run this test"
+ "https://github.com/apache/beam/pull/34123": "noting that PR #34123 should run this test",
+ "https://github.com/apache/beam/pull/34080": "noting that PR #34080 should run this test"
}
diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.json
index 87448a1525a0..8f47cf80e792 100644
--- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.json
+++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.json
@@ -4,5 +4,6 @@
"https://github.com/apache/beam/pull/31798": "noting that PR #31798 should run this test",
"https://github.com/apache/beam/pull/32546": "noting that PR #32546 should run this test",
"https://github.com/apache/beam/pull/33267": "noting that PR #33267 should run this test",
- "https://github.com/apache/beam/pull/34123": "noting that PR #34123 should run this test"
+ "https://github.com/apache/beam/pull/34123": "noting that PR #34123 should run this test",
+ "https://github.com/apache/beam/pull/34080": "noting that PR #34080 should run this test"
}
diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark_Java11.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark_Java11.json
index dd2bf3aeb361..e3fd31aabce2 100644
--- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark_Java11.json
+++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark_Java11.json
@@ -4,5 +4,6 @@
"https://github.com/apache/beam/pull/31798": "noting that PR #31798 should run this test",
"https://github.com/apache/beam/pull/32546": "noting that PR #32546 should run this test",
"https://github.com/apache/beam/pull/33267": "noting that PR #33267 should run this test",
- "https://github.com/apache/beam/pull/33322": "noting that PR #33322 should run this test"
+ "https://github.com/apache/beam/pull/33322": "noting that PR #33322 should run this test",
+ "https://github.com/apache/beam/pull/34080": "noting that PR #34080 should run this test"
}
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkPipelineOptions.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkPipelineOptions.java
index 4ad7dd120693..e9c358282132 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkPipelineOptions.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkPipelineOptions.java
@@ -48,6 +48,12 @@ public interface TestSparkPipelineOptions extends SparkPipelineOptions, TestPipe
void setStopPipelineWatermark(Long stopPipelineWatermark);
+ @Description("Whether to delete the checkpoint directory after the pipeline execution.")
+ @Default.Boolean(true)
+ boolean isDeleteCheckpointDir();
+
+ void setDeleteCheckpointDir(boolean deleteCheckpointDir);
+
/**
* A factory to provide the default watermark to stop a pipeline that reads from an unbounded
* source.
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
index e46b6c5f5e0b..22e25e5272a2 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
@@ -106,8 +106,10 @@ public SparkPipelineResult run(Pipeline pipeline) {
isOneOf(PipelineResult.State.STOPPED, PipelineResult.State.DONE));
} finally {
try {
- // cleanup checkpoint dir.
- FileUtils.deleteDirectory(new File(testSparkOptions.getCheckpointDir()));
+ if (testSparkOptions.isDeleteCheckpointDir()) {
+ // cleanup checkpoint dir.
+ FileUtils.deleteDirectory(new File(testSparkOptions.getCheckpointDir()));
+ }
} catch (IOException e) {
throw new RuntimeException("Failed to clear checkpoint tmp dir.", e);
}
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SingleEmitInputDStream.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SingleEmitInputDStream.java
new file mode 100644
index 000000000000..3cde41d2812e
--- /dev/null
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SingleEmitInputDStream.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.translation;
+
+import org.apache.spark.api.java.JavaSparkContext$;
+import org.apache.spark.rdd.RDD;
+import org.apache.spark.streaming.StreamingContext;
+import org.apache.spark.streaming.Time;
+import org.apache.spark.streaming.dstream.ConstantInputDStream;
+import org.apache.spark.streaming.dstream.QueueInputDStream;
+import scala.Option;
+
+/**
+ * A specialized {@link ConstantInputDStream} that emits its RDD exactly once. Alternative to {@link
+ * QueueInputDStream} when checkpointing is required.
+ *
+ *
Features:
+ *
+ *
+ * - Supports checkpointing
+ *
- Guarantees single emission of data
+ *
- Returns empty RDD after first emission
+ *
+ *
+ * @param The type of elements in the RDD
+ */
+public class SingleEmitInputDStream extends ConstantInputDStream {
+
+ private boolean emitted = false;
+
+ public SingleEmitInputDStream(StreamingContext ssc, RDD rdd) {
+ super(ssc, rdd, JavaSparkContext$.MODULE$.fakeClassTag());
+ }
+
+ @Override
+ public Option> compute(Time validTime) {
+ if (this.emitted) {
+ return Option.apply(this.emptyRDD());
+ } else {
+ this.emitted = true;
+ return super.compute(validTime);
+ }
+ }
+
+ private RDD emptyRDD() {
+ return this.context().sparkContext().emptyRDD(JavaSparkContext$.MODULE$.fakeClassTag());
+ }
+}
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkStreamingPortablePipelineTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkStreamingPortablePipelineTranslator.java
index 523dcbad0823..505a91e03b53 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkStreamingPortablePipelineTranslator.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkStreamingPortablePipelineTranslator.java
@@ -30,9 +30,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
-import java.util.Queue;
import java.util.Set;
-import java.util.concurrent.LinkedBlockingQueue;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.runners.fnexecution.provisioning.JobInfo;
import org.apache.beam.runners.spark.SparkPipelineOptions;
@@ -61,10 +59,11 @@
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.JavaSparkContext$;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.streaming.api.java.JavaDStream;
-import org.apache.spark.streaming.api.java.JavaInputDStream;
+import org.apache.spark.streaming.dstream.ConstantInputDStream;
import scala.Tuple2;
import scala.collection.JavaConverters;
@@ -158,15 +157,17 @@ private static void translateImpulse(
.parallelize(CoderHelpers.toByteArrays(windowedValues, windowCoder))
.map(CoderHelpers.fromByteFunction(windowCoder));
- Queue>> rddQueue = new LinkedBlockingQueue<>();
- rddQueue.offer(emptyByteArrayRDD);
- JavaInputDStream> emptyByteArrayStream =
- context.getStreamingContext().queueStream(rddQueue, true /* oneAtATime */);
+ final ConstantInputDStream> inputDStream =
+ new ConstantInputDStream<>(
+ context.getStreamingContext().ssc(),
+ emptyByteArrayRDD.rdd(),
+ JavaSparkContext$.MODULE$.fakeClassTag());
+
+ final JavaDStream> stream =
+ JavaDStream.fromDStream(inputDStream, JavaSparkContext$.MODULE$.fakeClassTag());
UnboundedDataset output =
- new UnboundedDataset<>(
- emptyByteArrayStream,
- Collections.singletonList(emptyByteArrayStream.inputDStream().id()));
+ new UnboundedDataset<>(stream, Collections.singletonList(inputDStream.id()));
// Add watermark to holder and advance to infinity to ensure future watermarks can be updated
GlobalWatermarkHolder.SparkWatermarks sparkWatermark =
@@ -175,7 +176,6 @@ private static void translateImpulse(
BoundedWindow.TIMESTAMP_MAX_VALUE,
context.getFirstTimestamp());
GlobalWatermarkHolder.add(output.getStreamSources().get(0), sparkWatermark);
-
context.pushDataset(getOutputId(transformNode), output);
}
@@ -297,6 +297,7 @@ public void setName(String name) {
}
}
+ @SuppressWarnings("unchecked")
private static void translateFlatten(
PTransformNode transformNode,
RunnerApi.Pipeline pipeline,
@@ -306,9 +307,11 @@ private static void translateFlatten(
List streamSources = new ArrayList<>();
if (inputsMap.isEmpty()) {
- Queue>> q = new LinkedBlockingQueue<>();
- q.offer(context.getSparkContext().emptyRDD());
- unifiedStreams = context.getStreamingContext().queueStream(q);
+ final JavaRDD> emptyRDD = context.getSparkContext().emptyRDD();
+ final SingleEmitInputDStream> singleEmitInputDStream =
+ new SingleEmitInputDStream<>(context.getStreamingContext().ssc(), emptyRDD.rdd());
+ unifiedStreams =
+ JavaDStream.fromDStream(singleEmitInputDStream, JavaSparkContext$.MODULE$.fakeClassTag());
} else {
List>> dStreams = new ArrayList<>();
for (String inputId : inputsMap.values()) {
@@ -319,11 +322,13 @@ private static void translateFlatten(
dStreams.add(unboundedDataset.getDStream());
} else {
// create a single RDD stream.
- Queue>> q = new LinkedBlockingQueue<>();
- q.offer(((BoundedDataset) dataset).getRDD());
- // TODO (https://github.com/apache/beam/issues/20426): this is not recoverable from
- // checkpoint!
- JavaDStream> dStream = context.getStreamingContext().queueStream(q);
+ final SingleEmitInputDStream> singleEmitInputDStream =
+ new SingleEmitInputDStream>(
+ context.getStreamingContext().ssc(), ((BoundedDataset) dataset).getRDD().rdd());
+ final JavaDStream> dStream =
+ JavaDStream.fromDStream(
+ singleEmitInputDStream, JavaSparkContext$.MODULE$.fakeClassTag());
+
dStreams.add(dStream);
}
}
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
index 539f8ff3efe6..e06ef79e483f 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
@@ -43,6 +43,7 @@
import org.apache.beam.runners.spark.translation.EvaluationContext;
import org.apache.beam.runners.spark.translation.GroupCombineFunctions;
import org.apache.beam.runners.spark.translation.MultiDoFnFunction;
+import org.apache.beam.runners.spark.translation.SingleEmitInputDStream;
import org.apache.beam.runners.spark.translation.SparkAssignWindowFn;
import org.apache.beam.runners.spark.translation.SparkCombineFn;
import org.apache.beam.runners.spark.translation.SparkPCollectionView;
@@ -88,6 +89,7 @@
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.JavaSparkContext$;
+import org.apache.spark.streaming.StreamingContext;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaInputDStream;
import org.apache.spark.streaming.api.java.JavaPairDStream;
@@ -291,12 +293,8 @@ public void evaluate(Flatten.PCollections transform, EvaluationContext contex
dStreams.add(unboundedDataset.getDStream());
} else {
// create a single RDD stream.
- Queue>> q = new LinkedBlockingQueue<>();
- q.offer(((BoundedDataset) dataset).getRDD());
- // TODO (https://github.com/apache/beam/issues/20426): this is not recoverable from
- // checkpoint!
- JavaDStream> dStream = context.getStreamingContext().queueStream(q);
- dStreams.add(dStream);
+ dStreams.add(
+ this.buildDStream(context.getStreamingContext().ssc(), (BoundedDataset) dataset));
}
}
// start by unifying streams into a single stream.
@@ -305,6 +303,15 @@ public void evaluate(Flatten.PCollections transform, EvaluationContext contex
context.putDataset(transform, new UnboundedDataset<>(unifiedStreams, streamingSources));
}
+ private JavaDStream> buildDStream(
+ final StreamingContext ssc, final BoundedDataset dataset) {
+
+ final SingleEmitInputDStream> singleEmitDStream =
+ new SingleEmitInputDStream<>(ssc, dataset.getRDD().rdd());
+
+ return JavaDStream.fromDStream(singleEmitDStream, JavaSparkContext$.MODULE$.fakeClassTag());
+ }
+
@Override
public String toNativeString() {
return "streamingContext.union(...)";
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SingleEmitInputDStreamTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SingleEmitInputDStreamTest.java
new file mode 100644
index 000000000000..b13ca1b46c95
--- /dev/null
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SingleEmitInputDStreamTest.java
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.translation;
+
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.empty;
+import static org.hamcrest.Matchers.hasItems;
+import static org.junit.Assert.assertNotNull;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.List;
+import org.apache.beam.runners.spark.SparkContextRule;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.JavaSparkContext$;
+import org.apache.spark.rdd.RDD;
+import org.apache.spark.streaming.Duration;
+import org.apache.spark.streaming.api.java.JavaStreamingContext;
+import org.junit.Before;
+import org.junit.ClassRule;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+import scala.Option;
+
+/**
+ * Tests for {@link SingleEmitInputDStream} class which ensures data is emitted exactly once in
+ * Spark streaming context.
+ */
+public class SingleEmitInputDStreamTest implements Serializable {
+ @ClassRule public static SparkContextRule sparkContext = new SparkContextRule();
+ @Rule public transient TemporaryFolder temporaryFolder = new TemporaryFolder();
+ private final Duration checkpointDuration = new Duration(500L);
+
+ /** Creates a temporary directory for storing checkpoints before each test execution. */
+ @Before
+ public void init() {
+ try {
+ temporaryFolder.create();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Test
+ public void singleEmitInputDStreamShouldBeEmitOnlyOnce() {
+ // Initialize Spark contexts
+ JavaSparkContext jsc = sparkContext.getSparkContext();
+ JavaStreamingContext jssc = new JavaStreamingContext(jsc, checkpointDuration);
+
+ // Create test data and wrap it in SingleEmitInputDStream
+ final SingleEmitInputDStream singleEmitInputDStream = singleEmitInputDSTream(jsc, jssc);
+ singleEmitInputDStream.checkpoint(checkpointDuration);
+
+ // First computation: should return the original data
+ Option> rddOption = singleEmitInputDStream.compute(null);
+ RDD rdd = rddOption.get();
+ JavaRDD javaRDD = JavaRDD.fromRDD(rdd, JavaSparkContext$.MODULE$.fakeClassTag());
+ List collect = javaRDD.collect();
+ assertNotNull(collect);
+ assertThat(collect, hasItems("foo", "bar"));
+
+ // Second computation: should return empty RDD since data was already emitted
+ rddOption = singleEmitInputDStream.compute(null);
+ rdd = rddOption.get();
+ javaRDD = JavaRDD.fromRDD(rdd, JavaSparkContext$.MODULE$.fakeClassTag());
+ collect = javaRDD.collect();
+ assertNotNull(collect);
+ assertThat(collect, empty());
+ }
+
+ private SingleEmitInputDStream singleEmitInputDSTream(
+ JavaSparkContext jsc, JavaStreamingContext jssc) {
+ final JavaRDD stringRDD = jsc.parallelize(Lists.newArrayList("foo", "bar"));
+ final SingleEmitInputDStream singleEmitInputDStream =
+ new SingleEmitInputDStream<>(jssc.ssc(), stringRDD.rdd());
+ return singleEmitInputDStream;
+ }
+
+ @Test
+ public void singleEmitInputDStreamShouldBeEmptyAfterCheckpointRecovery()
+ throws InterruptedException {
+ // Set up checkpoint directory
+ String checkpointPath = temporaryFolder.getRoot().getPath();
+
+ // Initialize Spark contexts with checkpoint configuration
+ JavaSparkContext jsc = sparkContext.getSparkContext();
+ jsc.setCheckpointDir(checkpointPath);
+ JavaStreamingContext jssc = new JavaStreamingContext(jsc, checkpointDuration);
+ jssc.checkpoint(checkpointPath);
+
+ // Create test data and configure SingleEmitInputDStream
+ final SingleEmitInputDStream singleEmitInputDStream = singleEmitInputDSTream(jsc, jssc);
+ singleEmitInputDStream.checkpoint(checkpointDuration);
+
+ // Register output operation required by Spark Streaming
+ singleEmitInputDStream.print();
+
+ // Compute initial RDD and verify original data is present
+ Option> rddOption = singleEmitInputDStream.compute(null);
+ RDD rdd = rddOption.get();
+
+ JavaRDD javaRDD = JavaRDD.fromRDD(rdd, JavaSparkContext$.MODULE$.fakeClassTag());
+ List collect = javaRDD.collect();
+ assertNotNull(collect);
+ assertThat(collect, hasItems("foo", "bar"));
+
+ // Start streaming context to create checkpoint data
+ jssc.start();
+ // Wait for checkpoint to be created and written
+ jssc.awaitTerminationOrTimeout(1000);
+ // Ensure clean shutdown and checkpoint writing
+ jssc.stop(true, true);
+
+ // Recover streaming context from checkpoint
+ JavaStreamingContext recoveredJssc =
+ JavaStreamingContext.getOrCreate(
+ checkpointPath,
+ () -> {
+ throw new RuntimeException(
+ "Should not create new context, should recover from checkpoint");
+ });
+
+ try {
+ // Extract recovered DStream from the restored context
+ @SuppressWarnings("unchecked")
+ SingleEmitInputDStream recoveredDStream =
+ (SingleEmitInputDStream) recoveredJssc.ssc().graph().getInputStreams()[0];
+
+ // Compute RDD from recovered DStream and verify it's empty
+ Option> recoveredRddOption = recoveredDStream.compute(null);
+ RDD recoveredRdd = recoveredRddOption.get();
+ JavaRDD recoveredJavaRdd =
+ JavaRDD.fromRDD(recoveredRdd, JavaSparkContext$.MODULE$.fakeClassTag());
+ List recoveredCollect = recoveredJavaRdd.collect();
+
+ // Verify that recovered DStream produces empty results
+ assertNotNull(recoveredCollect);
+ assertThat(recoveredCollect, empty());
+
+ } finally {
+ // Ensure recovered context is properly cleaned up
+ recoveredJssc.stop(true, true);
+ }
+ }
+}
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslatorTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslatorTest.java
new file mode 100644
index 000000000000..e61f530748c7
--- /dev/null
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslatorTest.java
@@ -0,0 +1,241 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.translation.streaming;
+
+import static org.apache.beam.sdk.metrics.MetricResultsMatchers.attemptedMetricsResult;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.anyOf;
+import static org.hamcrest.Matchers.hasItem;
+
+import java.io.IOException;
+import java.io.Serializable;
+import org.apache.beam.runners.spark.StreamingTest;
+import org.apache.beam.runners.spark.TestSparkPipelineOptions;
+import org.apache.beam.runners.spark.TestSparkRunner;
+import org.apache.beam.runners.spark.UsesCheckpointRecovery;
+import org.apache.beam.runners.spark.io.MicrobatchSource;
+import org.apache.beam.runners.spark.metrics.MetricsAccumulator;
+import org.apache.beam.runners.spark.util.GlobalWatermarkHolder;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.PipelineResult;
+import org.apache.beam.sdk.io.GenerateSequence;
+import org.apache.beam.sdk.metrics.Distribution;
+import org.apache.beam.sdk.metrics.DistributionResult;
+import org.apache.beam.sdk.metrics.MetricNameFilter;
+import org.apache.beam.sdk.metrics.Metrics;
+import org.apache.beam.sdk.metrics.MetricsFilter;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.Flatten;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.WithTimestamps;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Optional;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.junit.rules.TemporaryFolder;
+
+/** Test suite for {@link StreamingTransformTranslator}. */
+public class StreamingTransformTranslatorTest implements Serializable {
+
+ @Rule public transient TemporaryFolder temporaryFolder = new TemporaryFolder();
+ public transient Pipeline p;
+
+ /** Creates a temporary directory for storing checkpoints before each test execution. */
+ @Before
+ public void init() {
+ try {
+ temporaryFolder.create();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ /**
+ * Tests that Flatten transform of Bounded and Unbounded PCollections correctly recovers from
+ * checkpoint.
+ *
+ * Test scenario:
+ *
+ *
+ * - First run:
+ *
+ * - Flattens Bounded PCollection(0-9) with Unbounded PCollection
+ *
- Stops pipeline after 400ms
+ *
- Validates metrics results
+ *
+ * - Second run (recovery from checkpoint):
+ *
+ * - Recovers from previous state and continues execution
+ *
- Stops pipeline after 1 second
+ *
- Validates accumulated metrics results
+ *
+ *
+ */
+ @Category({UsesCheckpointRecovery.class, StreamingTest.class})
+ @Test
+ public void testFlattenPCollResumeFromCheckpoint() {
+ final MetricsFilter metricsFilter =
+ MetricsFilter.builder()
+ .addNameFilter(MetricNameFilter.inNamespace(PAssertFn.class))
+ .build();
+
+ PipelineResult res = run(Optional.of(new Instant(400)), false);
+
+ // Verify metrics for Bounded PCollection (sum of 0-9 = 45, count = 10)
+ assertThat(
+ res.metrics().queryMetrics(metricsFilter).getDistributions(),
+ hasItem(
+ attemptedMetricsResult(
+ PAssertFn.class.getName(),
+ "distribution",
+ "BoundedAssert",
+ DistributionResult.create(45, 10, 0L, 9L))));
+
+ // Verify metrics for Flattened result after first run
+ assertThat(
+ res.metrics().queryMetrics(metricsFilter).getDistributions(),
+ hasItem(
+ attemptedMetricsResult(
+ PAssertFn.class.getName(),
+ "distribution",
+ "FlattenedAssert",
+ DistributionResult.create(45, 10, 0L, 9L))));
+
+ // Clean up state
+ clean();
+
+ // Second run: recover from checkpoint
+ res = runAgain();
+
+ // Verify Bounded PCollection metrics remain the same
+ assertThat(
+ res.metrics().queryMetrics(metricsFilter).getDistributions(),
+ hasItem(
+ attemptedMetricsResult(
+ PAssertFn.class.getName(),
+ "distribution",
+ "BoundedAssert",
+ DistributionResult.create(45, 10, 0L, 9L))));
+
+ // Verify Flattened results show accumulated values from both runs
+ // We use anyOf matcher because the unbounded source may emit either 2 or 3 elements during the
+ // test window:
+ // Case 1 (3 elements): sum=78 (45 from bounded + 33 from unbounded), count=13 (10 bounded + 3
+ // unbounded)
+ // Case 2 (2 elements): sum=66 (45 from bounded + 21 from unbounded), count=12 (10 bounded + 2
+ // unbounded)
+ // This variation occurs because the unbounded source's withRate(3, Duration.standardSeconds(1))
+ // timing may be affected by test environment conditions
+ assertThat(
+ res.metrics().queryMetrics(metricsFilter).getDistributions(),
+ hasItem(
+ anyOf(
+ attemptedMetricsResult(
+ PAssertFn.class.getName(),
+ "distribution",
+ "FlattenedAssert",
+ DistributionResult.create(78, 13, 0, 12)),
+ attemptedMetricsResult(
+ PAssertFn.class.getName(),
+ "distribution",
+ "FlattenedAssert",
+ DistributionResult.create(66, 12, 0, 11)))));
+ }
+
+ /** Restarts the pipeline from checkpoint. Sets pipeline to stop after 1 second. */
+ private PipelineResult runAgain() {
+ return run(
+ Optional.of(
+ Instant.ofEpochMilli(
+ Duration.standardSeconds(1L).plus(Duration.millis(50L)).getMillis())),
+ true);
+ }
+
+ /**
+ * Sets up and runs the test pipeline.
+ *
+ * @param stopWatermarkOption Watermark at which to stop the pipeline
+ * @param deleteCheckpointDir Whether to delete checkpoint directory after completion
+ */
+ private PipelineResult run(Optional stopWatermarkOption, boolean deleteCheckpointDir) {
+ TestSparkPipelineOptions options =
+ PipelineOptionsFactory.create().as(TestSparkPipelineOptions.class);
+ options.setSparkMaster("local[*]");
+ options.setRunner(TestSparkRunner.class);
+ options.setCheckpointDir(temporaryFolder.getRoot().getPath());
+ if (stopWatermarkOption.isPresent()) {
+ options.setStopPipelineWatermark(stopWatermarkOption.get().getMillis());
+ }
+ options.setDeleteCheckpointDir(deleteCheckpointDir);
+
+ p = Pipeline.create(options);
+
+ final PCollection bounded =
+ p.apply("Bounded", GenerateSequence.from(0).to(10))
+ .apply("BoundedAssert", ParDo.of(new PAssertFn()));
+
+ final PCollection unbounded =
+ p.apply("Unbounded", GenerateSequence.from(10).withRate(3, Duration.standardSeconds(1)))
+ .apply(WithTimestamps.of(e -> Instant.now()));
+
+ final PCollection flattened = bounded.apply(Flatten.with(unbounded));
+
+ flattened.apply("FlattenedAssert", ParDo.of(new PAssertFn()));
+ return p.run();
+ }
+
+ /**
+ * Cleans up accumulated state between test runs. Clears metrics, watermarks, and microbatch
+ * source cache.
+ */
+ @After
+ public void clean() {
+ MetricsAccumulator.clear();
+ GlobalWatermarkHolder.clear();
+ MicrobatchSource.clearCache();
+ }
+
+ /**
+ * DoFn that tracks element distribution through metrics. Used to verify correct processing of
+ * elements in both bounded and unbounded streams.
+ */
+ private static class PAssertFn extends DoFn {
+ private final Distribution distribution = Metrics.distribution(PAssertFn.class, "distribution");
+
+ @ProcessElement
+ public void process(@Element Long element, OutputReceiver output) {
+ // For the unbounded source (starting from 10), we expect only 3 elements (10, 11, 12)
+ // to be emitted during the 1-second test window.
+ // However, different execution environments might emit more elements than expected
+ // despite the withRate(3, Duration.standardSeconds(1)) setting.
+ // Therefore, we filter out elements >= 13 to ensure consistent test behavior
+ // across all environments.
+ if (element >= 13L) {
+ return;
+ }
+ distribution.update(element);
+ output.output(element);
+ }
+ }
+}