Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def __init__(
self,
sparkContext: SparkContext,
jsparkSession: Optional[JavaObject] = None,
options: Optional[Dict[str, Any]] = None,
options: Optional[Dict[str, Any]] = {},
):
from pyspark.sql.context import SQLContext

Expand All @@ -305,10 +305,7 @@ def __init__(
):
jsparkSession = self._jvm.SparkSession.getDefaultSession().get()
else:
jsparkSession = self._jvm.SparkSession(self._jsc.sc())
if options is not None:
for key, value in options.items():
jsparkSession.sharedState().conf().set(key, value)
jsparkSession = self._jvm.SparkSession(self._jsc.sc(), options)
self._jsparkSession = jsparkSession
self._jwrapped = self._jsparkSession.sqlContext()
self._wrapped = SQLContext(self._sc, self, self._jwrapped)
Expand Down
11 changes: 8 additions & 3 deletions python/pyspark/sql/tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,18 +289,23 @@ def test_another_spark_session(self):
if session2 is not None:
session2.stop()

def test_create_spark_context_first_and_copy_options_to_sharedState(self):
def test_create_spark_context_with_initial_session_options(self):
sc = None
session = None
try:
conf = SparkConf().set("key1", "value1")
sc = SparkContext("local[4]", "SessionBuilderTests", conf=conf)
session = (
SparkSession.builder.config("key2", "value2").enableHiveSupport().getOrCreate()
SparkSession.builder.config("spark.sql.codegen.comments", "true")
.enableHiveSupport()
.getOrCreate()
)

self.assertEqual(session._jsparkSession.sharedState().conf().get("key1"), "value1")
self.assertEqual(session._jsparkSession.sharedState().conf().get("key2"), "value2")
self.assertEqual(
session._jsparkSession.sharedState().conf().get("spark.sql.codegen.comments"),
"true",
)
self.assertEqual(
session._jsparkSession.sharedState().conf().get("spark.sql.catalogImplementation"),
"hive",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,17 @@ class SparkSession private(
* since that would cause every new session to reinvoke Spark Session Extensions on the currently
* running extensions.
*/
private[sql] def this(sc: SparkContext) = {
private[sql] def this(
sc: SparkContext,
initialSessionOptions: java.util.HashMap[String, String]) = {
this(sc, None, None,
SparkSession.applyExtensions(
sc.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS).getOrElse(Seq.empty),
new SparkSessionExtensions), Map.empty)
new SparkSessionExtensions), initialSessionOptions.asScala.toMap)
}

private[sql] def this(sc: SparkContext) = this(sc, new java.util.HashMap[String, String]())

private[sql] val sessionUUID: String = UUID.randomUUID.toString

sparkContext.assertNotStopped()
Expand Down