Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions runners/spark/spark_runner.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ if (copySourceBase) {
}

test {
systemProperty "beam.spark.test.reuseSparkContext", "true"
systemProperty "spark.sql.shuffle.partitions", "4"
systemProperty "spark.ui.enabled", "false"
systemProperty "spark.ui.showConsoleProgress", "false"
Expand All @@ -113,17 +112,14 @@ test {
jvmArgs System.getProperty("beamSurefireArgline")
}

// Only one SparkContext may be running in a JVM (SPARK-2243)
forkEvery 1
maxParallelForks 4
useJUnit {
excludeCategories "org.apache.beam.runners.spark.StreamingTest"
excludeCategories "org.apache.beam.runners.spark.UsesCheckpointRecovery"
}
filter {
// BEAM-11653 MetricsSinkTest is failing with Spark 3
excludeTestsMatching 'org.apache.beam.runners.spark.aggregators.metrics.sink.SparkMetricsSinkTest'
}

// easily re-run all tests (to deal with flaky tests / SparkContext leaks)
if(project.hasProperty("rerun-tests")) { outputs.upToDateWhen {false} }
}

dependencies {
Expand Down Expand Up @@ -289,10 +285,6 @@ def validatesRunnerStreaming = tasks.register("validatesRunnerStreaming", Test)
useJUnit {
includeCategories 'org.apache.beam.runners.spark.StreamingTest'
}
filter {
// BEAM-11653 MetricsSinkTest is failing with Spark 3
excludeTestsMatching 'org.apache.beam.runners.spark.aggregators.metrics.sink.SparkMetricsSinkTest'
}
}

