From 5bfe614e61b9b5d0d903a65e8c2eefe24bed4b3c Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 30 Aug 2018 17:50:23 -0700 Subject: [PATCH 01/14] [SPARK-25255][PYTHON]Add getActiveSession to SparkSession in PySpark --- python/pyspark/sql/session.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 079af8c05705d..0be9e6c697437 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -255,6 +255,16 @@ def newSession(self): """ return self.__class__(self._sc, self._jsparkSession.newSession()) + @since(2.4) + def getActiveSession(self): + """ + Returns the active SparkSession for the current thread, returned by the builder. + >>> s = spark.getActiveSession() + >>> spark._jsparkSession.getDefaultSession().get().equals(s.get()) + True + """ + return self._jsparkSession.getActiveSession() + @property @since(2.0) def sparkContext(self): From 9048a36cfeba3d0d02ed2acc9d9551f421b44d80 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 30 Aug 2018 18:31:10 -0700 Subject: [PATCH 02/14] fix python style error --- python/pyspark/sql/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 0be9e6c697437..c00fe4541d8fa 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -264,7 +264,7 @@ def getActiveSession(self): True """ return self._jsparkSession.getActiveSession() - + @property @since(2.0) def sparkContext(self): From 221ea01cd862b2c79ba7eb113b1fcdbd4425233a Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Fri, 7 Sep 2018 18:02:49 -0700 Subject: [PATCH 03/14] address comments --- python/pyspark/sql/session.py | 13 ++++- python/pyspark/sql/tests.py | 101 ++++++++++++++++++++++++++++++++++ 2 files changed, 111 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index c00fe4541d8fa..453842e88278a 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -234,6 +234,7 @@ def __init__(self, sparkContext, jsparkSession=None): or SparkSession._instantiatedSession._sc._jsc is None: SparkSession._instantiatedSession = self self._jvm.SparkSession.setDefaultSession(self._jsparkSession) + self._jvm.SparkSession.setActiveSession(self._jsparkSession) def _repr_html_(self): return """ @@ -260,10 +261,16 @@ def getActiveSession(self): """ Returns the active SparkSession for the current thread, returned by the builder. >>> s = spark.getActiveSession() - >>> spark._jsparkSession.getDefaultSession().get().equals(s.get()) - True + >>> l = [('Alice', 1)] + >>> rdd = s.sparkContext.parallelize(l) + >>> df = spark.createDataFrame(rdd, ['name', 'age']) + >>> df.collect() + [Row(name=u'Alice', age=1)] """ - return self._jsparkSession.getActiveSession() + if self._jsparkSession.getActiveSession().isDefined(): + return self.__class__(self._sc, self._jsparkSession.getActiveSession().get()) + else: + return None @property @since(2.0) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 85712df5f2ad1..5685d0460d3cd 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3863,6 +3863,107 @@ def test_jvm_default_session_already_set(self): spark.stop() +class SparkSessionTests2(ReusedSQLTestCase): + + def test_active_session(self): + spark = SparkSession.builder \ + .master("local") \ + .getOrCreate() + try: + activeSession = spark.getActiveSession() + df = activeSession.createDataFrame([(1, 'Alice')], ['age', 'name']) + self.assertEqual(df.collect(), [Row(age=1, name=u'Alice')]) + finally: + spark.stop() + + def test_SparkSession(self): + spark = SparkSession.builder \ + .master("local") \ + .config("some-config", "v2") \ + .getOrCreate() + try: + self.assertEqual(spark.conf.get("some-config"), "v2") + self.assertEqual(spark.sparkContext._conf.get("some-config"), "v2") + self.assertEqual(spark.version, spark.sparkContext.version) + spark.sql("CREATE DATABASE test_db") + spark.catalog.setCurrentDatabase("test_db") + self.assertEqual(spark.catalog.currentDatabase(), "test_db") + spark.sql("CREATE TABLE table1 (name STRING, age INT) USING parquet") + self.assertEqual(spark.table("table1").columns, ['name', 'age']) + self.assertEqual(spark.range(3).count(), 3) + finally: + spark.stop() + + def test_global_default_session(self): + spark = SparkSession.builder \ + .master("local") \ + .getOrCreate() + try: + self.assertEqual(SparkSession.builder.getOrCreate(), spark) + finally: + spark.stop() + + def test_default_and_active_session(self): + spark = SparkSession.builder \ + .master("local") \ + .getOrCreate() + activeSession = spark._jvm.SparkSession.getActiveSession() + defaultSession = spark._jvm.SparkSession.getDefaultSession() + try: + self.assertEqual(activeSession, defaultSession) + finally: + spark.stop() + + def test_config_option_propagated_to_existing_SparkSession(self): + session1 = SparkSession.builder \ + .master("local") \ + .config("spark-config1", "a") \ + .getOrCreate() + self.assertEqual(session1.conf.get("spark-config1"), "a") + session2 = SparkSession.builder \ + .config("spark-config1", "b") \ + .getOrCreate() + try: + self.assertEqual(session1, session2) + self.assertEqual(session1.conf.get("spark-config1"), "b") + finally: + session1.stop() + + def test_newSession(self): + session = SparkSession.builder \ + .master("local") \ + .getOrCreate() + newSession = session.newSession() + try: + self.assertNotEqual(session, newSession) + finally: + session.stop() + newSession.stop() + + def test_create_new_session_if_old_session_stopped(self): + session = SparkSession.builder \ + .master("local") \ + .getOrCreate() + session.stop() + newSession = SparkSession.builder \ + .master("local") \ + .getOrCreate() + try: + self.assertNotEqual(session, newSession) + finally: + newSession.stop() + + def test_create_SparkContext_then_SparkSession(self): + sc = SparkContext('local', 'test') + session = SparkSession.builder \ + .config("key1", "value1") \ + .getOrCreate() + self.assertEqual(session.conf.get("key1"), "value1") + self.assertEqual(session.sparkContext, sc) + self.assertEqual(sc._conf.get("key1"), "value1") + session.stop() + + class UDFInitializationTests(unittest.TestCase): def tearDown(self): if SparkSession._instantiatedSession is not None: From 6f890661b341d8e5bae5496e29cd3cc5c423cad7 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Mon, 10 Sep 2018 11:07:01 -0700 Subject: [PATCH 04/14] fix test failure --- python/pyspark/sql/session.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 453842e88278a..f1c1d56d3568f 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -264,8 +264,8 @@ def getActiveSession(self): >>> l = [('Alice', 1)] >>> rdd = s.sparkContext.parallelize(l) >>> df = spark.createDataFrame(rdd, ['name', 'age']) - >>> df.collect() - [Row(name=u'Alice', age=1)] + >>> df.select("age").collect() + [Row(age=1)] """ if self._jsparkSession.getActiveSession().isDefined(): return self.__class__(self._sc, self._jsparkSession.getActiveSession().get()) From c223dd2fd3e7fa255a08091f3c84eb7a08dfd671 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 11 Sep 2018 08:32:11 -0700 Subject: [PATCH 05/14] change the target version to 3.0 --- python/pyspark/sql/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index f1c1d56d3568f..1f92b88ef6780 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -256,7 +256,7 @@ def newSession(self): """ return self.__class__(self._sc, self._jsparkSession.newSession()) - @since(2.4) + @since(3.0) def getActiveSession(self): """ Returns the active SparkSession for the current thread, returned by the builder. From 091b1d52c4ed6072d5db9bde11a3bcdba2729dad Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Fri, 21 Sep 2018 10:51:09 -0700 Subject: [PATCH 06/14] address comments --- python/pyspark/sql/session.py | 1 + python/pyspark/sql/tests.py | 12 ++++++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 1f92b88ef6780..a716f45ce638e 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -843,6 +843,7 @@ def stop(self): self._sc.stop() # We should clean the default session up. See SPARK-23228. self._jvm.SparkSession.clearDefaultSession() + self._jvm.SparkSession.clearActiveSession() SparkSession._instantiatedSession = None @since(2.0) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 5685d0460d3cd..64b2f7babd6d5 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3863,7 +3863,7 @@ def test_jvm_default_session_already_set(self): spark.stop() -class SparkSessionTests2(ReusedSQLTestCase): +class SparkSessionTests2(unittest.TestCase): def test_active_session(self): spark = SparkSession.builder \ @@ -3876,6 +3876,14 @@ def test_active_session(self): finally: spark.stop() + def test_get_active_session_when_no_active_session(self): + spark = SparkSession.builder \ + .master("local") \ + .getOrCreate() + spark.stop() + active = spark.getActiveSession() + self.assertEqual(active, None) + def test_SparkSession(self): spark = SparkSession.builder \ .master("local") \ @@ -3929,7 +3937,7 @@ def test_config_option_propagated_to_existing_SparkSession(self): finally: session1.stop() - def test_newSession(self): + def test_new_session(self): session = SparkSession.builder \ .master("local") \ .getOrCreate() From 1cda049e01651031f8192bd925ba31821049fe01 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 27 Sep 2018 09:47:55 -0700 Subject: [PATCH 07/14] change version to 2.5 --- python/pyspark/sql/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index a716f45ce638e..ae97ca7d80092 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -256,7 +256,7 @@ def newSession(self): """ return self.__class__(self._sc, self._jsparkSession.newSession()) - @since(3.0) + @since(2.5) def getActiveSession(self): """ Returns the active SparkSession for the current thread, returned by the builder. From 69b29e90b17c07f366c3a58733c3a55de7e93bdb Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 27 Sep 2018 11:08:04 -0700 Subject: [PATCH 08/14] remove test_create_SparkContext_then_SparkSession --- python/pyspark/sql/tests.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 64b2f7babd6d5..6615e9284e3fe 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3961,16 +3961,6 @@ def test_create_new_session_if_old_session_stopped(self): finally: newSession.stop() - def test_create_SparkContext_then_SparkSession(self): - sc = SparkContext('local', 'test') - session = SparkSession.builder \ - .config("key1", "value1") \ - .getOrCreate() - self.assertEqual(session.conf.get("key1"), "value1") - self.assertEqual(session.sparkContext, sc) - self.assertEqual(sc._conf.get("key1"), "value1") - session.stop() - class UDFInitializationTests(unittest.TestCase): def tearDown(self): From d8fef1c0fe6687b0d82f5071af98d500d229820c Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Mon, 1 Oct 2018 14:57:27 -0700 Subject: [PATCH 09/14] change getActiveSession to class method --- python/pyspark/sql/session.py | 16 +++++++++------- python/pyspark/sql/tests.py | 8 ++++++-- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index ae97ca7d80092..084f0dff4fd2b 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -192,6 +192,7 @@ def getOrCreate(self): """A class attribute having a :class:`Builder` to construct :class:`SparkSession` instances""" _instantiatedSession = None + _activeSession = None @ignore_unicode_prefix def __init__(self, sparkContext, jsparkSession=None): @@ -233,6 +234,7 @@ def __init__(self, sparkContext, jsparkSession=None): if SparkSession._instantiatedSession is None \ or SparkSession._instantiatedSession._sc._jsc is None: SparkSession._instantiatedSession = self + SparkSession._activeSession = self self._jvm.SparkSession.setDefaultSession(self._jsparkSession) self._jvm.SparkSession.setActiveSession(self._jsparkSession) @@ -256,21 +258,20 @@ def newSession(self): """ return self.__class__(self._sc, self._jsparkSession.newSession()) + @classmethod @since(2.5) - def getActiveSession(self): + def getActiveSession(cls): """ Returns the active SparkSession for the current thread, returned by the builder. - >>> s = spark.getActiveSession() + >>> s = SparkSession.getActiveSession() >>> l = [('Alice', 1)] >>> rdd = s.sparkContext.parallelize(l) - >>> df = spark.createDataFrame(rdd, ['name', 'age']) + >>> df = s.createDataFrame(rdd, ['name', 'age']) + >>> df.show() >>> df.select("age").collect() [Row(age=1)] """ - if self._jsparkSession.getActiveSession().isDefined(): - return self.__class__(self._sc, self._jsparkSession.getActiveSession().get()) - else: - return None + return cls._activeSession @property @since(2.0) @@ -845,6 +846,7 @@ def stop(self): self._jvm.SparkSession.clearDefaultSession() self._jvm.SparkSession.clearActiveSession() SparkSession._instantiatedSession = None + SparkSession._activeSession = None @since(2.0) def __enter__(self): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6615e9284e3fe..a3d7331c4110c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3870,18 +3870,22 @@ def test_active_session(self): .master("local") \ .getOrCreate() try: - activeSession = spark.getActiveSession() + activeSession = SparkSession.getActiveSession() df = activeSession.createDataFrame([(1, 'Alice')], ['age', 'name']) self.assertEqual(df.collect(), [Row(age=1, name=u'Alice')]) finally: spark.stop() def test_get_active_session_when_no_active_session(self): + active = SparkSession.getActiveSession() + self.assertEqual(active, None) spark = SparkSession.builder \ .master("local") \ .getOrCreate() + active = SparkSession.getActiveSession() + self.assertEqual(active, spark) spark.stop() - active = spark.getActiveSession() + active = SparkSession.getActiveSession() self.assertEqual(active, None) def test_SparkSession(self): From 59ad7a70c9694f7d66b3589dc86de5773db6844f Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Mon, 1 Oct 2018 15:35:52 -0700 Subject: [PATCH 10/14] fix test failure --- python/pyspark/sql/session.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 084f0dff4fd2b..62e1cd4a772d2 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -267,7 +267,6 @@ def getActiveSession(cls): >>> l = [('Alice', 1)] >>> rdd = s.sparkContext.parallelize(l) >>> df = s.createDataFrame(rdd, ['name', 'age']) - >>> df.show() >>> df.select("age").collect() [Row(age=1)] """ From f2949f100c8395aeba98610c230a7d620d7ddbc5 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Mon, 8 Oct 2018 11:01:28 -0700 Subject: [PATCH 11/14] change getActiveSession to class method (2) --- python/pyspark/sql/functions.py | 17 ++++++++ python/pyspark/sql/session.py | 5 ++- .../scala/org/apache/spark/sql/MyTest.scala | 41 +++++++++++++++++++ 3 files changed, 61 insertions(+), 2 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/MyTest.scala diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 32d7f02f61883..28ab9b55949ac 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2713,6 +2713,23 @@ def from_csv(col, schema, options={}): return Column(jc) +@since(3.0) +def getActiveSession(): + """ + Returns the active SparkSession for the current thread + """ + from pyspark.sql import SparkSession + sc = SparkContext._active_spark_context + if sc is None: + sc = SparkContext() + + if sc._jvm.SparkSession.getActiveSession().isDefined(): + SparkSession(sc, sc._jvm.SparkSession.getActiveSession().get()) + return SparkSession._activeSession + else: + return None + + # ---------------------------- User Defined Function ---------------------------------- class PandasUDFType(object): diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 62e1cd4a772d2..b0bbb20e0cdab 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -259,7 +259,7 @@ def newSession(self): return self.__class__(self._sc, self._jsparkSession.newSession()) @classmethod - @since(2.5) + @since(3.0) def getActiveSession(cls): """ Returns the active SparkSession for the current thread, returned by the builder. @@ -270,7 +270,8 @@ def getActiveSession(cls): >>> df.select("age").collect() [Row(age=1)] """ - return cls._activeSession + from pyspark.sql import functions + return functions.getActiveSession() @property @since(2.0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MyTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/MyTest.scala new file mode 100644 index 0000000000000..0b273fe979a54 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/MyTest.scala @@ -0,0 +1,41 @@ +package org.apache.spark.sql + + + +import java.util.Properties + +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} + +object MyTest { + def main(args: Array[String]) { + // $example on:init_session$ + val spark = SparkSession + .builder() + .appName("Spark SQL basic example") + .config("spark.some.config.option", "some-value") + .getOrCreate() + val sparkContext = spark.sparkContext + val url = "jdbc:h2:mem:testdb2" + var conn: java.sql.Connection = null + val url1 = "jdbc:h2:mem:testdb3" + var conn1: java.sql.Connection = null + val properties = new Properties() + properties.setProperty("user", "testUser") + properties.setProperty("password", "testPass") + properties.setProperty("rowId", "false") + + + val arr2x2 = Array[Row](Row.apply("dave", 42), Row.apply("mary", 222)) + val arr1x2 = Array[Row](Row.apply("fred", 3)) + val schema2 = StructType( + StructField("name", StringType) :: + StructField("id", IntegerType) :: Nil) + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2) + + df.write.mode(SaveMode.Ignore).jdbc(url1, "TEST.DROPTEST", properties) + spark.read.jdbc(url1, "TEST.DROPTEST", properties).count() + spark.read.jdbc(url1, "TEST.DROPTEST", properties).collect() + spark.stop() + } +} From 7c6d2d57363c0cef2eeb80228cc2a2f14ec9b226 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 16 Oct 2018 11:55:57 -0700 Subject: [PATCH 12/14] address comments --- python/pyspark/sql/functions.py | 16 ++++++++------- python/pyspark/sql/session.py | 4 +++- python/pyspark/sql/tests.py | 36 +++++++++++++++++++++++++++++++++ 3 files changed, 48 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 28ab9b55949ac..838dabf6c31eb 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2714,20 +2714,22 @@ def from_csv(col, schema, options={}): @since(3.0) -def getActiveSession(): +def _getActiveSession(): """ Returns the active SparkSession for the current thread + This method is not intended for user to call directly. + It is only used for getActiveSession method in session.py """ from pyspark.sql import SparkSession sc = SparkContext._active_spark_context if sc is None: - sc = SparkContext() - - if sc._jvm.SparkSession.getActiveSession().isDefined(): - SparkSession(sc, sc._jvm.SparkSession.getActiveSession().get()) - return SparkSession._activeSession - else: return None + else: + if sc._jvm.SparkSession.getActiveSession().isDefined(): + SparkSession(sc, sc._jvm.SparkSession.getActiveSession().get()) + return SparkSession._activeSession + else: + return None # ---------------------------- User Defined Function ---------------------------------- diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index b0bbb20e0cdab..e739c6ec82f4a 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -271,7 +271,7 @@ def getActiveSession(cls): [Row(age=1)] """ from pyspark.sql import functions - return functions.getActiveSession() + return functions._getActiveSession() @property @since(2.0) @@ -689,6 +689,8 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr ... Py4JJavaError: ... """ + SparkSession._activeSession = self + self._jvm.SparkSession.setActiveSession(self._jsparkSession) if isinstance(data, DataFrame): raise TypeError("data is already a DataFrame") diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a3d7331c4110c..f6abf8c0bedf4 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3965,6 +3965,42 @@ def test_create_new_session_if_old_session_stopped(self): finally: newSession.stop() + def test_active_session_with_None_and_not_None_context(self): + from pyspark.context import SparkContext + from pyspark.conf import SparkConf + sc = SparkContext._active_spark_context + self.assertEqual(sc, None) + activeSession = SparkSession.getActiveSession() + self.assertEqual(activeSession, None) + sparkConf = SparkConf() + sc = SparkContext.getOrCreate(sparkConf) + activeSession = sc._jvm.SparkSession.getActiveSession() + self.assertFalse(activeSession.isDefined()) + session = SparkSession(sc) + activeSession = sc._jvm.SparkSession.getActiveSession() + self.assertTrue(activeSession.isDefined()) + activeSession2 = SparkSession.getActiveSession() + self.assertNotEqual(activeSession2, None) + + +class SparkSessionTests3(ReusedSQLTestCase): + + def test_get_active_session_after_create_dataframe(self): + activeSession1 = SparkSession.getActiveSession() + session1 = self.spark + self.assertEqual(session1, activeSession1) + session2 = self.spark.newSession() + activeSession2 = SparkSession.getActiveSession() + self.assertEqual(session1, activeSession2) + self.assertNotEqual(session2, activeSession2) + session2.createDataFrame([(1, 'Alice')], ['age', 'name']) + activeSession3 = SparkSession.getActiveSession() + self.assertEqual(session2, activeSession3) + session1.createDataFrame([(1, 'Alice')], ['age', 'name']) + activeSession4 = SparkSession.getActiveSession() + self.assertEqual(session1, activeSession4) + session2.stop() + class UDFInitializationTests(unittest.TestCase): def tearDown(self): From fb474328d8a80d9763a2e3f44611f9a1dc2b47dd Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 18 Oct 2018 19:53:14 -0700 Subject: [PATCH 13/14] address comments --- python/pyspark/sql/functions.py | 19 --------- python/pyspark/sql/session.py | 12 +++++- python/pyspark/sql/tests.py | 68 +++++++++++++++++++-------------- 3 files changed, 50 insertions(+), 49 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 838dabf6c31eb..32d7f02f61883 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2713,25 +2713,6 @@ def from_csv(col, schema, options={}): return Column(jc) -@since(3.0) -def _getActiveSession(): - """ - Returns the active SparkSession for the current thread - This method is not intended for user to call directly. - It is only used for getActiveSession method in session.py - """ - from pyspark.sql import SparkSession - sc = SparkContext._active_spark_context - if sc is None: - return None - else: - if sc._jvm.SparkSession.getActiveSession().isDefined(): - SparkSession(sc, sc._jvm.SparkSession.getActiveSession().get()) - return SparkSession._activeSession - else: - return None - - # ---------------------------- User Defined Function ---------------------------------- class PandasUDFType(object): diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index e739c6ec82f4a..6f4b32757314d 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -270,8 +270,16 @@ def getActiveSession(cls): >>> df.select("age").collect() [Row(age=1)] """ - from pyspark.sql import functions - return functions._getActiveSession() + from pyspark import SparkContext + sc = SparkContext._active_spark_context + if sc is None: + return None + else: + if sc._jvm.SparkSession.getActiveSession().isDefined(): + SparkSession(sc, sc._jvm.SparkSession.getActiveSession().get()) + return SparkSession._activeSession + else: + return None @property @since(2.0) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index f6abf8c0bedf4..2c30c3dfc2086 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3926,7 +3926,7 @@ def test_default_and_active_session(self): finally: spark.stop() - def test_config_option_propagated_to_existing_SparkSession(self): + def test_config_option_propagated_to_existing_session(self): session1 = SparkSession.builder \ .master("local") \ .config("spark-config1", "a") \ @@ -3968,38 +3968,50 @@ def test_create_new_session_if_old_session_stopped(self): def test_active_session_with_None_and_not_None_context(self): from pyspark.context import SparkContext from pyspark.conf import SparkConf - sc = SparkContext._active_spark_context - self.assertEqual(sc, None) - activeSession = SparkSession.getActiveSession() - self.assertEqual(activeSession, None) - sparkConf = SparkConf() - sc = SparkContext.getOrCreate(sparkConf) - activeSession = sc._jvm.SparkSession.getActiveSession() - self.assertFalse(activeSession.isDefined()) - session = SparkSession(sc) - activeSession = sc._jvm.SparkSession.getActiveSession() - self.assertTrue(activeSession.isDefined()) - activeSession2 = SparkSession.getActiveSession() - self.assertNotEqual(activeSession2, None) + sc = None + session = None + try: + sc = SparkContext._active_spark_context + self.assertEqual(sc, None) + activeSession = SparkSession.getActiveSession() + self.assertEqual(activeSession, None) + sparkConf = SparkConf() + sc = SparkContext.getOrCreate(sparkConf) + activeSession = sc._jvm.SparkSession.getActiveSession() + self.assertFalse(activeSession.isDefined()) + session = SparkSession(sc) + activeSession = sc._jvm.SparkSession.getActiveSession() + self.assertTrue(activeSession.isDefined()) + activeSession2 = SparkSession.getActiveSession() + self.assertNotEqual(activeSession2, None) + finally: + if session is not None: + session.stop() + if sc is not None: + sc.stop() class SparkSessionTests3(ReusedSQLTestCase): def test_get_active_session_after_create_dataframe(self): - activeSession1 = SparkSession.getActiveSession() - session1 = self.spark - self.assertEqual(session1, activeSession1) - session2 = self.spark.newSession() - activeSession2 = SparkSession.getActiveSession() - self.assertEqual(session1, activeSession2) - self.assertNotEqual(session2, activeSession2) - session2.createDataFrame([(1, 'Alice')], ['age', 'name']) - activeSession3 = SparkSession.getActiveSession() - self.assertEqual(session2, activeSession3) - session1.createDataFrame([(1, 'Alice')], ['age', 'name']) - activeSession4 = SparkSession.getActiveSession() - self.assertEqual(session1, activeSession4) - session2.stop() + session2 = None + try: + activeSession1 = SparkSession.getActiveSession() + session1 = self.spark + self.assertEqual(session1, activeSession1) + session2 = self.spark.newSession() + activeSession2 = SparkSession.getActiveSession() + self.assertEqual(session1, activeSession2) + self.assertNotEqual(session2, activeSession2) + session2.createDataFrame([(1, 'Alice')], ['age', 'name']) + activeSession3 = SparkSession.getActiveSession() + self.assertEqual(session2, activeSession3) + session1.createDataFrame([(1, 'Alice')], ['age', 'name']) + activeSession4 = SparkSession.getActiveSession() + self.assertEqual(session1, activeSession4) + finally: + if session2 is not None: + session2.stop() class UDFInitializationTests(unittest.TestCase): From 94e3db0c0c9873daaca688c2a63f01420882692e Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 18 Oct 2018 19:56:29 -0700 Subject: [PATCH 14/14] Delete MyTest.scala --- .../scala/org/apache/spark/sql/MyTest.scala | 41 ------------------- 1 file changed, 41 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/MyTest.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MyTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/MyTest.scala deleted file mode 100644 index 0b273fe979a54..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/MyTest.scala +++ /dev/null @@ -1,41 +0,0 @@ -package org.apache.spark.sql - - - -import java.util.Properties - -import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} - -object MyTest { - def main(args: Array[String]) { - // $example on:init_session$ - val spark = SparkSession - .builder() - .appName("Spark SQL basic example") - .config("spark.some.config.option", "some-value") - .getOrCreate() - val sparkContext = spark.sparkContext - val url = "jdbc:h2:mem:testdb2" - var conn: java.sql.Connection = null - val url1 = "jdbc:h2:mem:testdb3" - var conn1: java.sql.Connection = null - val properties = new Properties() - properties.setProperty("user", "testUser") - properties.setProperty("password", "testPass") - properties.setProperty("rowId", "false") - - - val arr2x2 = Array[Row](Row.apply("dave", 42), Row.apply("mary", 222)) - val arr1x2 = Array[Row](Row.apply("fred", 3)) - val schema2 = StructType( - StructField("name", StringType) :: - StructField("id", IntegerType) :: Nil) - val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) - val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2) - - df.write.mode(SaveMode.Ignore).jdbc(url1, "TEST.DROPTEST", properties) - spark.read.jdbc(url1, "TEST.DROPTEST", properties).count() - spark.read.jdbc(url1, "TEST.DROPTEST", properties).collect() - spark.stop() - } -}