diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark.json index a9e688973762..64f9b2e34efa 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark.json @@ -7,5 +7,6 @@ "https://github.com/apache/beam/pull/33322": "noting that PR #33322 should run this test", "https://github.com/apache/beam/pull/34123": "noting that PR #34123 should run this test", "https://github.com/apache/beam/pull/34080": "noting that PR #34080 should run this test", - "https://github.com/apache/beam/pull/34155": "noting that PR #34155 should run this test" + "https://github.com/apache/beam/pull/34155": "noting that PR #34155 should run this test", + "https://github.com/apache/beam/pull/34560": "noting that PR #34560 should run this test" } diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark_Java11.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark_Java11.json index c747288cc25b..3a01c4921572 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark_Java11.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark_Java11.json @@ -6,5 +6,6 @@ "https://github.com/apache/beam/pull/33267": "noting that PR #33267 should run this test", "https://github.com/apache/beam/pull/33322": "noting that PR #33322 should run this test", "https://github.com/apache/beam/pull/34080": "noting that PR #34080 should run this test", - "https://github.com/apache/beam/pull/34155": "noting that PR #34155 should run this test" + "https://github.com/apache/beam/pull/34155": "noting that PR #34155 should run this test", + "https://github.com/apache/beam/pull/34560": "noting that PR #34560 should run this test" } diff --git a/CHANGES.md b/CHANGES.md index 8097d48d2935..c1f3471f364f 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -81,6 +81,7 @@ ## Deprecations * X behavior is deprecated and will be removed in X versions ([#X](https://github.com/apache/beam/issues/X)). +* Added support for streaming side-inputs in the Spark Classic runner ([#18136](https://github.com/apache/beam/issues/18136)). * Beam ZetaSQL is deprecated and will be removed no earlier than Beam 2.68.0 ([#34423](https://github.com/apache/beam/issues/34423)). Users are recommended to switch to [Calcite SQL](https://beam.apache.org/documentation/dsls/sql/calcite/overview/) dialect. diff --git a/runners/spark/spark_runner.gradle b/runners/spark/spark_runner.gradle index f33c84d0e14c..037d46a31ed3 100644 --- a/runners/spark/spark_runner.gradle +++ b/runners/spark/spark_runner.gradle @@ -350,7 +350,6 @@ def validatesRunnerStreaming = tasks.register("validatesRunnerStreaming", Test) filter { // UNBOUNDED View.CreatePCollectionView not supported excludeTestsMatching 'org.apache.beam.sdk.transforms.ParDoTest$BundleInvariantsTests.testWatermarkUpdateMidBundle' - excludeTestsMatching 'org.apache.beam.sdk.transforms.ViewTest.testWindowedSideInputNotPresent' // TODO(https://github.com/apache/beam/issues/29973) excludeTestsMatching 'org.apache.beam.sdk.transforms.ReshuffleTest.testReshufflePreservesMetadata' // TODO(https://github.com/apache/beam/issues/31231 diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkTransformOverrides.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkTransformOverrides.java index 5bab8e58098e..f5071e24a209 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkTransformOverrides.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkTransformOverrides.java @@ -18,6 +18,8 @@ package org.apache.beam.runners.spark; import java.util.List; +import org.apache.beam.runners.spark.translation.streaming.CreateStreamingSparkView; +import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.runners.PTransformOverride; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.util.construction.PTransformMatchers; @@ -25,6 +27,7 @@ import org.apache.beam.sdk.util.construction.SplittableParDo; import org.apache.beam.sdk.util.construction.SplittableParDoNaiveBounded; import org.apache.beam.sdk.util.construction.UnsupportedOverrideFactory; +import org.apache.beam.sdk.values.PCollection; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; /** {@link PTransform} overrides for Spark runner. */ @@ -50,6 +53,20 @@ public static List getDefaultOverrides(boolean streaming) { PTransformOverride.of( PTransformMatchers.urnEqualTo(PTransformTranslation.SPLITTABLE_PROCESS_KEYED_URN), new SplittableParDoNaiveBounded.OverrideFactory())); + } else { + builder.add( + PTransformOverride.of( + // For streaming pipelines, this override is applied only when the PTransform has the + // same URN + // as PTransformTranslation.CREATE_VIEW_TRANSFORM and at least one of its inputs is + // UNBOUNDED + PTransformMatchers.urnEqualTo(PTransformTranslation.CREATE_VIEW_TRANSFORM_URN) + .and( + (AppliedPTransform appliedPTransform) -> + appliedPTransform.getInputs().values().stream() + .anyMatch( + e -> e.isBounded().equals(PCollection.IsBounded.UNBOUNDED))), + CreateStreamingSparkView.Factory.INSTANCE)); } return builder.build(); } diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/CoderHelpers.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/CoderHelpers.java index 2725a57f7dc0..c4324479a38c 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/CoderHelpers.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/CoderHelpers.java @@ -26,6 +26,7 @@ import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Collection; +import java.util.Iterator; import java.util.List; import java.util.stream.Collectors; import java.util.stream.StreamSupport; @@ -101,6 +102,23 @@ public static List toByteArrays(Iterable values, Coder coder) return res; } + /** + * Utility method for serializing a Iterator of values using the specified coder. + * + * @param values Values to serialize. + * @param coder Coder to serialize with. + * @param type of value that is serialized + * @return List of bytes representing serialized objects. + */ + public static List toByteArrays(Iterator values, Coder coder) { + List res = new ArrayList<>(); + while (values.hasNext()) { + final T value = values.next(); + res.add(toByteArray(value, coder)); + } + return res; + } + /** * Utility method for deserializing a byte array using the specified coder. * diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java index 6d00eeaeace9..99f2a3e4c360 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java @@ -71,6 +71,7 @@ public class EvaluationContext { new HashMap<>(); private final PipelineOptions options; private final SerializablePipelineOptions serializableOptions; + private boolean streamingSideInput = false; public EvaluationContext(JavaSparkContext jsc, Pipeline pipeline, PipelineOptions options) { this.jsc = jsc; @@ -358,4 +359,31 @@ Iterable> getWindowedValues(PCollection pcollection) { public String storageLevel() { return serializableOptions.get().as(SparkPipelineOptions.class).getStorageLevel(); } + + /** + * Checks if any of the side inputs in the pipeline are streaming side inputs. + * + *

