From 84c2513a2cb7a50ba01cdc0d49e2ad7f9583799b Mon Sep 17 00:00:00 2001 From: Russell Spitzer Date: Fri, 3 Aug 2018 11:04:00 -0500 Subject: [PATCH 1/6] [SPARK-25003][PYSPARK] Use SessionExtensions in Pyspark Previously Pyspark used the private constructor for SparkSession when building that object. This resulted in a SparkSession without checking the sql.extensions parameter for additional session extensions. To fix this we instead use the Session.builder() path as SparkR uses, this loads the extensions and allows their use in PySpark. --- python/pyspark/sql/session.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index f1ad6b1212ed9..16fe65a88f648 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -218,7 +218,9 @@ def __init__(self, sparkContext, jsparkSession=None): .sparkContext().isStopped(): jsparkSession = self._jvm.SparkSession.getDefaultSession().get() else: - jsparkSession = self._jvm.SparkSession(self._jsc.sc()) + jsparkSession = self._jvm.SparkSession.builder() \ + .sparkContext(self._jsc.sc()) \ + .getOrCreate() self._jsparkSession = jsparkSession self._jwrapped = self._jsparkSession.sqlContext() self._wrapped = SQLContext(self._sc, self, self._jwrapped) From def4f3e91b38e7d63c03fc7206489fd01220e878 Mon Sep 17 00:00:00 2001 From: Russell Spitzer Date: Thu, 16 Aug 2018 16:12:41 -0500 Subject: [PATCH 2/6] [SPARK-25003][PYSPARK]: Add Tests for spark.sql.extensions in Pyspark Adds a test which sets spark.sql.extensions to a custom extension class. This is the same as the SparkExtensionsSuite which does the same thing in Scala. --- python/pyspark/sql/tests.py | 42 +++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 565654e7f03bb..44dd299a9a5b7 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3563,6 +3563,48 @@ def test_query_execution_listener_on_collect_with_arrow(self): "The callback from the query execution listener should be called after 'toPandas'") +class SparkExtensionsTest(unittest.TestCase, SQLTestUtils): + # These tests are separate because it uses 'spark.sql.extensions' which is + # static and immutable. This can't be set or unset, for example, via `spark.conf`. + + @classmethod + def setUpClass(cls): + import glob + from pyspark.find_spark_home import _find_spark_home + + SPARK_HOME = _find_spark_home() + filename_pattern = ( + "sql/core/target/scala-*/test-classes/org/apache/spark/sql/" + "SparkSessionExtensionSuite.class") + if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)): + raise unittest.SkipTest( + "'org.apache.spark.sql.SparkSessionExtensionSuite' is not " + "available. Will skip the related tests.") + + # Note that 'spark.sql.extensions' is a static immutable configuration. + cls.spark = SparkSession.builder \ + .master("local[4]") \ + .appName(cls.__name__) \ + .config( + "spark.sql.extensions", + "org.apache.spark.sql.MyExtensions") \ + .getOrCreate() + + @classmethod + def tearDownClass(cls): + cls.spark.stop() + + def test_use_custom_class_for_extensions(self): + self.assertTrue( + self.spark._jsparkSession.sessionState().planner().strategies().contains( + self.spark._jvm.org.apache.spark.sql.MySparkStrategy(self.spark._jsparkSession)), + "MySparkStrategy not found in active planner strategies") + self.assertTrue( + self.spark._jsparkSession.sessionState().analyzer().extendedResolutionRules().contains( + self.spark._jvm.org.apache.spark.sql.MyRule(self.spark._jsparkSession)), + "MyRule not found in extended resolution rules") + + class SparkSessionTests(PySparkTestCase): # This test is separate because it's closely related with session's start and stop. From 67d9772e1f470bba83e311cd10cfe586a7f11b92 Mon Sep 17 00:00:00 2001 From: Russell Spitzer Date: Wed, 19 Sep 2018 08:16:44 -0500 Subject: [PATCH 3/6] SPARK-25003: Add helper methods to create new Extensions from Conf Previously the only way to add extensions to the session was via the getOrCreate method of the SparkSession Builder. To facilitate non-scala Session creation we add a new constructor which takes in just the context and Extensions. Then we also add a new Extensions constructor which given a SparkConf generates an Extensions object with user config already applied. --- python/pyspark/sql/session.py | 8 +++-- .../org/apache/spark/sql/SparkSession.scala | 22 +++--------- .../spark/sql/SparkSessionExtensions.scala | 35 +++++++++++++++++++ 3 files changed, 45 insertions(+), 20 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 16fe65a88f648..b510e90848710 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -212,15 +212,17 @@ def __init__(self, sparkContext, jsparkSession=None): self._sc = sparkContext self._jsc = self._sc._jsc self._jvm = self._sc._jvm + if jsparkSession is None: if self._jvm.SparkSession.getDefaultSession().isDefined() \ and not self._jvm.SparkSession.getDefaultSession().get() \ .sparkContext().isStopped(): jsparkSession = self._jvm.SparkSession.getDefaultSession().get() else: - jsparkSession = self._jvm.SparkSession.builder() \ - .sparkContext(self._jsc.sc()) \ - .getOrCreate() + extensions = self._sc._jvm.org.apache.spark.sql\ + .SparkSessionExtensions(self._jsc.getConf()) + jsparkSession = self._jvm.SparkSession(self._jsc.sc(), extensions) + self._jsparkSession = jsparkSession self._jwrapped = self._jsparkSession.sqlContext() self._wrapped = SQLContext(self._sc, self, self._jwrapped) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 565042fcf762e..8b2ea7491e454 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -88,6 +88,10 @@ class SparkSession private( this(sc, None, None, new SparkSessionExtensions) } + private[sql] def this(sc: SparkContext, extensions: SparkSessionExtensions) { + this(sc, None, None, extensions) + } + sparkContext.assertNotStopped() // If there is no active SparkSession, uses the default SQL conf. Otherwise, use the session's. @@ -935,23 +939,7 @@ object SparkSession extends Logging { // Do not update `SparkConf` for existing `SparkContext`, as it's shared by all sessions. } - // Initialize extensions if the user has defined a configurator class. - val extensionConfOption = sparkContext.conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS) - if (extensionConfOption.isDefined) { - val extensionConfClassName = extensionConfOption.get - try { - val extensionConfClass = Utils.classForName(extensionConfClassName) - val extensionConf = extensionConfClass.newInstance() - .asInstanceOf[SparkSessionExtensions => Unit] - extensionConf(extensions) - } catch { - // Ignore the error if we cannot find the class or when the class has the wrong type. - case e @ (_: ClassCastException | - _: ClassNotFoundException | - _: NoClassDefFoundError) => - logWarning(s"Cannot use $extensionConfClassName to configure session extensions.", e) - } - } + SparkSessionExtensions.applyExtensionsFromConf(sparkContext.conf, extensions) session = new SparkSession(sparkContext, None, None, extensions) options.foreach { case (k, v) => session.initialSessionOptions.put(k, v) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala index f99c108161f94..0fc37187e69f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala @@ -19,10 +19,14 @@ package org.apache.spark.sql import scala.collection.mutable +import org.apache.spark.SparkConf import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.StaticSQLConf +import org.apache.spark.util.Utils /** * :: Experimental :: @@ -66,6 +70,11 @@ class SparkSessionExtensions { type StrategyBuilder = SparkSession => Strategy type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface + private[sql] def this(conf: SparkConf) { + this() + SparkSessionExtensions.applyExtensionsFromConf(conf, this) + } + private[this] val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder] /** @@ -169,3 +178,29 @@ class SparkSessionExtensions { parserBuilders += builder } } + +object SparkSessionExtensions extends Logging { + + /** + * Initialize extensions if the user has defined a configurator class in their SparkConf. + * This class will be applied to the extensions passed into this function. + */ + private[sql] def applyExtensionsFromConf(conf: SparkConf, extensions: SparkSessionExtensions) { + val extensionConfOption = conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS) + if (extensionConfOption.isDefined) { + val extensionConfClassName = extensionConfOption.get + try { + val extensionConfClass = Utils.classForName(extensionConfClassName) + val extensionConf = extensionConfClass.newInstance() + .asInstanceOf[SparkSessionExtensions => Unit] + extensionConf(extensions) + } catch { + // Ignore the error if we cannot find the class or when the class has the wrong type. + case e@(_: ClassCastException | + _: ClassNotFoundException | + _: NoClassDefFoundError) => + logWarning(s"Cannot use $extensionConfClassName to configure session extensions.", e) + } + } + } +} From 4ddaff8fcd1bfedc221eda7d6f37e5e9f28bfb4e Mon Sep 17 00:00:00 2001 From: Russell Spitzer Date: Mon, 1 Oct 2018 12:55:01 -0500 Subject: [PATCH 4/6] SPARK-25003: Address Reviewer Comments --- python/pyspark/sql/session.py | 5 +-- .../org/apache/spark/sql/SparkSession.scala | 30 +++++++++++++--- .../spark/sql/SparkSessionExtensions.scala | 35 ------------------- 3 files changed, 26 insertions(+), 44 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index b510e90848710..e704de63394ea 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -212,16 +212,13 @@ def __init__(self, sparkContext, jsparkSession=None): self._sc = sparkContext self._jsc = self._sc._jsc self._jvm = self._sc._jvm - if jsparkSession is None: if self._jvm.SparkSession.getDefaultSession().isDefined() \ and not self._jvm.SparkSession.getDefaultSession().get() \ .sparkContext().isStopped(): jsparkSession = self._jvm.SparkSession.getDefaultSession().get() else: - extensions = self._sc._jvm.org.apache.spark.sql\ - .SparkSessionExtensions(self._jsc.getConf()) - jsparkSession = self._jvm.SparkSession(self._jsc.sc(), extensions) + jsparkSession = self._jvm.SparkSession(self._jsc.sc()) self._jsparkSession = jsparkSession self._jwrapped = self._jsparkSession.sqlContext() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 8b2ea7491e454..8a8c9b9b8e7f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -86,10 +86,7 @@ class SparkSession private( private[sql] def this(sc: SparkContext) { this(sc, None, None, new SparkSessionExtensions) - } - - private[sql] def this(sc: SparkContext, extensions: SparkSessionExtensions) { - this(sc, None, None, extensions) + SparkSession.applyExtensionsFromConf(sc.getConf, this.extensions) } sparkContext.assertNotStopped() @@ -939,7 +936,7 @@ object SparkSession extends Logging { // Do not update `SparkConf` for existing `SparkContext`, as it's shared by all sessions. } - SparkSessionExtensions.applyExtensionsFromConf(sparkContext.conf, extensions) + applyExtensionsFromConf(sparkContext.conf, extensions) session = new SparkSession(sparkContext, None, None, extensions) options.foreach { case (k, v) => session.initialSessionOptions.put(k, v) } @@ -1124,4 +1121,27 @@ object SparkSession extends Logging { SparkSession.clearDefaultSession() } } + + /** + * Initialize extensions if the user has defined a configurator class in their SparkConf. + * This class will be applied to the extensions passed into this function. + */ + private[sql] def applyExtensionsFromConf(conf: SparkConf, extensions: SparkSessionExtensions) { + val extensionConfOption = conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS) + if (extensionConfOption.isDefined) { + val extensionConfClassName = extensionConfOption.get + try { + val extensionConfClass = Utils.classForName(extensionConfClassName) + val extensionConf = extensionConfClass.newInstance() + .asInstanceOf[SparkSessionExtensions => Unit] + extensionConf(extensions) + } catch { + // Ignore the error if we cannot find the class or when the class has the wrong type. + case e@(_: ClassCastException | + _: ClassNotFoundException | + _: NoClassDefFoundError) => + logWarning(s"Cannot use $extensionConfClassName to configure session extensions.", e) + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala index 0fc37187e69f1..f99c108161f94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala @@ -19,14 +19,10 @@ package org.apache.spark.sql import scala.collection.mutable -import org.apache.spark.SparkConf import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} -import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.internal.StaticSQLConf -import org.apache.spark.util.Utils /** * :: Experimental :: @@ -70,11 +66,6 @@ class SparkSessionExtensions { type StrategyBuilder = SparkSession => Strategy type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface - private[sql] def this(conf: SparkConf) { - this() - SparkSessionExtensions.applyExtensionsFromConf(conf, this) - } - private[this] val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder] /** @@ -178,29 +169,3 @@ class SparkSessionExtensions { parserBuilders += builder } } - -object SparkSessionExtensions extends Logging { - - /** - * Initialize extensions if the user has defined a configurator class in their SparkConf. - * This class will be applied to the extensions passed into this function. - */ - private[sql] def applyExtensionsFromConf(conf: SparkConf, extensions: SparkSessionExtensions) { - val extensionConfOption = conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS) - if (extensionConfOption.isDefined) { - val extensionConfClassName = extensionConfOption.get - try { - val extensionConfClass = Utils.classForName(extensionConfClassName) - val extensionConf = extensionConfClass.newInstance() - .asInstanceOf[SparkSessionExtensions => Unit] - extensionConf(extensions) - } catch { - // Ignore the error if we cannot find the class or when the class has the wrong type. - case e@(_: ClassCastException | - _: ClassNotFoundException | - _: NoClassDefFoundError) => - logWarning(s"Cannot use $extensionConfClassName to configure session extensions.", e) - } - } - } -} From d9b2a55275b74c406d9f9c435bf1b53a6ef4b35a Mon Sep 17 00:00:00 2001 From: Russell Spitzer Date: Tue, 16 Oct 2018 16:13:12 -0500 Subject: [PATCH 5/6] SPARK-25003: More Refactoring Removes SparkConf from applyExtensions, now only accepts a Optional string which can contain a classname for extensions. Removed errant whitespace. --- python/pyspark/sql/session.py | 1 - python/pyspark/sql/tests.py | 2 +- .../org/apache/spark/sql/SparkSession.scala | 25 +++++++++++++------ 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index e704de63394ea..f1ad6b1212ed9 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -219,7 +219,6 @@ def __init__(self, sparkContext, jsparkSession=None): jsparkSession = self._jvm.SparkSession.getDefaultSession().get() else: jsparkSession = self._jvm.SparkSession(self._jsc.sc()) - self._jsparkSession = jsparkSession self._jwrapped = self._jsparkSession.sqlContext() self._wrapped = SQLContext(self._sc, self, self._jwrapped) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 44dd299a9a5b7..3016ffbed63fd 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3563,7 +3563,7 @@ def test_query_execution_listener_on_collect_with_arrow(self): "The callback from the query execution listener should be called after 'toPandas'") -class SparkExtensionsTest(unittest.TestCase, SQLTestUtils): +class SparkExtensionsTest(unittest.TestCase): # These tests are separate because it uses 'spark.sql.extensions' which is # static and immutable. This can't be set or unset, for example, via `spark.conf`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 8a8c9b9b8e7f8..ba25bc458e55c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -84,9 +84,17 @@ class SparkSession private( // The call site where this SparkSession was constructed. private val creationSite: CallSite = Utils.getCallSite() + /** + * Constructor used in Pyspark. Contains explicit application of Spark Session Extensions + * which otherwise only occurs during getOrCreate. We cannot add this to the default constructor + * since that would cause every new session to reinvoke Spark Session Extensions on the currently + * running extensions. + */ private[sql] def this(sc: SparkContext) { this(sc, None, None, new SparkSessionExtensions) - SparkSession.applyExtensionsFromConf(sc.getConf, this.extensions) + SparkSession.applyExtensions( + sc.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS), + this.extensions) } sparkContext.assertNotStopped() @@ -936,7 +944,9 @@ object SparkSession extends Logging { // Do not update `SparkConf` for existing `SparkContext`, as it's shared by all sessions. } - applyExtensionsFromConf(sparkContext.conf, extensions) + applyExtensions( + sparkContext.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS), + extensions) session = new SparkSession(sparkContext, None, None, extensions) options.foreach { case (k, v) => session.initialSessionOptions.put(k, v) } @@ -1123,13 +1133,12 @@ object SparkSession extends Logging { } /** - * Initialize extensions if the user has defined a configurator class in their SparkConf. - * This class will be applied to the extensions passed into this function. + * Initialize extensions for given extension classname. This class will be applied to the + * extensions passed into this function. */ - private[sql] def applyExtensionsFromConf(conf: SparkConf, extensions: SparkSessionExtensions) { - val extensionConfOption = conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS) - if (extensionConfOption.isDefined) { - val extensionConfClassName = extensionConfOption.get + private def applyExtensions(extensionOption: Option[String], extensions: SparkSessionExtensions) { + if (extensionOption.isDefined) { + val extensionConfClassName = extensionOption.get try { val extensionConfClass = Utils.classForName(extensionConfClassName) val extensionConf = extensionConfClass.newInstance() From 3629c78118eb77589ae1126e877f0fd2b57441ce Mon Sep 17 00:00:00 2001 From: Russell Spitzer Date: Wed, 17 Oct 2018 10:41:25 -0500 Subject: [PATCH 6/6] SPARK-25003: Refactor ApplyExtensions Function It now returns the extensions in modifies --- .../scala/org/apache/spark/sql/SparkSession.scala | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index ba25bc458e55c..1154f6c067176 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -91,10 +91,10 @@ class SparkSession private( * running extensions. */ private[sql] def this(sc: SparkContext) { - this(sc, None, None, new SparkSessionExtensions) - SparkSession.applyExtensions( - sc.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS), - this.extensions) + this(sc, None, None, + SparkSession.applyExtensions( + sc.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS), + new SparkSessionExtensions)) } sparkContext.assertNotStopped() @@ -1136,7 +1136,9 @@ object SparkSession extends Logging { * Initialize extensions for given extension classname. This class will be applied to the * extensions passed into this function. */ - private def applyExtensions(extensionOption: Option[String], extensions: SparkSessionExtensions) { + private def applyExtensions( + extensionOption: Option[String], + extensions: SparkSessionExtensions): SparkSessionExtensions = { if (extensionOption.isDefined) { val extensionConfClassName = extensionOption.get try { @@ -1152,5 +1154,6 @@ object SparkSession extends Logging { logWarning(s"Cannot use $extensionConfClassName to configure session extensions.", e) } } + extensions } }