tasks.register("validatesStructuredStreamingRunnerBatch", Test) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@
* which link to Spark dependencies, won't be scanned by {@link PipelineOptions} reflective
* instantiation. Note that {@link SparkContextOptions} is not registered with {@link
* SparkRunnerRegistrar}.
*
* <p>Note: It's recommended to use {@link
* org.apache.beam.runners.spark.translation.SparkContextFactory#setProvidedSparkContext(JavaSparkContext)}
* instead of {@link SparkContextOptions#setProvidedSparkContext(JavaSparkContext)} for testing.
* When using @{@link org.apache.beam.sdk.testing.TestPipeline} any provided {@link
* JavaSparkContext} via {@link SparkContextOptions} is dropped.
*/
public interface SparkContextOptions extends SparkPipelineOptions {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.concurrent.TimeoutException;
import org.apache.beam.runners.core.construction.SplittableParDo;
import org.apache.beam.runners.spark.translation.EvaluationContext;
import org.apache.beam.runners.spark.translation.SparkContextFactory;
import org.apache.beam.runners.spark.translation.SparkPipelineTranslator;
import org.apache.beam.runners.spark.translation.TransformTranslator;
import org.apache.beam.runners.spark.translation.streaming.StreamingTransformTranslator;
Expand Down Expand Up @@ -86,7 +87,8 @@ public SparkPipelineResult run(Pipeline pipeline) {
SplittableParDo.convertReadBasedSplittableDoFnsToPrimitiveReadsIfNecessary(pipeline);
}

JavaSparkContext jsc = new JavaSparkContext("local[1]", "Debug_Pipeline");
JavaSparkContext jsc =
SparkContextFactory.getSparkContext(pipeline.getOptions().as(SparkPipelineOptions.class));
JavaStreamingContext jssc =
new JavaStreamingContext(jsc, new org.apache.spark.streaming.Duration(1000));

Expand All @@ -107,7 +109,7 @@ public SparkPipelineResult run(Pipeline pipeline) {

pipeline.traverseTopologically(visitor);

jsc.stop();
SparkContextFactory.stopSparkContext(jsc);

String debugString = visitor.getDebugString();
LOG.info("Translated Native Spark pipeline:\n" + debugString);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
*/
package org.apache.beam.runners.spark.translation;

import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkNotNull;

import javax.annotation.Nullable;
import org.apache.beam.runners.spark.SparkContextOptions;
import org.apache.beam.runners.spark.SparkPipelineOptions;
import org.apache.beam.runners.spark.coders.SparkRunnerKryoRegistrator;
Expand All @@ -25,80 +28,121 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** The Spark context factory. */
@SuppressWarnings({
"nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
})
public final class SparkContextFactory {
private static final Logger LOG = LoggerFactory.getLogger(SparkContextFactory.class);

/**
* If the property {@code beam.spark.test.reuseSparkContext} is set to {@code true} then the Spark
* context will be reused for beam pipelines. This property should only be enabled for tests.
*
* @deprecated This will leak your SparkContext, any attempt to create a new SparkContext later
* will fail. Please use {@link #setProvidedSparkContext(JavaSparkContext)} / {@link
* #clearProvidedSparkContext()} instead to properly control the lifecycle of your context.
* Alternatively you may also provide a SparkContext using {@link
* SparkContextOptions#setUsesProvidedSparkContext(boolean)} together with {@link
* SparkContextOptions#setProvidedSparkContext(JavaSparkContext)} and close that one
* appropriately. Tests of this module should use {@code SparkContextRule}.
*/
@Deprecated
public static final String TEST_REUSE_SPARK_CONTEXT = "beam.spark.test.reuseSparkContext";

// Spark allows only one context for JVM so this can be static.
private static JavaSparkContext sparkContext;
private static String sparkMaster;
private static boolean usesProvidedSparkContext;
private static @Nullable JavaSparkContext sparkContext;

// Remember spark master if TEST_REUSE_SPARK_CONTEXT is enabled.
private static @Nullable String reusableSparkMaster;

// SparkContext is provided by the user instead of simply reused using TEST_REUSE_SPARK_CONTEXT
private static boolean hasProvidedSparkContext;

private SparkContextFactory() {}

/**
* Set an externally managed {@link JavaSparkContext} that will be used if {@link
* SparkContextOptions#getUsesProvidedSparkContext()} is set to {@code true}.
*
* <p>A Spark context can also be provided using {@link
* SparkContextOptions#setProvidedSparkContext(JavaSparkContext)}. However, it will be dropped
* during serialization potentially leading to confusing behavior. This is particularly the case
* when used in tests with {@link org.apache.beam.sdk.testing.TestPipeline}.
*/
public static synchronized void setProvidedSparkContext(JavaSparkContext providedSparkContext) {
sparkContext = checkNotNull(providedSparkContext);
hasProvidedSparkContext = true;
reusableSparkMaster = null;
}

public static synchronized void clearProvidedSparkContext() {
hasProvidedSparkContext = false;
sparkContext = null;
}

public static synchronized JavaSparkContext getSparkContext(SparkPipelineOptions options) {
SparkContextOptions contextOptions = options.as(SparkContextOptions.class);
usesProvidedSparkContext = contextOptions.getUsesProvidedSparkContext();
// reuse should be ignored if the context is provided.
if (Boolean.getBoolean(TEST_REUSE_SPARK_CONTEXT) && !usesProvidedSparkContext) {

// if the context is null or stopped for some reason, re-create it.
if (sparkContext == null || sparkContext.sc().isStopped()) {
sparkContext = createSparkContext(contextOptions);
sparkMaster = options.getSparkMaster();
} else if (!options.getSparkMaster().equals(sparkMaster)) {
throw new IllegalArgumentException(
if (contextOptions.getUsesProvidedSparkContext()) {
JavaSparkContext jsc = contextOptions.getProvidedSparkContext();
if (jsc != null) {
setProvidedSparkContext(jsc);
} else if (hasProvidedSparkContext) {
jsc = sparkContext;
}
if (jsc == null) {
throw new IllegalStateException(
"No Spark context was provided. Use SparkContextFactor.setProvidedSparkContext to do so.");
} else if (jsc.sc().isStopped()) {
LOG.error("The provided Spark context " + jsc + " was already stopped.");
throw new IllegalStateException("The provided Spark context was already stopped");
}
LOG.info("Using a provided Spark Context");
return jsc;
} else if (Boolean.getBoolean(TEST_REUSE_SPARK_CONTEXT)) {
// This is highly discouraged as it leaks the SparkContext without any way to close it.
// Attempting to create any new SparkContext later will fail.
// If the context is null or stopped for some reason, re-create it.
@Nullable JavaSparkContext jsc = sparkContext;
if (jsc == null || jsc.sc().isStopped()) {
sparkContext = jsc = createSparkContext(contextOptions);
reusableSparkMaster = options.getSparkMaster();
hasProvidedSparkContext = false;
} else if (hasProvidedSparkContext) {
throw new IllegalStateException(
"Usage of provided Spark context is disabled in SparkPipelineOptions.");
} else if (!options.getSparkMaster().equals(reusableSparkMaster)) {
throw new IllegalStateException(
String.format(
"Cannot reuse spark context "
+ "with different spark master URL. Existing: %s, requested: %s.",
sparkMaster, options.getSparkMaster()));
reusableSparkMaster, options.getSparkMaster()));
}
return sparkContext;
return jsc;
} else {
return createSparkContext(contextOptions);
JavaSparkContext jsc = createSparkContext(contextOptions);
clearProvidedSparkContext(); // any provided context can't be valid anymore
return jsc;
}
}

public static synchronized void stopSparkContext(JavaSparkContext context) {
if (!Boolean.getBoolean(TEST_REUSE_SPARK_CONTEXT) && !usesProvidedSparkContext) {
if (!Boolean.getBoolean(TEST_REUSE_SPARK_CONTEXT) && !hasProvidedSparkContext) {
context.stop();
}
}

private static JavaSparkContext createSparkContext(SparkContextOptions contextOptions) {
if (usesProvidedSparkContext) {
LOG.info("Using a provided Spark Context");
JavaSparkContext jsc = contextOptions.getProvidedSparkContext();
if (jsc == null || jsc.sc().isStopped()) {
LOG.error("The provided Spark context " + jsc + " was not created or was stopped");
throw new RuntimeException("The provided Spark context was not created or was stopped");
}
return jsc;
} else {
LOG.info("Creating a brand new Spark Context.");
SparkConf conf = new SparkConf();
if (!conf.contains("spark.master")) {
// set master if not set.
conf.setMaster(contextOptions.getSparkMaster());
}

if (contextOptions.getFilesToStage() != null && !contextOptions.getFilesToStage().isEmpty()) {
conf.setJars(contextOptions.getFilesToStage().toArray(new String[0]));
}
private static JavaSparkContext createSparkContext(SparkPipelineOptions options) {
LOG.info("Creating a brand new Spark Context.");
SparkConf conf = new SparkConf();
if (!conf.contains("spark.master")) {
// set master if not set.
conf.setMaster(options.getSparkMaster());
}

conf.setAppName(contextOptions.getAppName());
// register immutable collections serializers because the SDK uses them.
conf.set("spark.kryo.registrator", SparkRunnerKryoRegistrator.class.getName());
return new JavaSparkContext(conf);
if (options.getFilesToStage() != null && !options.getFilesToStage().isEmpty()) {
conf.setJars(options.getFilesToStage().toArray(new String[0]));
}

conf.setAppName(options.getAppName());
// register immutable collections serializers because the SDK uses them.
conf.set("spark.kryo.registrator", SparkRunnerKryoRegistrator.class.getName());
return new JavaSparkContext(conf);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,9 @@
import java.util.List;
import org.apache.beam.runners.spark.translation.Dataset;
import org.apache.beam.runners.spark.translation.EvaluationContext;
import org.apache.beam.runners.spark.translation.SparkContextFactory;
import org.apache.beam.runners.spark.translation.TransformTranslator;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.Create.Values;
Expand All @@ -39,7 +37,7 @@
import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.spark.api.java.JavaSparkContext;
import org.junit.ClassRule;
import org.junit.Test;

/** Tests of {@link Dataset#cache(String, Coder)}} scenarios. */
Expand All @@ -48,13 +46,15 @@
})
public class CacheTest {

@ClassRule public static SparkContextRule contextRule = new SparkContextRule();

/**
* Test checks how the cache candidates map is populated by the runner when evaluating the
* pipeline.
*/
@Test
public void cacheCandidatesUpdaterTest() {
SparkPipelineOptions options = createOptions();
SparkPipelineOptions options = contextRule.createPipelineOptions();
Pipeline pipeline = Pipeline.create(options);
PCollection<String> pCollection = pipeline.apply(Create.of("foo", "bar"));

Expand All @@ -80,8 +80,8 @@ public void processElement(ProcessContext processContext) {
})
.withSideInputs(view));

JavaSparkContext jsc = SparkContextFactory.getSparkContext(options);
EvaluationContext ctxt = new EvaluationContext(jsc, pipeline, options);
EvaluationContext ctxt =
new EvaluationContext(contextRule.getSparkContext(), pipeline, options);
SparkRunner.CacheVisitor cacheVisitor =
new SparkRunner.CacheVisitor(new TransformTranslator.Translator(), ctxt);
pipeline.traverseTopologically(cacheVisitor);
Expand All @@ -91,15 +91,15 @@ public void processElement(ProcessContext processContext) {

@Test
public void shouldCacheTest() {
SparkPipelineOptions options = createOptions();
SparkPipelineOptions options = contextRule.createPipelineOptions();
options.setCacheDisabled(true);
Pipeline pipeline = Pipeline.create(options);

Values<String> valuesTransform = Create.of("foo", "bar");
PCollection pCollection = mock(PCollection.class);

JavaSparkContext jsc = SparkContextFactory.getSparkContext(options);
EvaluationContext ctxt = new EvaluationContext(jsc, pipeline, options);
EvaluationContext ctxt =
new EvaluationContext(contextRule.getSparkContext(), pipeline, options);
ctxt.getCacheCandidates().put(pCollection, 2L);

assertFalse(ctxt.shouldCache(valuesTransform, pCollection));
Expand All @@ -110,11 +110,4 @@ public void shouldCacheTest() {
GroupByKey<String, String> gbkTransform = GroupByKey.create();
assertFalse(ctxt.shouldCache(gbkTransform, pCollection));
}

private SparkPipelineOptions createOptions() {
SparkPipelineOptions options =
PipelineOptionsFactory.create().as(TestSparkPipelineOptions.class);
options.setRunner(TestSparkRunner.class);
return options;
}
}
Loading