diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 927554198743e..f94b9c2115044 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -291,7 +291,7 @@ def __init__( self, sparkContext: SparkContext, jsparkSession: Optional[JavaObject] = None, - options: Optional[Dict[str, Any]] = None, + options: Optional[Dict[str, Any]] = {}, ): from pyspark.sql.context import SQLContext @@ -305,10 +305,7 @@ def __init__( ): jsparkSession = self._jvm.SparkSession.getDefaultSession().get() else: - jsparkSession = self._jvm.SparkSession(self._jsc.sc()) - if options is not None: - for key, value in options.items(): - jsparkSession.sharedState().conf().set(key, value) + jsparkSession = self._jvm.SparkSession(self._jsc.sc(), options) self._jsparkSession = jsparkSession self._jwrapped = self._jsparkSession.sqlContext() self._wrapped = SQLContext(self._sc, self, self._jwrapped) diff --git a/python/pyspark/sql/tests/test_session.py b/python/pyspark/sql/tests/test_session.py index eb23b68ccf498..06771fac896ba 100644 --- a/python/pyspark/sql/tests/test_session.py +++ b/python/pyspark/sql/tests/test_session.py @@ -289,18 +289,23 @@ def test_another_spark_session(self): if session2 is not None: session2.stop() - def test_create_spark_context_first_and_copy_options_to_sharedState(self): + def test_create_spark_context_with_initial_session_options(self): sc = None session = None try: conf = SparkConf().set("key1", "value1") sc = SparkContext("local[4]", "SessionBuilderTests", conf=conf) session = ( - SparkSession.builder.config("key2", "value2").enableHiveSupport().getOrCreate() + SparkSession.builder.config("spark.sql.codegen.comments", "true") + .enableHiveSupport() + .getOrCreate() ) self.assertEqual(session._jsparkSession.sharedState().conf().get("key1"), "value1") - self.assertEqual(session._jsparkSession.sharedState().conf().get("key2"), "value2") + self.assertEqual( + session._jsparkSession.sharedState().conf().get("spark.sql.codegen.comments"), + "true", + ) self.assertEqual( session._jsparkSession.sharedState().conf().get("spark.sql.catalogImplementation"), "hive", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 63812b873ba8e..df110aa269e7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -97,13 +97,17 @@ class SparkSession private( * since that would cause every new session to reinvoke Spark Session Extensions on the currently * running extensions. */ - private[sql] def this(sc: SparkContext) = { + private[sql] def this( + sc: SparkContext, + initialSessionOptions: java.util.HashMap[String, String]) = { this(sc, None, None, SparkSession.applyExtensions( sc.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS).getOrElse(Seq.empty), - new SparkSessionExtensions), Map.empty) + new SparkSessionExtensions), initialSessionOptions.asScala.toMap) } + private[sql] def this(sc: SparkContext) = this(sc, new java.util.HashMap[String, String]()) + private[sql] val sessionUUID: String = UUID.randomUUID.toString sparkContext.assertNotStopped()