If at least one of the side inputs is a streaming side input, this method returns true. When + * streaming side inputs are present, the {@link + * org.apache.beam.runners.spark.util.CachedSideInputReader} will not be used. + * + * @return true if any of the side inputs in the pipeline are streaming side inputs, false + * otherwise + */ + public boolean isStreamingSideInput() { + return streamingSideInput; + } + + /** + * Marks that the pipeline contains at least one streaming side input. + * + *

When this method is called, it sets the streamingSideInput flag to true, indicating that the + * {@link org.apache.beam.runners.spark.util.CachedSideInputReader} should not be used for + * processing side inputs. + */ + public void useStreamingSideInput() { + if (!this.streamingSideInput) { + this.streamingSideInput = true; + } + } } diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java index df36d24531a6..353cf9d0ab90 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java @@ -31,9 +31,8 @@ import org.apache.beam.runners.core.TimerInternals; import org.apache.beam.runners.core.construction.SerializablePipelineOptions; import org.apache.beam.runners.spark.metrics.MetricsContainerStepMapAccumulator; -import org.apache.beam.runners.spark.util.CachedSideInputReader; import org.apache.beam.runners.spark.util.SideInputBroadcast; -import org.apache.beam.runners.spark.util.SparkSideInputReader; +import org.apache.beam.runners.spark.util.SideInputReaderFactory; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.DoFnSchemaInformation; @@ -66,6 +65,7 @@ public class MultiDoFnFunction private final MetricsContainerStepMapAccumulator metricsAccum; private final String stepName; private final DoFn doFn; + private final boolean useStreamingSideInput; private transient boolean wasSetupCalled; private final SerializablePipelineOptions options; private final TupleTag mainOutputTag; @@ -106,7 +106,8 @@ public MultiDoFnFunction( boolean stateful, DoFnSchemaInformation doFnSchemaInformation, Map> sideInputMapping, - boolean useBoundedConcurrentOutput) { + boolean useBoundedConcurrentOutput, + boolean useStreamingSideInput) { this.metricsAccum = metricsAccum; this.stepName = stepName; this.doFn = SerializableUtils.clone(doFn); @@ -121,6 +122,7 @@ public MultiDoFnFunction( this.doFnSchemaInformation = doFnSchemaInformation; this.sideInputMapping = sideInputMapping; this.useBoundedConcurrentOutput = useBoundedConcurrentOutput; + this.useStreamingSideInput = useStreamingSideInput; } @Override @@ -178,7 +180,7 @@ public TimerInternals timerInternals() { DoFnRunners.simpleRunner( options.get(), doFn, - CachedSideInputReader.of(new SparkSideInputReader(sideInputs)), + SideInputReaderFactory.create(this.useStreamingSideInput, this.sideInputs), processor.getOutputManager(), mainOutputTag, additionalOutputTags, diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SideInputMetadata.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SideInputMetadata.java new file mode 100644 index 000000000000..6c407806812c --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SideInputMetadata.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.translation; + +import java.io.Serializable; +import org.apache.beam.runners.spark.util.SideInputBroadcast; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.util.WindowedValue; + +/** + * Metadata class for side inputs in Spark runner. Contains serialized data, type information and + * coder for side input processing. + */ +public class SideInputMetadata implements Serializable { + private final byte[] data; + private final SparkPCollectionView.Type type; + private final Coder>> coder; + + /** + * Constructor for SideInputMetadata. + * + * @param data The serialized side input data as byte array + * @param type The type of the SparkPCollectionView + * @param coder The coder for iterables of windowed values + */ + SideInputMetadata( + byte[] data, SparkPCollectionView.Type type, Coder>> coder) { + this.data = data; + this.type = type; + this.coder = coder; + } + + /** + * Creates a new instance of SideInputMetadata. + * + * @param data The serialized side input data as byte array + * @param type The type of the SparkPCollectionView + * @param coder The coder for iterables of windowed values + * @return A new SideInputMetadata instance + */ + public static SideInputMetadata create( + byte[] data, SparkPCollectionView.Type type, Coder>> coder) { + return new SideInputMetadata(data, type, coder); + } + + /** + * Converts this metadata to a {@link SideInputBroadcast} instance. + * + * @return A new {@link SideInputBroadcast} instance created from this metadata + */ + @SuppressWarnings("rawtypes") + public SideInputBroadcast toSideInputBroadcast() { + return SideInputBroadcast.create(this.data, this.type, this.coder); + } +} diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPCollectionView.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPCollectionView.java index c2d21be52144..4cc25bd6ffa2 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPCollectionView.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPCollectionView.java @@ -28,7 +28,6 @@ import org.apache.spark.api.java.JavaSparkContext; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import scala.Tuple2; /** SparkPCollectionView is used to pass serialized views to lambdas. */ @SuppressWarnings({ @@ -41,17 +40,39 @@ public class SparkPCollectionView implements Serializable { // Holds the view --> broadcast mapping. Transient so it will be null from resume private transient volatile Map, SideInputBroadcast> broadcastHelperMap = null; + /** Type of side input. */ + public enum Type { + /** for fixed inputs. */ + STATIC, + /** for dynamically updated inputs. */ + STREAMING + } + // Holds the Actual data of the views in serialize form - private final Map, Tuple2>>>> pviews = - new LinkedHashMap<>(); + private final Map, SideInputMetadata> pviews = new LinkedHashMap<>(); - // Driver only - during evaluation stage - void putPView( + public void putPView( PCollectionView view, Iterable> value, Coder>> coder) { + this.putPView(view, value, coder, Type.STATIC); + } + + public void putStreamingPView( + PCollectionView view, + Iterable> value, + Coder>> coder) { + this.putPView(view, value, coder, Type.STREAMING); + } + + // Driver only - during evaluation stage + private void putPView( + PCollectionView view, + Iterable> value, + Coder>> coder, + Type type) { - pviews.put(view, new Tuple2<>(CoderHelpers.toByteArray(value, coder), coder)); + pviews.put(view, SideInputMetadata.create(CoderHelpers.toByteArray(value, coder), type, coder)); // Currently unsynchronized unpersist, if needed can be changed to blocking if (broadcastHelperMap != null) { @@ -90,8 +111,8 @@ SideInputBroadcast getPCollectionView(PCollectionView view, JavaSparkContext private SideInputBroadcast createBroadcastHelper( PCollectionView view, JavaSparkContext context) { - Tuple2>>> tuple2 = pviews.get(view); - SideInputBroadcast helper = SideInputBroadcast.create(tuple2._1, tuple2._2); + final SideInputMetadata sideInputMetadata = pviews.get(view); + SideInputBroadcast helper = sideInputMetadata.toSideInputBroadcast(); String pCollectionName = view.getPCollection() != null ? view.getPCollection().getName() : "UNKNOWN"; LOG.debug( diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java index f12da285c69d..1fea8b9329c6 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java @@ -432,6 +432,7 @@ public void evaluate( ParDoTranslation.getSideInputMapping(context.getCurrentTransform()); TupleTag mainOutputTag = transform.getMainOutputTag(); + final boolean useStreamingSideInput = context.isStreamingSideInput(); MultiDoFnFunction multiDoFnFunction = new MultiDoFnFunction<>( metricsAccum, @@ -447,7 +448,8 @@ public void evaluate( stateful, doFnSchemaInformation, sideInputMapping, - useBoundedConcurrentOutput); + useBoundedConcurrentOutput, + useStreamingSideInput); if (stateful) { // Based on the fact that the signature is stateful, DoFnSignatures ensures diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamingSparkView.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamingSparkView.java new file mode 100644 index 000000000000..df5547655246 --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamingSparkView.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.translation.streaming; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import org.apache.beam.runners.core.Concatenate; +import org.apache.beam.sdk.coders.ListCoder; +import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.runners.PTransformOverrideFactory; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.util.Preconditions; +import org.apache.beam.sdk.util.construction.CreatePCollectionViewTranslation; +import org.apache.beam.sdk.util.construction.ReplacementOutputs; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.PCollectionViews; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.sdk.values.TypeDescriptors; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; + +/** Spark streaming overrides for various view (side input) transforms. */ +@SuppressWarnings("rawtypes") +public class CreateStreamingSparkView + extends PTransform, PCollection> { + private final PCollectionView view; + + public static final String CREATE_STREAMING_SPARK_VIEW_URN = + "beam:transform:spark:create-streaming-spark-view:v1"; + + public CreateStreamingSparkView(PCollectionView view) { + this.view = view; + } + + @Override + public PCollection expand(PCollection input) { + PCollection> iterable; + // See https://github.com/apache/beam/pull/25940 + if (view.getViewFn() instanceof PCollectionViews.IsSingletonView) { + final TypeDescriptor inputType = + Preconditions.checkStateNotNull(input.getTypeDescriptor()); + + iterable = + input + .apply(MapElements.into(TypeDescriptors.lists(inputType)).via(Lists::newArrayList)) + .setCoder(ListCoder.of(input.getCoder())); + } else { + iterable = input.apply(Combine.globally(new Concatenate()).withoutDefaults()); + } + + iterable.apply(CreateSparkPCollectionView.of(this.view)); + return input; + } + + /** + * Creates a primitive {@link PCollectionView}. + * + *

For internal use only by runner implementors. + * + * @param The type of the elements of the input PCollection + * @param The type associated with the {@link PCollectionView} used as a side input + */ + public static class CreateSparkPCollectionView + extends PTransform>, PCollection>> { + private PCollectionView view; + + private CreateSparkPCollectionView(PCollectionView view) { + this.view = view; + } + + public static CreateSparkPCollectionView of( + PCollectionView view) { + return new CreateSparkPCollectionView<>(view); + } + + @Override + public PCollection> expand(PCollection> input) { + return PCollection.createPrimitiveOutputInternal( + input.getPipeline(), input.getWindowingStrategy(), input.isBounded(), input.getCoder()); + } + + public PCollectionView getView() { + return this.view; + } + } + + public static class Factory + implements PTransformOverrideFactory< + PCollection, + PCollection, + PTransform, PCollection>> { + + public static final Factory INSTANCE = new Factory(); + + private Factory() {} + + @Override + public PTransformReplacement, PCollection> getReplacementTransform( + AppliedPTransform< + PCollection, + PCollection, + PTransform, PCollection>> + transform) { + final PCollection collection = + (PCollection) Iterables.getOnlyElement(transform.getInputs().values()); + + PCollectionView view; + try { + view = CreatePCollectionViewTranslation.getView(transform); + } catch (IOException e) { + throw new RuntimeException(e); + } + + final CreateStreamingSparkView createSparkView = + new CreateStreamingSparkView<>(view); + return PTransformReplacement.of(collection, createSparkView); + } + + @Override + public Map, ReplacementOutput> mapOutputs( + Map, PCollection> outputs, PCollection newOutput) { + return ReplacementOutputs.singleton(outputs, newOutput); + } + } +} diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/ParDoStateUpdateFn.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/ParDoStateUpdateFn.java index 82557c3b972b..b9c7e7d6d63d 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/ParDoStateUpdateFn.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/ParDoStateUpdateFn.java @@ -38,10 +38,9 @@ import org.apache.beam.runners.spark.translation.SparkInputDataProcessor; import org.apache.beam.runners.spark.translation.SparkProcessContext; import org.apache.beam.runners.spark.util.ByteArray; -import org.apache.beam.runners.spark.util.CachedSideInputReader; import org.apache.beam.runners.spark.util.GlobalWatermarkHolder; import org.apache.beam.runners.spark.util.SideInputBroadcast; -import org.apache.beam.runners.spark.util.SparkSideInputReader; +import org.apache.beam.runners.spark.util.SideInputReaderFactory; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.DoFnSchemaInformation; @@ -109,6 +108,8 @@ public class ParDoStateUpdateFn, O private final Map watermarks; private final List sourceIds; private final TimerInternals.TimerDataCoderV2 timerDataCoder; + // for sideInput + private final boolean useStreamingSideInput; public ParDoStateUpdateFn( MetricsContainerStepMapAccumulator metricsAccum, @@ -126,7 +127,8 @@ public ParDoStateUpdateFn( DoFnSchemaInformation doFnSchemaInformation, Map> sideInputMapping, Map watermarks, - List sourceIds) { + List sourceIds, + boolean useStreamingSideInput) { this.metricsAccum = metricsAccum; this.stepName = stepName; this.doFn = SerializableUtils.clone(doFn); @@ -145,6 +147,7 @@ public ParDoStateUpdateFn( this.sourceIds = sourceIds; this.timerDataCoder = TimerInternals.TimerDataCoderV2.of(windowingStrategy.getWindowFn().windowCoder()); + this.useStreamingSideInput = useStreamingSideInput; } @Override @@ -199,7 +202,7 @@ public TimerInternals timerInternals() { DoFnRunners.simpleRunner( options.get(), doFn, - CachedSideInputReader.of(new SparkSideInputReader(sideInputs)), + SideInputReaderFactory.create(this.useStreamingSideInput, this.sideInputs), processor.getOutputManager(), (TupleTag) mainOutputTag, additionalOutputTags, diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StatefulStreamingParDoEvaluator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StatefulStreamingParDoEvaluator.java index 23bcfcb129ce..c135d6edbf8a 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StatefulStreamingParDoEvaluator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StatefulStreamingParDoEvaluator.java @@ -184,7 +184,8 @@ public void evaluate( doFnSchemaInformation, sideInputMapping, watermarks, - sourceIds))); + sourceIds, + context.isStreamingSideInput()))); all = processedPairDStream.flatMapToPair( diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java index e06ef79e483f..b505d9bb3d41 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java @@ -24,6 +24,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Queue; @@ -53,7 +54,9 @@ import org.apache.beam.runners.spark.util.GlobalWatermarkHolder; import org.apache.beam.runners.spark.util.SideInputBroadcast; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.ListCoder; import org.apache.beam.sdk.testing.TestStream; import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.CombineWithContext; @@ -86,6 +89,8 @@ import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.WindowingStrategy; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext$; @@ -386,6 +391,71 @@ public String toNativeString() { }; } + @SuppressWarnings({"rawtypes", "unchecked"}) + private static + TransformEvaluator> + streamingSideInput() { + return new TransformEvaluator< + CreateStreamingSparkView.CreateSparkPCollectionView>() { + @Override + public void evaluate( + CreateStreamingSparkView.CreateSparkPCollectionView transform, + EvaluationContext context) { + final PCollection> input = context.getInput(transform); + final UnboundedDataset dataset = + (UnboundedDataset) context.borrowDataset(input); + final PCollectionView output = transform.getView(); + + final JavaDStream> dStream = dataset.getDStream(); + + Coder> coderInternal = + (Coder) + WindowedValue.getFullCoder( + ListCoder.of(output.getCoderInternal()), + output.getWindowingStrategyInternal().getWindowFn().windowCoder()); + + // Convert JavaDStream to byte array + // The (JavaDStream) cast is used to prevent CheckerFramework type checking errors + // CheckerFramework treats mismatched generic type parameters as errors, + // but at runtime this is safe due to type erasure + final JavaDStream byteConverted = + (JavaDStream) + dStream.mapPartitions( + (Iterator> iter) -> + CoderHelpers.toByteArrays(iter, (Coder) coderInternal).iterator()); + + // Update side input values whenever a new RDD arrives + final SparkPCollectionView pViews = context.getPViews(); + byteConverted.foreachRDD( + (JavaRDD rdd) -> { + final List collect = rdd.collect(); + final Iterable> iterable = + CoderHelpers.fromByteArrays(collect, (Coder) coderInternal); + + if (!Iterables.isEmpty(iterable)) { + pViews.putStreamingPView( + output, (Iterable) iterable, IterableCoder.of(coderInternal)); + } + }); + + // Enable streaming side input mode + context.useStreamingSideInput(); + + // Initialize with empty side input values + // In streaming environment, data from DStream is not immediately available + // The system initializes with empty values and updates them when data arrives + // This means side inputs may initially be null + context.putPView( + output, /*Empty Side Inputs*/ Lists.newArrayList(), IterableCoder.of(coderInternal)); + } + + @Override + public String toNativeString() { + return "streamingView()"; + } + }; + } + private static TransformEvaluator> combineGrouped() { return new TransformEvaluator>() { @@ -480,6 +550,8 @@ public void evaluate( final Map> sideInputMapping = ParDoTranslation.getSideInputMapping(context.getCurrentTransform()); + final boolean useStreamingSideInput = context.isStreamingSideInput(); + final String stepName = context.getCurrentTransform().getFullName(); JavaPairDStream, WindowedValue> all = dStream.transformToPair( @@ -508,7 +580,8 @@ public void evaluate( false, doFnSchemaInformation, sideInputMapping, - false)); + false, + useStreamingSideInput)); }); Map, PCollection> outputs = context.getOutputs(transform); @@ -589,6 +662,7 @@ public String toNativeString() { EVALUATORS.put(PTransformTranslation.ASSIGN_WINDOWS_TRANSFORM_URN, window()); EVALUATORS.put(PTransformTranslation.FLATTEN_TRANSFORM_URN, flattenPColl()); EVALUATORS.put(PTransformTranslation.RESHUFFLE_URN, reshuffle()); + EVALUATORS.put(CreateStreamingSparkView.CREATE_STREAMING_SPARK_VIEW_URN, streamingSideInput()); // For testing only EVALUATORS.put(CreateStream.TRANSFORM_URN, createFromQueue()); EVALUATORS.put(PTransformTranslation.TEST_STREAM_TRANSFORM_URN, createFromTestStream()); @@ -596,6 +670,9 @@ public String toNativeString() { private static @Nullable TransformEvaluator getTranslator(PTransform transform) { @Nullable String urn = PTransformTranslation.urnForTransformOrNull(transform); + if (transform instanceof CreateStreamingSparkView.CreateSparkPCollectionView) { + urn = CreateStreamingSparkView.CREATE_STREAMING_SPARK_VIEW_URN; + } return urn == null ? null : EVALUATORS.get(urn); } diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/util/SideInputBroadcast.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/util/SideInputBroadcast.java index 57c7c6f81870..cf6815c44ec9 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/util/SideInputBroadcast.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/util/SideInputBroadcast.java @@ -20,6 +20,7 @@ import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.Serializable; +import org.apache.beam.runners.spark.translation.SparkPCollectionView; import org.apache.beam.sdk.coders.Coder; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.broadcast.Broadcast; @@ -41,14 +42,18 @@ public class SideInputBroadcast implements Serializable { private final Coder coder; private transient T value; private transient byte[] bytes = null; + private SparkPCollectionView.Type sparkPCollectionViewType; - private SideInputBroadcast(byte[] bytes, Coder coder) { + private SideInputBroadcast( + byte[] bytes, Coder coder, SparkPCollectionView.Type sparkPCollectionViewType) { this.bytes = bytes; this.coder = coder; + this.sparkPCollectionViewType = sparkPCollectionViewType; } - public static SideInputBroadcast create(byte[] bytes, Coder coder) { - return new SideInputBroadcast<>(bytes, coder); + public static SideInputBroadcast create( + byte[] bytes, SparkPCollectionView.Type type, Coder coder) { + return new SideInputBroadcast<>(bytes, coder, type); } public synchronized T getValue() { @@ -62,6 +67,10 @@ public void broadcast(JavaSparkContext jsc) { this.bcast = jsc.broadcast(bytes); } + public SparkPCollectionView.Type getSparkPCollectionViewType() { + return sparkPCollectionViewType; + } + public void unpersist() { this.bcast.unpersist(); } diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/util/SideInputReaderFactory.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/util/SideInputReaderFactory.java new file mode 100644 index 000000000000..97dfb0a9b51e --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/util/SideInputReaderFactory.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.util; + +import java.util.Map; +import org.apache.beam.runners.core.SideInputReader; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.WindowingStrategy; + +/** + * Utility class for creating and managing side input readers in the Spark runner. + * + *

This class provides factory methods to create appropriate {@link SideInputReader} + * implementations based on the execution mode (streaming or batch) to optimize side input access + * patterns. + */ +public class SideInputReaderFactory { + /** + * Creates and returns a {@link SideInputReader} based on the configuration. + * + *

If streaming side inputs are enabled, returns a direct {@link SparkSideInputReader}. + * Otherwise, returns a cached version of the side input reader using {@link + * CachedSideInputReader} for better performance in batch processing. + * + * @param useStreamingSideInput Whether to use streaming side inputs + * @param sideInputs A map of side inputs with their windowing strategies and broadcasts + * @return A {@link SideInputReader} instance appropriate for the current configuration + */ + public static SideInputReader create( + boolean useStreamingSideInput, + Map, KV, SideInputBroadcast>> sideInputs) { + return useStreamingSideInput + ? new SparkSideInputReader(sideInputs) + : CachedSideInputReader.of(new SparkSideInputReader(sideInputs)); + } +} diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/util/SparkSideInputReader.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/util/SparkSideInputReader.java index ce32be569ae4..4e24d7e50e6a 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/util/SparkSideInputReader.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/util/SparkSideInputReader.java @@ -19,11 +19,14 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; +import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import java.util.stream.Stream; import java.util.stream.StreamSupport; import org.apache.beam.runners.core.InMemoryMultimapSideInputView; import org.apache.beam.runners.core.SideInputReader; +import org.apache.beam.runners.spark.translation.SparkPCollectionView; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.transforms.Materializations; @@ -66,9 +69,11 @@ public SparkSideInputReader( // --- match the appropriate sideInput window. // a tag will point to all matching sideInputs, that is all windows. // now that we've obtained the appropriate sideInputWindow, all that's left is to filter by it. + final SideInputBroadcast sideInputBroadcast = windowedBroadcastHelper.getValue(); Iterable> availableSideInputs = - (Iterable>) windowedBroadcastHelper.getValue().getValue(); - Iterable sideInputForWindow = + (Iterable>) sideInputBroadcast.getValue(); + + final Stream> stream = StreamSupport.stream(availableSideInputs.spliterator(), false) .filter( sideInputCandidate -> { @@ -76,11 +81,9 @@ public SparkSideInputReader( return false; } return Iterables.contains(sideInputCandidate.getWindows(), sideInputWindow); - }) - .collect(Collectors.toList()) - .stream() - .map(WindowedValue::getValue) - .collect(Collectors.toList()); + }); + final List sideInputForWindow = + this.getSideInputForWindow(sideInputBroadcast.getSparkPCollectionViewType(), stream); switch (view.getViewFn().getMaterialization().getUrn()) { case Materializations.ITERABLE_MATERIALIZATION_URN: @@ -103,6 +106,46 @@ public SparkSideInputReader( } } + /** + * Extracts side input values from windowed values based on the collection view type. + * + *

For {@link SparkPCollectionView.Type#STATIC} view types, simply extracts the value from each + * {@link WindowedValue}. + * + *

For {@link SparkPCollectionView.Type#STREAMING} view types, performs additional processing + * by flattening any List values, as streaming side inputs arrive as collections that need to be + * processed individually. + * + * @param sparkPCollectionViewType the type of PCollection view (STATIC or STREAMING) + * @param stream the stream of WindowedValues filtered for the current window + * @return a list of extracted side input values + */ + private List getSideInputForWindow( + SparkPCollectionView.Type sparkPCollectionViewType, Stream> stream) { + switch (sparkPCollectionViewType) { + case STATIC: + return stream.map(WindowedValue::getValue).collect(Collectors.toList()); + case STREAMING: + return stream + .flatMap( + (WindowedValue windowedValue) -> { + final Object value = windowedValue.getValue(); + // Streaming side inputs arrive as List collections. + // These lists need to be flattened to process each element individually. + if (value instanceof List) { + final List list = (List) value; + return list.stream(); + } else { + return Stream.of(value); + } + }) + .collect(Collectors.toList()); + default: + throw new IllegalStateException( + String.format("Unknown pcollection view type %s", sparkPCollectionViewType)); + } + } + @Override public boolean contains(PCollectionView view) { return sideInputs.containsKey(view.getTagInternal()); diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslatorTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslatorTest.java index e61f530748c7..309ad8a8ace3 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslatorTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslatorTest.java @@ -21,9 +21,12 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.hasItem; +import static org.hamcrest.Matchers.is; import java.io.IOException; import java.io.Serializable; +import java.util.List; +import java.util.function.Function; import org.apache.beam.runners.spark.StreamingTest; import org.apache.beam.runners.spark.TestSparkPipelineOptions; import org.apache.beam.runners.spark.TestSparkRunner; @@ -33,19 +36,33 @@ import org.apache.beam.runners.spark.util.GlobalWatermarkHolder; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.PipelineResult; +import org.apache.beam.sdk.coders.NullableCoder; +import org.apache.beam.sdk.coders.VarLongCoder; import org.apache.beam.sdk.io.GenerateSequence; import org.apache.beam.sdk.metrics.Distribution; import org.apache.beam.sdk.metrics.DistributionResult; import org.apache.beam.sdk.metrics.MetricNameFilter; +import org.apache.beam.sdk.metrics.MetricResult; import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.metrics.MetricsFilter; +import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.testing.UsesSideInputs; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.Flatten; +import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.transforms.WithTimestamps; +import org.apache.beam.sdk.transforms.View; +import org.apache.beam.sdk.transforms.windowing.AfterPane; +import org.apache.beam.sdk.transforms.windowing.Repeatedly; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Optional; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.Duration; import org.joda.time.Instant; import org.junit.After; @@ -58,8 +75,15 @@ /** Test suite for {@link StreamingTransformTranslator}. */ public class StreamingTransformTranslatorTest implements Serializable { + /** + * A functional interface that creates a {@link Pipeline} from {@link PipelineOptions}. Used in + * tests to define different pipeline configurations that can be executed with the same test + * harness. + */ + @FunctionalInterface + interface PipelineFunction extends Function {} + @Rule public transient TemporaryFolder temporaryFolder = new TemporaryFolder(); - public transient Pipeline p; /** Creates a temporary directory for storing checkpoints before each test execution. */ @Before @@ -71,6 +95,93 @@ public void init() { } } + private static class StreamingSideInputAsIterableView + extends PTransform>> { + private final Instant baseTimestamp; + + private StreamingSideInputAsIterableView(Instant baseTimestamp) { + this.baseTimestamp = baseTimestamp; + } + + @Override + public PCollectionView> expand(PBegin input) { + return input + .getPipeline() + .apply( + "Gen Seq", + GenerateSequence.from(0) + .withRate(1, Duration.millis(500)) + .withTimestampFn(e -> this.baseTimestamp.plus(Duration.millis(e * 100)))) + .apply( + Window.configure() + .withAllowedLateness(Duration.ZERO) + .triggering(Repeatedly.forever(AfterPane.elementCountAtLeast(2))) + .discardingFiredPanes()) + .setCoder(NullableCoder.of(VarLongCoder.of())) + .apply(View.asIterable()); + } + } + + @Test + @Category({StreamingTest.class, UsesSideInputs.class}) + public void testStreamingSideInputAsIterableView() { + final PipelineFunction pipelineFunction = + (PipelineOptions options) -> { + final Instant baseTimestamp = new Instant(0); + final Pipeline p = Pipeline.create(options); + + final PCollectionView> streamingSideInput = + p.apply( + "Streaming Side Input As Iterable View", + new StreamingSideInputAsIterableView(baseTimestamp)); + final PAssertFn pAssertFn = new PAssertFn(); + pAssertFn.streamingSideInputAsIterableView = streamingSideInput; + p.apply( + "Main Input", + GenerateSequence.from(0) + .withRate(1, Duration.millis(500)) + .withTimestampFn(e -> baseTimestamp.plus(Duration.millis(e * 100)))) + .apply( + "StreamingSideInputAssert", + ParDo.of(pAssertFn).withSideInput("streaming-side-input", streamingSideInput)); + + return p; + }; + + final PipelineResult result = run(pipelineFunction, Optional.of(new Instant(1000)), true); + final Iterable> distributions = + result + .metrics() + .queryMetrics( + MetricsFilter.builder() + .addNameFilter(MetricNameFilter.inNamespace(PAssertFn.class)) + .build()) + .getDistributions(); + + final MetricResult streamingIterSideInputMetricResult = + Iterables.find( + distributions, + dist -> dist.getName().getName().equals("streaming_iter_side_input_distribution")); + + final DistributionResult attempted = streamingIterSideInputMetricResult.getAttempted(); + + // The distribution metrics for the iterable side input are calculated based on only + // processing values [0, 1] to maintain consistent test behavior. + // Since we're only processing the pair [0, 1], the DistributionResult values will be: + // sum = count/2 (since we're summing the sequence [0, 1], [0, 1], ...) + // count = total number of elements processed + // min = 0 + // max = 1 + assertThat( + streamingIterSideInputMetricResult, + is( + attemptedMetricsResult( + PAssertFn.class.getName(), + "streaming_iter_side_input_distribution", + "StreamingSideInputAssert", + DistributionResult.create(attempted.getCount() / 2, attempted.getCount(), 0, 1)))); + } + /** * Tests that Flatten transform of Bounded and Unbounded PCollections correctly recovers from * checkpoint. @@ -100,7 +211,32 @@ public void testFlattenPCollResumeFromCheckpoint() { .addNameFilter(MetricNameFilter.inNamespace(PAssertFn.class)) .build(); - PipelineResult res = run(Optional.of(new Instant(400)), false); + final PipelineFunction pipelineFunction = + (PipelineOptions options) -> { + Pipeline p = Pipeline.create(options); + final Instant baseTimestamp = new Instant(0); + final PCollection bounded = + p.apply( + "Bounded", + GenerateSequence.from(0) + .to(10) + .withTimestampFn(e -> baseTimestamp.plus(Duration.millis(e * 100)))) + .apply("BoundedAssert", ParDo.of(new PAssertFn())); + + final PCollection unbounded = + p.apply( + "Unbounded", + GenerateSequence.from(10) + .withRate(3, Duration.standardSeconds(1)) + .withTimestampFn(e -> baseTimestamp.plus(Duration.millis(e * 100)))); + + final PCollection flattened = bounded.apply(Flatten.with(unbounded)); + + flattened.apply("FlattenedAssert", ParDo.of(new PAssertFn())); + return p; + }; + + PipelineResult res = run(pipelineFunction, Optional.of(new Instant(400)), false); // Verify metrics for Bounded PCollection (sum of 0-9 = 45, count = 10) assertThat( @@ -126,7 +262,7 @@ public void testFlattenPCollResumeFromCheckpoint() { clean(); // Second run: recover from checkpoint - res = runAgain(); + res = runAgain(pipelineFunction); // Verify Bounded PCollection metrics remain the same assertThat( @@ -164,8 +300,9 @@ public void testFlattenPCollResumeFromCheckpoint() { } /** Restarts the pipeline from checkpoint. Sets pipeline to stop after 1 second. */ - private PipelineResult runAgain() { + private PipelineResult runAgain(PipelineFunction pipelineFunction) { return run( + pipelineFunction, Optional.of( Instant.ofEpochMilli( Duration.standardSeconds(1L).plus(Duration.millis(50L)).getMillis())), @@ -175,34 +312,30 @@ private PipelineResult runAgain() { /** * Sets up and runs the test pipeline. * + * @param pipelineFunction Function that creates and configures the pipeline to be tested * @param stopWatermarkOption Watermark at which to stop the pipeline * @param deleteCheckpointDir Whether to delete checkpoint directory after completion */ - private PipelineResult run(Optional stopWatermarkOption, boolean deleteCheckpointDir) { - TestSparkPipelineOptions options = - PipelineOptionsFactory.create().as(TestSparkPipelineOptions.class); - options.setSparkMaster("local[*]"); - options.setRunner(TestSparkRunner.class); + private PipelineResult run( + PipelineFunction pipelineFunction, + Optional stopWatermarkOption, + boolean deleteCheckpointDir) { + final TestSparkPipelineOptions options = this.createTestSparkPipelineOptions(); options.setCheckpointDir(temporaryFolder.getRoot().getPath()); if (stopWatermarkOption.isPresent()) { options.setStopPipelineWatermark(stopWatermarkOption.get().getMillis()); } options.setDeleteCheckpointDir(deleteCheckpointDir); - p = Pipeline.create(options); - - final PCollection bounded = - p.apply("Bounded", GenerateSequence.from(0).to(10)) - .apply("BoundedAssert", ParDo.of(new PAssertFn())); - - final PCollection unbounded = - p.apply("Unbounded", GenerateSequence.from(10).withRate(3, Duration.standardSeconds(1))) - .apply(WithTimestamps.of(e -> Instant.now())); - - final PCollection flattened = bounded.apply(Flatten.with(unbounded)); + return pipelineFunction.apply(options).run(); + } - flattened.apply("FlattenedAssert", ParDo.of(new PAssertFn())); - return p.run(); + private TestSparkPipelineOptions createTestSparkPipelineOptions() { + TestSparkPipelineOptions options = + PipelineOptionsFactory.create().as(TestSparkPipelineOptions.class); + options.setSparkMaster("local[*]"); + options.setRunner(TestSparkRunner.class); + return options; } /** @@ -221,10 +354,29 @@ public void clean() { * elements in both bounded and unbounded streams. */ private static class PAssertFn extends DoFn { + @Nullable PCollectionView> streamingSideInputAsIterableView; private final Distribution distribution = Metrics.distribution(PAssertFn.class, "distribution"); + private final Distribution streamingIterSideInputDistribution = + Metrics.distribution(PAssertFn.class, "streaming_iter_side_input_distribution"); @ProcessElement - public void process(@Element Long element, OutputReceiver output) { + public void process( + ProcessContext context, @Element Long element, OutputReceiver output) { + if (this.streamingSideInputAsIterableView != null) { + final Iterable streamingSideInputIterValue = + context.sideInput(this.streamingSideInputAsIterableView); + final List sideInputValues = Lists.newArrayList(streamingSideInputIterValue); + // We only process side input values when they exactly match [0L, 1L] to ensure consistent + // test behavior across different runtime environments. The number of emitted elements can + // vary between test runs, so we need to filter for a specific pattern to maintain test + // determinism. + if (sideInputValues.equals(Lists.newArrayList(0L, 1L))) { + for (Long sideInputValue : sideInputValues) { + this.streamingIterSideInputDistribution.update(sideInputValue); + } + } + } + // For the unbounded source (starting from 10), we expect only 3 elements (10, 11, 12) // to be emitted during the 1-second test window. // However, different execution environments might emit more elements than expected