diff --git a/runners/spark/spark_runner.gradle b/runners/spark/spark_runner.gradle
index 3344c42fe3f6..f5555854c24f 100644
--- a/runners/spark/spark_runner.gradle
+++ b/runners/spark/spark_runner.gradle
@@ -94,7 +94,6 @@ if (copySourceBase) {
}
test {
- systemProperty "beam.spark.test.reuseSparkContext", "true"
systemProperty "spark.sql.shuffle.partitions", "4"
systemProperty "spark.ui.enabled", "false"
systemProperty "spark.ui.showConsoleProgress", "false"
@@ -113,17 +112,14 @@ test {
jvmArgs System.getProperty("beamSurefireArgline")
}
- // Only one SparkContext may be running in a JVM (SPARK-2243)
- forkEvery 1
maxParallelForks 4
useJUnit {
excludeCategories "org.apache.beam.runners.spark.StreamingTest"
excludeCategories "org.apache.beam.runners.spark.UsesCheckpointRecovery"
}
- filter {
- // BEAM-11653 MetricsSinkTest is failing with Spark 3
- excludeTestsMatching 'org.apache.beam.runners.spark.aggregators.metrics.sink.SparkMetricsSinkTest'
- }
+
+ // easily re-run all tests (to deal with flaky tests / SparkContext leaks)
+ if(project.hasProperty("rerun-tests")) { outputs.upToDateWhen {false} }
}
dependencies {
@@ -289,10 +285,6 @@ def validatesRunnerStreaming = tasks.register("validatesRunnerStreaming", Test)
useJUnit {
includeCategories 'org.apache.beam.runners.spark.StreamingTest'
}
- filter {
- // BEAM-11653 MetricsSinkTest is failing with Spark 3
- excludeTestsMatching 'org.apache.beam.runners.spark.aggregators.metrics.sink.SparkMetricsSinkTest'
- }
}
tasks.register("validatesStructuredStreamingRunnerBatch", Test) {
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkContextOptions.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkContextOptions.java
index 13ae67878eb2..39caee7e6ba7 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkContextOptions.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkContextOptions.java
@@ -37,6 +37,12 @@
* which link to Spark dependencies, won't be scanned by {@link PipelineOptions} reflective
* instantiation. Note that {@link SparkContextOptions} is not registered with {@link
* SparkRunnerRegistrar}.
+ *
+ *
Note: It's recommended to use {@link
+ * org.apache.beam.runners.spark.translation.SparkContextFactory#setProvidedSparkContext(JavaSparkContext)}
+ * instead of {@link SparkContextOptions#setProvidedSparkContext(JavaSparkContext)} for testing.
+ * When using @{@link org.apache.beam.sdk.testing.TestPipeline} any provided {@link
+ * JavaSparkContext} via {@link SparkContextOptions} is dropped.
*/
public interface SparkContextOptions extends SparkPipelineOptions {
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunnerDebugger.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunnerDebugger.java
index 65f83f5e8195..0474ca580a44 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunnerDebugger.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunnerDebugger.java
@@ -21,6 +21,7 @@
import java.util.concurrent.TimeoutException;
import org.apache.beam.runners.core.construction.SplittableParDo;
import org.apache.beam.runners.spark.translation.EvaluationContext;
+import org.apache.beam.runners.spark.translation.SparkContextFactory;
import org.apache.beam.runners.spark.translation.SparkPipelineTranslator;
import org.apache.beam.runners.spark.translation.TransformTranslator;
import org.apache.beam.runners.spark.translation.streaming.StreamingTransformTranslator;
@@ -86,7 +87,8 @@ public SparkPipelineResult run(Pipeline pipeline) {
SplittableParDo.convertReadBasedSplittableDoFnsToPrimitiveReadsIfNecessary(pipeline);
}
- JavaSparkContext jsc = new JavaSparkContext("local[1]", "Debug_Pipeline");
+ JavaSparkContext jsc =
+ SparkContextFactory.getSparkContext(pipeline.getOptions().as(SparkPipelineOptions.class));
JavaStreamingContext jssc =
new JavaStreamingContext(jsc, new org.apache.spark.streaming.Duration(1000));
@@ -107,7 +109,7 @@ public SparkPipelineResult run(Pipeline pipeline) {
pipeline.traverseTopologically(visitor);
- jsc.stop();
+ SparkContextFactory.stopSparkContext(jsc);
String debugString = visitor.getDebugString();
LOG.info("Translated Native Spark pipeline:\n" + debugString);
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkContextFactory.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkContextFactory.java
index 61cf3afed9ca..9f9465ccde8f 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkContextFactory.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkContextFactory.java
@@ -17,6 +17,9 @@
*/
package org.apache.beam.runners.spark.translation;
+import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkNotNull;
+
+import javax.annotation.Nullable;
import org.apache.beam.runners.spark.SparkContextOptions;
import org.apache.beam.runners.spark.SparkPipelineOptions;
import org.apache.beam.runners.spark.coders.SparkRunnerKryoRegistrator;
@@ -25,80 +28,121 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-/** The Spark context factory. */
-@SuppressWarnings({
- "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
-})
public final class SparkContextFactory {
private static final Logger LOG = LoggerFactory.getLogger(SparkContextFactory.class);
/**
* If the property {@code beam.spark.test.reuseSparkContext} is set to {@code true} then the Spark
* context will be reused for beam pipelines. This property should only be enabled for tests.
+ *
+ * @deprecated This will leak your SparkContext, any attempt to create a new SparkContext later
+ * will fail. Please use {@link #setProvidedSparkContext(JavaSparkContext)} / {@link
+ * #clearProvidedSparkContext()} instead to properly control the lifecycle of your context.
+ * Alternatively you may also provide a SparkContext using {@link
+ * SparkContextOptions#setUsesProvidedSparkContext(boolean)} together with {@link
+ * SparkContextOptions#setProvidedSparkContext(JavaSparkContext)} and close that one
+ * appropriately. Tests of this module should use {@code SparkContextRule}.
*/
+ @Deprecated
public static final String TEST_REUSE_SPARK_CONTEXT = "beam.spark.test.reuseSparkContext";
// Spark allows only one context for JVM so this can be static.
- private static JavaSparkContext sparkContext;
- private static String sparkMaster;
- private static boolean usesProvidedSparkContext;
+ private static @Nullable JavaSparkContext sparkContext;
+
+ // Remember spark master if TEST_REUSE_SPARK_CONTEXT is enabled.
+ private static @Nullable String reusableSparkMaster;
+
+ // SparkContext is provided by the user instead of simply reused using TEST_REUSE_SPARK_CONTEXT
+ private static boolean hasProvidedSparkContext;
private SparkContextFactory() {}
+ /**
+ * Set an externally managed {@link JavaSparkContext} that will be used if {@link
+ * SparkContextOptions#getUsesProvidedSparkContext()} is set to {@code true}.
+ *
+ *
A Spark context can also be provided using {@link
+ * SparkContextOptions#setProvidedSparkContext(JavaSparkContext)}. However, it will be dropped
+ * during serialization potentially leading to confusing behavior. This is particularly the case
+ * when used in tests with {@link org.apache.beam.sdk.testing.TestPipeline}.
+ */
+ public static synchronized void setProvidedSparkContext(JavaSparkContext providedSparkContext) {
+ sparkContext = checkNotNull(providedSparkContext);
+ hasProvidedSparkContext = true;
+ reusableSparkMaster = null;
+ }
+
+ public static synchronized void clearProvidedSparkContext() {
+ hasProvidedSparkContext = false;
+ sparkContext = null;
+ }
+
public static synchronized JavaSparkContext getSparkContext(SparkPipelineOptions options) {
SparkContextOptions contextOptions = options.as(SparkContextOptions.class);
- usesProvidedSparkContext = contextOptions.getUsesProvidedSparkContext();
- // reuse should be ignored if the context is provided.
- if (Boolean.getBoolean(TEST_REUSE_SPARK_CONTEXT) && !usesProvidedSparkContext) {
-
- // if the context is null or stopped for some reason, re-create it.
- if (sparkContext == null || sparkContext.sc().isStopped()) {
- sparkContext = createSparkContext(contextOptions);
- sparkMaster = options.getSparkMaster();
- } else if (!options.getSparkMaster().equals(sparkMaster)) {
- throw new IllegalArgumentException(
+ if (contextOptions.getUsesProvidedSparkContext()) {
+ JavaSparkContext jsc = contextOptions.getProvidedSparkContext();
+ if (jsc != null) {
+ setProvidedSparkContext(jsc);
+ } else if (hasProvidedSparkContext) {
+ jsc = sparkContext;
+ }
+ if (jsc == null) {
+ throw new IllegalStateException(
+ "No Spark context was provided. Use SparkContextFactor.setProvidedSparkContext to do so.");
+ } else if (jsc.sc().isStopped()) {
+ LOG.error("The provided Spark context " + jsc + " was already stopped.");
+ throw new IllegalStateException("The provided Spark context was already stopped");
+ }
+ LOG.info("Using a provided Spark Context");
+ return jsc;
+ } else if (Boolean.getBoolean(TEST_REUSE_SPARK_CONTEXT)) {
+ // This is highly discouraged as it leaks the SparkContext without any way to close it.
+ // Attempting to create any new SparkContext later will fail.
+ // If the context is null or stopped for some reason, re-create it.
+ @Nullable JavaSparkContext jsc = sparkContext;
+ if (jsc == null || jsc.sc().isStopped()) {
+ sparkContext = jsc = createSparkContext(contextOptions);
+ reusableSparkMaster = options.getSparkMaster();
+ hasProvidedSparkContext = false;
+ } else if (hasProvidedSparkContext) {
+ throw new IllegalStateException(
+ "Usage of provided Spark context is disabled in SparkPipelineOptions.");
+ } else if (!options.getSparkMaster().equals(reusableSparkMaster)) {
+ throw new IllegalStateException(
String.format(
"Cannot reuse spark context "
+ "with different spark master URL. Existing: %s, requested: %s.",
- sparkMaster, options.getSparkMaster()));
+ reusableSparkMaster, options.getSparkMaster()));
}
- return sparkContext;
+ return jsc;
} else {
- return createSparkContext(contextOptions);
+ JavaSparkContext jsc = createSparkContext(contextOptions);
+ clearProvidedSparkContext(); // any provided context can't be valid anymore
+ return jsc;
}
}
public static synchronized void stopSparkContext(JavaSparkContext context) {
- if (!Boolean.getBoolean(TEST_REUSE_SPARK_CONTEXT) && !usesProvidedSparkContext) {
+ if (!Boolean.getBoolean(TEST_REUSE_SPARK_CONTEXT) && !hasProvidedSparkContext) {
context.stop();
}
}
- private static JavaSparkContext createSparkContext(SparkContextOptions contextOptions) {
- if (usesProvidedSparkContext) {
- LOG.info("Using a provided Spark Context");
- JavaSparkContext jsc = contextOptions.getProvidedSparkContext();
- if (jsc == null || jsc.sc().isStopped()) {
- LOG.error("The provided Spark context " + jsc + " was not created or was stopped");
- throw new RuntimeException("The provided Spark context was not created or was stopped");
- }
- return jsc;
- } else {
- LOG.info("Creating a brand new Spark Context.");
- SparkConf conf = new SparkConf();
- if (!conf.contains("spark.master")) {
- // set master if not set.
- conf.setMaster(contextOptions.getSparkMaster());
- }
-
- if (contextOptions.getFilesToStage() != null && !contextOptions.getFilesToStage().isEmpty()) {
- conf.setJars(contextOptions.getFilesToStage().toArray(new String[0]));
- }
+ private static JavaSparkContext createSparkContext(SparkPipelineOptions options) {
+ LOG.info("Creating a brand new Spark Context.");
+ SparkConf conf = new SparkConf();
+ if (!conf.contains("spark.master")) {
+ // set master if not set.
+ conf.setMaster(options.getSparkMaster());
+ }
- conf.setAppName(contextOptions.getAppName());
- // register immutable collections serializers because the SDK uses them.
- conf.set("spark.kryo.registrator", SparkRunnerKryoRegistrator.class.getName());
- return new JavaSparkContext(conf);
+ if (options.getFilesToStage() != null && !options.getFilesToStage().isEmpty()) {
+ conf.setJars(options.getFilesToStage().toArray(new String[0]));
}
+
+ conf.setAppName(options.getAppName());
+ // register immutable collections serializers because the SDK uses them.
+ conf.set("spark.kryo.registrator", SparkRunnerKryoRegistrator.class.getName());
+ return new JavaSparkContext(conf);
}
}
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/CacheTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/CacheTest.java
index 8209d4302717..861e13a4208f 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/CacheTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/CacheTest.java
@@ -25,11 +25,9 @@
import java.util.List;
import org.apache.beam.runners.spark.translation.Dataset;
import org.apache.beam.runners.spark.translation.EvaluationContext;
-import org.apache.beam.runners.spark.translation.SparkContextFactory;
import org.apache.beam.runners.spark.translation.TransformTranslator;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.Create.Values;
@@ -39,7 +37,7 @@
import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionView;
-import org.apache.spark.api.java.JavaSparkContext;
+import org.junit.ClassRule;
import org.junit.Test;
/** Tests of {@link Dataset#cache(String, Coder)}} scenarios. */
@@ -48,13 +46,15 @@
})
public class CacheTest {
+ @ClassRule public static SparkContextRule contextRule = new SparkContextRule();
+
/**
* Test checks how the cache candidates map is populated by the runner when evaluating the
* pipeline.
*/
@Test
public void cacheCandidatesUpdaterTest() {
- SparkPipelineOptions options = createOptions();
+ SparkPipelineOptions options = contextRule.createPipelineOptions();
Pipeline pipeline = Pipeline.create(options);
PCollection pCollection = pipeline.apply(Create.of("foo", "bar"));
@@ -80,8 +80,8 @@ public void processElement(ProcessContext processContext) {
})
.withSideInputs(view));
- JavaSparkContext jsc = SparkContextFactory.getSparkContext(options);
- EvaluationContext ctxt = new EvaluationContext(jsc, pipeline, options);
+ EvaluationContext ctxt =
+ new EvaluationContext(contextRule.getSparkContext(), pipeline, options);
SparkRunner.CacheVisitor cacheVisitor =
new SparkRunner.CacheVisitor(new TransformTranslator.Translator(), ctxt);
pipeline.traverseTopologically(cacheVisitor);
@@ -91,15 +91,15 @@ public void processElement(ProcessContext processContext) {
@Test
public void shouldCacheTest() {
- SparkPipelineOptions options = createOptions();
+ SparkPipelineOptions options = contextRule.createPipelineOptions();
options.setCacheDisabled(true);
Pipeline pipeline = Pipeline.create(options);
Values valuesTransform = Create.of("foo", "bar");
PCollection pCollection = mock(PCollection.class);
- JavaSparkContext jsc = SparkContextFactory.getSparkContext(options);
- EvaluationContext ctxt = new EvaluationContext(jsc, pipeline, options);
+ EvaluationContext ctxt =
+ new EvaluationContext(contextRule.getSparkContext(), pipeline, options);
ctxt.getCacheCandidates().put(pCollection, 2L);
assertFalse(ctxt.shouldCache(valuesTransform, pCollection));
@@ -110,11 +110,4 @@ public void shouldCacheTest() {
GroupByKey gbkTransform = GroupByKey.create();
assertFalse(ctxt.shouldCache(gbkTransform, pCollection));
}
-
- private SparkPipelineOptions createOptions() {
- SparkPipelineOptions options =
- PipelineOptionsFactory.create().as(TestSparkPipelineOptions.class);
- options.setRunner(TestSparkRunner.class);
- return options;
- }
}
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/GlobalWatermarkHolderTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/GlobalWatermarkHolderTest.java
index 7bcff9875db6..a4dc6afd9c45 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/GlobalWatermarkHolderTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/GlobalWatermarkHolderTest.java
@@ -20,13 +20,12 @@
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.core.IsEqual.equalTo;
-import org.apache.beam.runners.spark.translation.SparkContextFactory;
import org.apache.beam.runners.spark.util.GlobalWatermarkHolder;
import org.apache.beam.runners.spark.util.GlobalWatermarkHolder.SparkWatermarks;
-import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.testing.RegexMatcher;
import org.joda.time.Duration;
import org.joda.time.Instant;
+import org.junit.ClassRule;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
@@ -34,27 +33,21 @@
/** A test suite for the propagation of watermarks in the Spark runner. */
public class GlobalWatermarkHolderTest {
+ // Watermark holder requires valid SparkEnv
+ @ClassRule public static SparkContextRule contextRule = new SparkContextRule();
+
@Rule public ClearWatermarksRule clearWatermarksRule = new ClearWatermarksRule();
@Rule public ExpectedException thrown = ExpectedException.none();
- @Rule public ReuseSparkContextRule reuseContext = ReuseSparkContextRule.yes();
-
- // only needed in-order to get context from the SparkContextFactory.
- private static final SparkPipelineOptions options =
- PipelineOptionsFactory.create().as(SparkPipelineOptions.class);
-
private static final String INSTANT_PATTERN =
"[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}.[0-9]{3}Z";
@Test
public void testLowHighWatermarksAdvance() {
-
Instant instant = new Instant(0);
// low == high.
- SparkContextFactory.getSparkContext(options);
-
GlobalWatermarkHolder.add(
1,
new SparkWatermarks(
@@ -98,7 +91,7 @@ public void testLowHighWatermarksAdvance() {
@Test
public void testSynchronizedTimeMonotonic() {
Instant instant = new Instant(0);
- SparkContextFactory.getSparkContext(options);
+
GlobalWatermarkHolder.add(
1,
new SparkWatermarks(
@@ -119,7 +112,7 @@ public void testSynchronizedTimeMonotonic() {
@Test
public void testMultiSource() {
Instant instant = new Instant(0);
- SparkContextFactory.getSparkContext(options);
+
GlobalWatermarkHolder.add(
1,
new SparkWatermarks(
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/ProvidedSparkContextTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/ProvidedSparkContextTest.java
index 4a57ade09cb5..0ef6bb0e078c 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/ProvidedSparkContextTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/ProvidedSparkContextTest.java
@@ -18,13 +18,12 @@
package org.apache.beam.runners.spark;
import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.fail;
+import static org.junit.Assert.assertThrows;
import org.apache.beam.runners.spark.examples.WordCount;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.PipelineResult;
import org.apache.beam.sdk.coders.StringUtf8Coder;
-import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
@@ -32,10 +31,18 @@
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet;
-import org.apache.spark.api.java.JavaSparkContext;
+import org.junit.ClassRule;
+import org.junit.FixMethodOrder;
import org.junit.Test;
+import org.junit.runners.MethodSorters;
-/** Provided Spark Context tests. */
+/**
+ * Provided Spark Context tests.
+ *
+ * Note: These tests are run sequentially ordered by their name to reuse the Spark context and
+ * speed up testing.
+ */
+@FixMethodOrder(MethodSorters.NAME_ASCENDING)
public class ProvidedSparkContextTest {
private static final String[] WORDS_ARRAY = {
"hi there", "hi", "hi sue bob",
@@ -47,72 +54,50 @@ public class ProvidedSparkContextTest {
private static final String PROVIDED_CONTEXT_EXCEPTION =
"The provided Spark context was not created or was stopped";
+ @ClassRule
+ public static SparkContextOptionsRule contextOptionsRule = new SparkContextOptionsRule();
+
/** Provide a context and call pipeline run. */
@Test
- public void testWithProvidedContext() throws Exception {
- JavaSparkContext jsc = new JavaSparkContext("local[*]", "Existing_Context");
- testWithValidProvidedContext(jsc);
+ public void testAWithProvidedContext() throws Exception {
+ Pipeline p = createPipeline();
+ PipelineResult result = p.run(); // Run test from pipeline
+ result.waitUntilFinish();
+ TestPipeline.verifyPAssertsSucceeded(p, result);
// A provided context must not be stopped after execution
- assertFalse(jsc.sc().isStopped());
- jsc.stop();
+ assertFalse(contextOptionsRule.getSparkContext().sc().isStopped());
}
- /** Provide a context and call pipeline run. */
+ /** A SparkRunner with a stopped provided Spark context cannot run pipelines. */
@Test
- public void testWithNullContext() throws Exception {
- testWithInvalidContext(null);
+ public void testBWithStoppedProvidedContext() {
+ // Stop the provided Spark context
+ contextOptionsRule.getSparkContext().sc().stop();
+ assertThrows(
+ PROVIDED_CONTEXT_EXCEPTION,
+ RuntimeException.class,
+ () -> createPipeline().run().waitUntilFinish());
}
- /** A SparkRunner with a stopped provided Spark context cannot run pipelines. */
+ /** Provide a context and call pipeline run. */
@Test
- public void testWithStoppedProvidedContext() throws Exception {
- JavaSparkContext jsc = new JavaSparkContext("local[*]", "Existing_Context");
- // Stop the provided Spark context directly
- jsc.stop();
- testWithInvalidContext(jsc);
+ public void testCWithNullContext() {
+ contextOptionsRule.getOptions().setProvidedSparkContext(null);
+ assertThrows(
+ PROVIDED_CONTEXT_EXCEPTION,
+ RuntimeException.class,
+ () -> createPipeline().run().waitUntilFinish());
}
- private void testWithValidProvidedContext(JavaSparkContext jsc) throws Exception {
- SparkContextOptions options = getSparkContextOptions(jsc);
-
- Pipeline p = Pipeline.create(options);
+ private Pipeline createPipeline() {
+ Pipeline p = Pipeline.create(contextOptionsRule.getOptions());
PCollection inputWords = p.apply(Create.of(WORDS).withCoder(StringUtf8Coder.of()));
PCollection output =
inputWords
.apply(new WordCount.CountWords())
.apply(MapElements.via(new WordCount.FormatAsTextFn()));
-
- PAssert.that(output).containsInAnyOrder(EXPECTED_COUNT_SET);
-
// Run test from pipeline
- PipelineResult result = p.run();
-
- TestPipeline.verifyPAssertsSucceeded(p, result);
- }
-
- private void testWithInvalidContext(JavaSparkContext jsc) {
- SparkContextOptions options = getSparkContextOptions(jsc);
-
- Pipeline p = Pipeline.create(options);
- PCollection inputWords = p.apply(Create.of(WORDS).withCoder(StringUtf8Coder.of()));
- inputWords
- .apply(new WordCount.CountWords())
- .apply(MapElements.via(new WordCount.FormatAsTextFn()));
-
- try {
- p.run().waitUntilFinish();
- fail("Should throw an exception when The provided Spark context is null or stopped");
- } catch (RuntimeException e) {
- assert e.getMessage().contains(PROVIDED_CONTEXT_EXCEPTION);
- }
- }
-
- private static SparkContextOptions getSparkContextOptions(JavaSparkContext jsc) {
- final SparkContextOptions options = PipelineOptionsFactory.as(SparkContextOptions.class);
- options.setRunner(TestSparkRunner.class);
- options.setUsesProvidedSparkContext(true);
- options.setProvidedSparkContext(jsc);
- options.setEnableSparkMetricSinks(false);
- return options;
+ PAssert.that(output).containsInAnyOrder(EXPECTED_COUNT_SET);
+ return p;
}
}
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/ReuseSparkContextRule.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkContextOptionsRule.java
similarity index 58%
rename from runners/spark/src/test/java/org/apache/beam/runners/spark/ReuseSparkContextRule.java
rename to runners/spark/src/test/java/org/apache/beam/runners/spark/SparkContextOptionsRule.java
index 54b77448f78a..2f424cd7ca40 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/ReuseSparkContextRule.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkContextOptionsRule.java
@@ -17,28 +17,27 @@
*/
package org.apache.beam.runners.spark;
-import org.apache.beam.runners.spark.translation.SparkContextFactory;
-import org.junit.rules.ExternalResource;
+import javax.annotation.Nullable;
+import org.apache.beam.sdk.values.KV;
-/** Explicitly set {@link org.apache.spark.SparkContext} to be reused (or not) in tests. */
-public class ReuseSparkContextRule extends ExternalResource {
+public class SparkContextOptionsRule extends SparkContextRule {
- private final boolean reuse;
+ private @Nullable SparkContextOptions contextOptions = null;
- private ReuseSparkContextRule(boolean reuse) {
- this.reuse = reuse;
- }
-
- public static ReuseSparkContextRule no() {
- return new ReuseSparkContextRule(false);
- }
-
- public static ReuseSparkContextRule yes() {
- return new ReuseSparkContextRule(true);
+ public SparkContextOptionsRule(KV... sparkConfig) {
+ super(sparkConfig);
}
@Override
protected void before() throws Throwable {
- System.setProperty(SparkContextFactory.TEST_REUSE_SPARK_CONTEXT, Boolean.toString(reuse));
+ super.before();
+ contextOptions = createPipelineOptions();
+ }
+
+ public SparkContextOptions getOptions() {
+ if (contextOptions == null) {
+ throw new IllegalStateException("SparkContextOptions not available");
+ }
+ return contextOptions;
}
}
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkContextRule.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkContextRule.java
new file mode 100644
index 000000000000..caa7d8f6814b
--- /dev/null
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkContextRule.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark;
+
+import static java.util.stream.Collectors.toMap;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.Map;
+import javax.annotation.Nullable;
+import org.apache.beam.runners.spark.translation.SparkContextFactory;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.values.KV;
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.junit.rules.ExternalResource;
+import org.junit.runner.Description;
+import org.junit.runners.model.Statement;
+
+public class SparkContextRule extends ExternalResource implements Serializable {
+ private transient SparkConf sparkConf;
+ private transient @Nullable JavaSparkContext sparkContext = null;
+
+ public SparkContextRule(String sparkMaster, Map sparkConfig) {
+ sparkConf = new SparkConf();
+ sparkConfig.forEach(sparkConf::set);
+ sparkConf.setMaster(sparkMaster);
+ }
+
+ public SparkContextRule(KV... sparkConfig) {
+ this("local", sparkConfig);
+ }
+
+ public SparkContextRule(String sparkMaster, KV... sparkConfig) {
+ this(sparkMaster, Arrays.stream(sparkConfig).collect(toMap(KV::getKey, KV::getValue)));
+ }
+
+ public JavaSparkContext getSparkContext() {
+ if (sparkContext == null) {
+ throw new IllegalStateException("SparkContext not available");
+ }
+ return sparkContext;
+ }
+
+ public SparkContextOptions createPipelineOptions() {
+ return configure(TestPipeline.testingPipelineOptions());
+ }
+
+ public SparkContextOptions configure(PipelineOptions opts) {
+ SparkContextOptions ctxOpts = opts.as(SparkContextOptions.class);
+ ctxOpts.setUsesProvidedSparkContext(true);
+ ctxOpts.setProvidedSparkContext(getSparkContext());
+ return ctxOpts;
+ }
+
+ @Override
+ public Statement apply(Statement base, Description description) {
+ sparkConf.setAppName(description.getDisplayName());
+ return super.apply(base, description);
+ }
+
+ @Override
+ protected void before() throws Throwable {
+ sparkContext = new JavaSparkContext(sparkConf);
+ SparkContextFactory.setProvidedSparkContext(sparkContext);
+ }
+
+ @Override
+ protected void after() {
+ SparkContextFactory.clearProvidedSparkContext();
+ getSparkContext().stop();
+ sparkContext = null;
+ }
+}
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPipelineStateTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPipelineStateTest.java
index b48f553d8fc5..61e111234331 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPipelineStateTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPipelineStateTest.java
@@ -20,190 +20,137 @@
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
-import static org.junit.Assert.fail;
+import static org.joda.time.Duration.millis;
+import static org.junit.Assert.assertThrows;
import java.io.Serializable;
+import javax.annotation.Nullable;
import org.apache.beam.runners.spark.io.CreateStream;
import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.Pipeline.PipelineExecutionException;
import org.apache.beam.sdk.PipelineResult;
import org.apache.beam.sdk.coders.StringUtf8Coder;
-import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.transforms.Create;
-import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.PTransform;
-import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SimpleFunction;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
-import org.joda.time.Duration;
-import org.junit.Rule;
+import org.junit.ClassRule;
import org.junit.Test;
-import org.junit.rules.TestName;
/** This suite tests that various scenarios result in proper states of the pipeline. */
public class SparkPipelineStateTest implements Serializable {
- private static class MyCustomException extends RuntimeException {
+ @ClassRule public static SparkContextRule contextRule = new SparkContextRule();
- MyCustomException(final String message) {
+ private static class CustomException extends RuntimeException {
+ CustomException(final String message) {
super(message);
}
}
- private final transient SparkPipelineOptions options =
- PipelineOptionsFactory.as(SparkPipelineOptions.class);
-
- @Rule public transient TestName testName = new TestName();
-
- private static final String FAILED_THE_BATCH_INTENTIONALLY = "Failed the batch intentionally";
-
- private ParDo.SingleOutput printParDo(final String prefix) {
- return ParDo.of(
- new DoFn() {
-
- @ProcessElement
- public void processElement(final ProcessContext c) {
- System.out.println(prefix + " " + c.element());
- }
- });
- }
-
- private PTransform> getValues(final SparkPipelineOptions options) {
- final boolean doNotSyncWithWatermark = false;
- return options.isStreaming()
- ? CreateStream.of(StringUtf8Coder.of(), Duration.millis(1), doNotSyncWithWatermark)
- .nextBatch("one", "two")
- : Create.of("one", "two");
+ private static class FailAlways extends SimpleFunction {
+ @Override
+ public String apply(final String input) {
+ throw new CustomException(FAILED_THE_BATCH_INTENTIONALLY);
+ }
}
- private SparkPipelineOptions getStreamingOptions() {
- options.setRunner(SparkRunner.class);
- options.setStreaming(true);
- return options;
- }
+ private static final String FAILED_THE_BATCH_INTENTIONALLY = "Failed the batch intentionally";
- private SparkPipelineOptions getBatchOptions() {
+ private Pipeline createPipeline(
+ boolean isStreaming, @Nullable SimpleFunction mapFun) {
+ SparkContextOptions options = contextRule.createPipelineOptions();
options.setRunner(SparkRunner.class);
- options.setStreaming(false); // explicit because options is reused throughout the test.
- return options;
- }
+ options.setStreaming(isStreaming);
- private Pipeline getPipeline(final SparkPipelineOptions options) {
-
- final Pipeline pipeline = Pipeline.create(options);
- final String name = testName.getMethodName() + "(isStreaming=" + options.isStreaming() + ")";
-
- pipeline.apply(getValues(options)).setCoder(StringUtf8Coder.of()).apply(printParDo(name));
+ Pipeline pipeline = Pipeline.create(options);
+ PTransform> values =
+ isStreaming
+ ? CreateStream.of(StringUtf8Coder.of(), millis(1), false).nextBatch("one", "two")
+ : Create.of("one", "two");
+ PCollection collection = pipeline.apply(values).setCoder(StringUtf8Coder.of());
+ if (mapFun != null) {
+ collection.apply(MapElements.via(mapFun));
+ }
return pipeline;
}
- private void testFailedPipeline(final SparkPipelineOptions options) throws Exception {
-
- SparkPipelineResult result = null;
-
- try {
- final Pipeline pipeline = Pipeline.create(options);
- pipeline
- .apply(getValues(options))
- .setCoder(StringUtf8Coder.of())
- .apply(
- MapElements.via(
- new SimpleFunction() {
-
- @Override
- public String apply(final String input) {
- throw new MyCustomException(FAILED_THE_BATCH_INTENTIONALLY);
- }
- }));
-
- result = (SparkPipelineResult) pipeline.run();
- result.waitUntilFinish();
- } catch (final Exception e) {
- assertThat(e, instanceOf(Pipeline.PipelineExecutionException.class));
- assertThat(e.getCause(), instanceOf(MyCustomException.class));
- assertThat(e.getCause().getMessage(), is(FAILED_THE_BATCH_INTENTIONALLY));
- assertThat(result.getState(), is(PipelineResult.State.FAILED));
- result.cancel();
- return;
- }
+ private void testFailedPipeline(boolean isStreaming) throws Exception {
+ Pipeline pipeline = createPipeline(isStreaming, new FailAlways());
+ SparkPipelineResult result = (SparkPipelineResult) pipeline.run();
- fail("An injected failure did not affect the pipeline as expected.");
+ PipelineExecutionException e =
+ assertThrows(PipelineExecutionException.class, () -> result.waitUntilFinish());
+ assertThat(e.getCause(), instanceOf(CustomException.class));
+ assertThat(e.getCause().getMessage(), is(FAILED_THE_BATCH_INTENTIONALLY));
+ assertThat(result.getState(), is(PipelineResult.State.FAILED));
+ result.cancel();
}
- private void testTimeoutPipeline(final SparkPipelineOptions options) throws Exception {
-
- final Pipeline pipeline = getPipeline(options);
-
- final SparkPipelineResult result = (SparkPipelineResult) pipeline.run();
-
- result.waitUntilFinish(Duration.millis(1));
+ private void testWaitUntilFinishedTimeout(boolean isStreaming) throws Exception {
+ Pipeline pipeline = createPipeline(isStreaming, null);
+ SparkPipelineResult result = (SparkPipelineResult) pipeline.run();
+ result.waitUntilFinish(millis(1));
+ // Wait timed out, pipeline is still running
assertThat(result.getState(), is(PipelineResult.State.RUNNING));
-
result.cancel();
}
- private void testCanceledPipeline(final SparkPipelineOptions options) throws Exception {
-
- final Pipeline pipeline = getPipeline(options);
-
- final SparkPipelineResult result = (SparkPipelineResult) pipeline.run();
-
+ private void testCanceledPipeline(boolean isStreaming) throws Exception {
+ Pipeline pipeline = createPipeline(isStreaming, null);
+ SparkPipelineResult result = (SparkPipelineResult) pipeline.run();
result.cancel();
-
assertThat(result.getState(), is(PipelineResult.State.CANCELLED));
}
- private void testRunningPipeline(final SparkPipelineOptions options) throws Exception {
-
- final Pipeline pipeline = getPipeline(options);
-
- final SparkPipelineResult result = (SparkPipelineResult) pipeline.run();
-
+ private void testRunningPipeline(boolean isStreaming) throws Exception {
+ Pipeline pipeline = createPipeline(isStreaming, null);
+ SparkPipelineResult result = (SparkPipelineResult) pipeline.run();
assertThat(result.getState(), is(PipelineResult.State.RUNNING));
-
result.cancel();
}
@Test
public void testStreamingPipelineRunningState() throws Exception {
- testRunningPipeline(getStreamingOptions());
+ testRunningPipeline(true);
}
@Test
public void testBatchPipelineRunningState() throws Exception {
- testRunningPipeline(getBatchOptions());
+ testRunningPipeline(false);
}
@Test
public void testStreamingPipelineCanceledState() throws Exception {
- testCanceledPipeline(getStreamingOptions());
+ testCanceledPipeline(true);
}
@Test
public void testBatchPipelineCanceledState() throws Exception {
- testCanceledPipeline(getBatchOptions());
+ testCanceledPipeline(false);
}
@Test
public void testStreamingPipelineFailedState() throws Exception {
- testFailedPipeline(getStreamingOptions());
+ testFailedPipeline(true);
}
@Test
public void testBatchPipelineFailedState() throws Exception {
- testFailedPipeline(getBatchOptions());
+ testFailedPipeline(false);
}
@Test
- public void testStreamingPipelineTimeoutState() throws Exception {
- testTimeoutPipeline(getStreamingOptions());
+ public void testStreamingPipelineWaitTimeout() throws Exception {
+ testWaitUntilFinishedTimeout(true);
}
@Test
- public void testBatchPipelineTimeoutState() throws Exception {
- testTimeoutPipeline(getBatchOptions());
+ public void testBatchPipelineWaitTimeout() throws Exception {
+ testWaitUntilFinishedTimeout(false);
}
}
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java
index c9bb83dd0c34..91ef5a426401 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java
@@ -49,6 +49,7 @@
import org.apache.kafka.common.serialization.StringSerializer;
import org.hamcrest.Matchers;
import org.joda.time.Duration;
+import org.junit.ClassRule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -57,11 +58,12 @@
@RunWith(JUnit4.class)
public class SparkRunnerDebuggerTest {
+ @ClassRule public static SparkContextRule contextRule = new SparkContextRule("local[1]");
+
@Test
public void debugBatchPipeline() {
- PipelineOptions options = PipelineOptionsFactory.create().as(TestSparkPipelineOptions.class);
+ PipelineOptions options = contextRule.configure(PipelineOptionsFactory.create());
options.setRunner(SparkRunnerDebugger.class);
-
Pipeline pipeline = Pipeline.create(options);
PCollection lines =
@@ -105,11 +107,9 @@ public void debugBatchPipeline() {
@Test
public void debugStreamingPipeline() {
- TestSparkPipelineOptions options =
- PipelineOptionsFactory.create().as(TestSparkPipelineOptions.class);
- options.setForceStreaming(true);
+ PipelineOptions options = contextRule.configure(PipelineOptionsFactory.create());
options.setRunner(SparkRunnerDebugger.class);
-
+ options.as(TestSparkPipelineOptions.class).setForceStreaming(true);
Pipeline pipeline = Pipeline.create(options);
KafkaIO.Read read =
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/aggregators/metrics/sink/SparkMetricsSinkTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/aggregators/metrics/sink/SparkMetricsSinkTest.java
index 7439ebfeb726..f21168336d02 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/aggregators/metrics/sink/SparkMetricsSinkTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/aggregators/metrics/sink/SparkMetricsSinkTest.java
@@ -21,7 +21,7 @@
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;
-import org.apache.beam.runners.spark.ReuseSparkContextRule;
+import org.apache.beam.runners.spark.SparkContextRule;
import org.apache.beam.runners.spark.SparkPipelineOptions;
import org.apache.beam.runners.spark.StreamingTest;
import org.apache.beam.runners.spark.examples.WordCount;
@@ -39,6 +39,7 @@
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet;
import org.joda.time.Duration;
import org.joda.time.Instant;
+import org.junit.ClassRule;
import org.junit.Rule;
import org.junit.Test;
import org.junit.experimental.categories.Category;
@@ -49,9 +50,13 @@
* streaming modes.
*/
public class SparkMetricsSinkTest {
+ @ClassRule public static SparkContextRule contextRule = new SparkContextRule();
+
@Rule public ExternalResource inMemoryMetricsSink = new InMemoryMetricsSinkRule();
- @Rule public final TestPipeline pipeline = TestPipeline.create();
- @Rule public final transient ReuseSparkContextRule noContextResue = ReuseSparkContextRule.no();
+
+ @Rule
+ public final TestPipeline pipeline =
+ TestPipeline.fromOptions(contextRule.createPipelineOptions());
private static final ImmutableList WORDS =
ImmutableList.of("hi there", "hi", "hi sue bob", "hi sue", "", "bob hi");
@@ -68,7 +73,7 @@ public void testInBatchMode() throws Exception {
.apply(new WordCount.CountWords())
.apply(MapElements.via(new WordCount.FormatAsTextFn()));
PAssert.that(output).containsInAnyOrder(EXPECTED_COUNTS);
- pipeline.run();
+ pipeline.run().waitUntilFinish();
assertThat(InMemoryMetrics.valueOf("emptyLines"), is(1d));
}
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/coders/SparkRunnerKryoRegistratorTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/coders/SparkRunnerKryoRegistratorTest.java
index 390b127871a4..fcc7fee27063 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/coders/SparkRunnerKryoRegistratorTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/coders/SparkRunnerKryoRegistratorTest.java
@@ -17,96 +17,84 @@
*/
package org.apache.beam.runners.spark.coders;
+import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
import com.esotericsoftware.kryo.Kryo;
import com.esotericsoftware.kryo.Registration;
-import org.apache.beam.runners.spark.SparkContextOptions;
-import org.apache.beam.runners.spark.SparkPipelineOptions;
-import org.apache.beam.runners.spark.TestSparkPipelineOptions;
-import org.apache.beam.runners.spark.TestSparkRunner;
+import org.apache.beam.runners.spark.SparkContextRule;
+import org.apache.beam.runners.spark.coders.SparkRunnerKryoRegistratorTest.Others.TestKryoRegistrator;
import org.apache.beam.runners.spark.io.MicrobatchSource;
import org.apache.beam.sdk.Pipeline;
-import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.transforms.Create;
-import org.apache.spark.SparkConf;
-import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.beam.sdk.values.KV;
+import org.junit.ClassRule;
import org.junit.Test;
+import org.junit.experimental.runners.Enclosed;
+import org.junit.runner.RunWith;
-/** Testing of beam registrar. */
+/**
+ * Testing of beam registrar. Note: There can only be one Spark context at a time. For that reason
+ * tests requiring a different context have to be forked using separate test classes.
+ */
@SuppressWarnings({
"rawtypes" // TODO(https://issues.apache.org/jira/browse/BEAM-10556)
})
+@RunWith(Enclosed.class)
public class SparkRunnerKryoRegistratorTest {
- @Test
- public void testKryoRegistration() {
- SparkConf conf = new SparkConf();
- conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
- conf.set("spark.kryo.registrator", WrapperKryoRegistrator.class.getName());
- runSimplePipelineWithSparkContext(conf);
- assertTrue(
- "WrapperKryoRegistrator wasn't initiated, probably KryoSerializer is not set",
- WrapperKryoRegistrator.wasInitiated);
- }
-
- @Test
- public void testDefaultSerializerNotCallingKryo() {
- SparkConf conf = new SparkConf();
- conf.set("spark.kryo.registrator", KryoRegistratorIsNotCalled.class.getName());
- runSimplePipelineWithSparkContext(conf);
- }
+ public static class WithKryoSerializer {
- private void runSimplePipelineWithSparkContext(SparkConf conf) {
- SparkPipelineOptions options =
- PipelineOptionsFactory.create().as(TestSparkPipelineOptions.class);
- options.setRunner(TestSparkRunner.class);
+ @ClassRule
+ public static SparkContextRule contextRule =
+ new SparkContextRule(
+ KV.of("spark.serializer", "org.apache.spark.serializer.KryoSerializer"),
+ KV.of("spark.kryo.registrator", TestKryoRegistrator.class.getName()));
- conf.set("spark.master", "local");
- conf.setAppName("test");
-
- JavaSparkContext javaSparkContext = new JavaSparkContext(conf);
- options.setUsesProvidedSparkContext(true);
- options.as(SparkContextOptions.class).setProvidedSparkContext(javaSparkContext);
- Pipeline p = Pipeline.create(options);
- p.apply(Create.of("a")); // some operation to trigger pipeline construction
- p.run().waitUntilFinish();
- javaSparkContext.stop();
+ @Test
+ public void testKryoRegistration() {
+ TestKryoRegistrator.wasInitiated = false;
+ runSimplePipelineWithSparkContextOptions(contextRule);
+ assertTrue(TestKryoRegistrator.wasInitiated);
+ }
}
- /**
- * A {@link SparkRunnerKryoRegistrator} that fails if called. Use only for test purposes. Needs to
- * be public for serialization.
- */
- public static class KryoRegistratorIsNotCalled extends SparkRunnerKryoRegistrator {
+ public static class WithoutKryoSerializer {
+ @ClassRule
+ public static SparkContextRule contextRule =
+ new SparkContextRule(KV.of("spark.kryo.registrator", TestKryoRegistrator.class.getName()));
- @Override
- public void registerClasses(Kryo kryo) {
- fail(
- "Default spark.serializer is JavaSerializer"
- + " so spark.kryo.registrator shouldn't be called");
+ @Test
+ public void testDefaultSerializerNotCallingKryo() {
+ TestKryoRegistrator.wasInitiated = false;
+ runSimplePipelineWithSparkContextOptions(contextRule);
+ assertFalse(TestKryoRegistrator.wasInitiated);
}
}
- /**
- * A {@link SparkRunnerKryoRegistrator} that registers an internal class to validate
- * KryoSerialization resolution. Use only for test purposes. Needs to be public for serialization.
- */
- public static class WrapperKryoRegistrator extends SparkRunnerKryoRegistrator {
+ // Hide TestKryoRegistrator from the Enclosed JUnit runner
+ interface Others {
+ class TestKryoRegistrator extends SparkRunnerKryoRegistrator {
- static boolean wasInitiated = false;
+ static boolean wasInitiated = false;
- public WrapperKryoRegistrator() {
- wasInitiated = true;
- }
+ public TestKryoRegistrator() {
+ wasInitiated = true;
+ }
- @Override
- public void registerClasses(Kryo kryo) {
- super.registerClasses(kryo);
- Registration registration = kryo.getRegistration(MicrobatchSource.class);
- com.esotericsoftware.kryo.Serializer kryoSerializer = registration.getSerializer();
- assertTrue(kryoSerializer instanceof StatelessJavaSerializer);
+ @Override
+ public void registerClasses(Kryo kryo) {
+ super.registerClasses(kryo);
+ // verify serializer for MicrobatchSource
+ Registration registration = kryo.getRegistration(MicrobatchSource.class);
+ assertTrue(registration.getSerializer() instanceof StatelessJavaSerializer);
+ }
}
}
+
+ private static void runSimplePipelineWithSparkContextOptions(SparkContextRule context) {
+ Pipeline p = Pipeline.create(context.createPipelineOptions());
+ p.apply(Create.of("a")); // some operation to trigger pipeline construction
+ p.run().waitUntilFinish();
+ }
}
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/metrics/SparkMetricsPusherTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/metrics/SparkMetricsPusherTest.java
index bc4f7507ca05..aa7ab616ecd8 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/metrics/SparkMetricsPusherTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/metrics/SparkMetricsPusherTest.java
@@ -21,7 +21,6 @@
import static org.hamcrest.Matchers.is;
import org.apache.beam.runners.core.metrics.TestMetricsSink;
-import org.apache.beam.runners.spark.ReuseSparkContextRule;
import org.apache.beam.runners.spark.SparkPipelineOptions;
import org.apache.beam.runners.spark.StreamingTest;
import org.apache.beam.runners.spark.io.CreateStream;
@@ -52,8 +51,6 @@ public class SparkMetricsPusherTest {
private static final Logger LOG = LoggerFactory.getLogger(SparkMetricsPusherTest.class);
private static final String COUNTER_NAME = "counter";
- @Rule public final transient ReuseSparkContextRule noContextResue = ReuseSparkContextRule.no();
-
@Rule public final TestPipeline pipeline = TestPipeline.create();
private Duration batchDuration() {
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/SparkSessionRule.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/SparkSessionRule.java
new file mode 100644
index 000000000000..f68df83ac07d
--- /dev/null
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/SparkSessionRule.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming;
+
+import static java.util.stream.Collectors.toMap;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.Map;
+import javax.annotation.Nullable;
+import org.apache.beam.sdk.values.KV;
+import org.apache.spark.sql.SparkSession;
+import org.junit.rules.ExternalResource;
+import org.junit.runner.Description;
+import org.junit.runners.model.Statement;
+
+public class SparkSessionRule extends ExternalResource implements Serializable {
+ private transient SparkSession.Builder builder;
+ private transient @Nullable SparkSession session = null;
+
+ public SparkSessionRule(String sparkMaster, Map sparkConfig) {
+ builder = SparkSession.builder();
+ sparkConfig.forEach(builder::config);
+ builder.master(sparkMaster);
+ }
+
+ public SparkSessionRule(KV... sparkConfig) {
+ this("local", sparkConfig);
+ }
+
+ public SparkSessionRule(String sparkMaster, KV... sparkConfig) {
+ this(sparkMaster, Arrays.stream(sparkConfig).collect(toMap(KV::getKey, KV::getValue)));
+ }
+
+ public SparkSession getSession() {
+ if (session == null) {
+ throw new IllegalStateException("SparkSession not available");
+ }
+ return session;
+ }
+
+ @Override
+ public Statement apply(Statement base, Description description) {
+ builder.appName(description.getDisplayName());
+ return super.apply(base, description);
+ }
+
+ @Override
+ protected void before() throws Throwable {
+ session = builder.getOrCreate();
+ }
+
+ @Override
+ protected void after() {
+ getSession().stop();
+ session = null;
+ }
+}
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpersTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpersTest.java
index 54db4fae1c24..3151a5fe956f 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpersTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpersTest.java
@@ -21,9 +21,10 @@
import java.util.Arrays;
import java.util.List;
+import org.apache.beam.runners.spark.structuredstreaming.SparkSessionRule;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.SparkSession;
+import org.junit.ClassRule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -32,16 +33,15 @@
@RunWith(JUnit4.class)
public class EncoderHelpersTest {
+ @ClassRule public static SparkSessionRule sessionRule = new SparkSessionRule();
+
@Test
public void beamCoderToSparkEncoderTest() {
- SparkSession sparkSession =
- SparkSession.builder()
- .appName("beamCoderToSparkEncoderTest")
- .master("local[4]")
- .getOrCreate();
List data = Arrays.asList(1, 2, 3);
Dataset dataset =
- sparkSession.createDataset(data, EncoderHelpers.fromBeamCoder(VarIntCoder.of()));
+ sessionRule
+ .getSession()
+ .createDataset(data, EncoderHelpers.fromBeamCoder(VarIntCoder.of()));
assertEquals(data, dataset.collectAsList());
}
}
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java
index 2a40b45136a9..8fde97456227 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java
@@ -29,7 +29,6 @@
import java.io.IOException;
import java.io.Serializable;
import java.util.concurrent.atomic.AtomicInteger;
-import org.apache.beam.runners.spark.ReuseSparkContextRule;
import org.apache.beam.runners.spark.SparkPipelineOptions;
import org.apache.beam.runners.spark.StreamingTest;
import org.apache.beam.runners.spark.io.CreateStream;
@@ -86,7 +85,6 @@
public class CreateStreamTest implements Serializable {
@Rule public final transient TestPipeline p = TestPipeline.create();
- @Rule public final transient ReuseSparkContextRule noContextResue = ReuseSparkContextRule.no();
@Rule public final transient ExpectedException thrown = ExpectedException.none();
@Test
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java
index 6c107c474b66..e7f45d99e513 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java
@@ -30,7 +30,6 @@
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.TimeUnit;
-import org.apache.beam.runners.spark.ReuseSparkContextRule;
import org.apache.beam.runners.spark.SparkPipelineResult;
import org.apache.beam.runners.spark.TestSparkPipelineOptions;
import org.apache.beam.runners.spark.TestSparkRunner;
@@ -84,7 +83,6 @@
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
-import org.junit.Rule;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.rules.TemporaryFolder;
@@ -112,8 +110,6 @@ public class ResumeFromCheckpointStreamingTest implements Serializable {
private transient TemporaryFolder temporaryFolder;
- @Rule public final transient ReuseSparkContextRule noContextReuse = ReuseSparkContextRule.no();
-
@BeforeClass
public static void setup() throws IOException {
EMBEDDED_ZOOKEEPER.startup();
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/SparkCoGroupByKeyStreamingTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/SparkCoGroupByKeyStreamingTest.java
index 407b07ac0d6d..fc4e427e2f30 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/SparkCoGroupByKeyStreamingTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/SparkCoGroupByKeyStreamingTest.java
@@ -22,7 +22,6 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
-import org.apache.beam.runners.spark.ReuseSparkContextRule;
import org.apache.beam.runners.spark.SparkPipelineOptions;
import org.apache.beam.runners.spark.StreamingTest;
import org.apache.beam.runners.spark.io.CreateStream;
@@ -53,8 +52,6 @@ public class SparkCoGroupByKeyStreamingTest {
private static final TupleTag INPUT1_TAG = new TupleTag<>("input1");
private static final TupleTag INPUT2_TAG = new TupleTag<>("input2");
- @Rule public final transient ReuseSparkContextRule noContextResue = ReuseSparkContextRule.no();
-
@Rule public final TestPipeline pipeline = TestPipeline.create();
private Duration batchDuration() {
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java
index 5ede41aaedaf..79bc8a0a71a2 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java
@@ -22,17 +22,15 @@
import static org.hamcrest.core.IsEqual.equalTo;
import java.util.List;
-import org.apache.beam.runners.spark.ReuseSparkContextRule;
+import org.apache.beam.runners.spark.SparkContextRule;
import org.apache.beam.runners.spark.SparkPipelineOptions;
import org.apache.beam.runners.spark.SparkRunner;
import org.apache.beam.runners.spark.io.CreateStream;
import org.apache.beam.runners.spark.translation.Dataset;
import org.apache.beam.runners.spark.translation.EvaluationContext;
-import org.apache.beam.runners.spark.translation.SparkContextFactory;
import org.apache.beam.runners.spark.translation.TransformTranslator;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.VarIntCoder;
-import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.runners.TransformHierarchy;
import org.apache.beam.sdk.transforms.DoFn;
@@ -46,7 +44,7 @@
import org.apache.spark.streaming.api.java.JavaStreamingContext;
import org.joda.time.Duration;
import org.junit.Before;
-import org.junit.Rule;
+import org.junit.ClassRule;
import org.junit.Test;
/**
@@ -58,10 +56,7 @@
})
public class TrackStreamingSourcesTest {
- @Rule public ReuseSparkContextRule reuseContext = ReuseSparkContextRule.yes();
-
- private static final transient SparkPipelineOptions options =
- PipelineOptionsFactory.create().as(SparkPipelineOptions.class);
+ @ClassRule public static SparkContextRule sparkContext = new SparkContextRule();
@Before
public void before() {
@@ -70,8 +65,9 @@ public void before() {
@Test
public void testTrackSingle() {
+ SparkPipelineOptions options = sparkContext.createPipelineOptions();
options.setRunner(SparkRunner.class);
- JavaSparkContext jsc = SparkContextFactory.getSparkContext(options);
+ JavaSparkContext jsc = sparkContext.getSparkContext();
JavaStreamingContext jssc =
new JavaStreamingContext(
jsc, new org.apache.spark.streaming.Duration(options.getBatchIntervalMillis()));
@@ -90,8 +86,9 @@ public void testTrackSingle() {
@Test
public void testTrackFlattened() {
+ SparkPipelineOptions options = sparkContext.createPipelineOptions();
options.setRunner(SparkRunner.class);
- JavaSparkContext jsc = SparkContextFactory.getSparkContext(options);
+ JavaSparkContext jsc = sparkContext.getSparkContext();
JavaStreamingContext jssc =
new JavaStreamingContext(
jsc, new org.apache.spark.streaming.Duration(options.getBatchIntervalMillis()));
@@ -135,7 +132,7 @@ private StreamingSourceTracker(
Pipeline pipeline,
Class extends PTransform> transformClassToAssert,
Integer... expected) {
- this.ctxt = new EvaluationContext(jssc.sparkContext(), pipeline, options, jssc);
+ this.ctxt = new EvaluationContext(jssc.sparkContext(), pipeline, pipeline.getOptions(), jssc);
this.evaluator =
new SparkRunner.Evaluator(
new StreamingTransformTranslator.Translator(new TransformTranslator.Translator()),