From 1a6cfcea2fd23a2a8b7cd0604507a8eb502962a6 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 16 Mar 2018 13:09:58 +0900 Subject: [PATCH 1/3] spark.conf.get(value, default=None) should produce None in PySpark --- python/pyspark/sql/conf.py | 9 +++++---- python/pyspark/sql/context.py | 4 ++-- python/pyspark/sql/tests.py | 4 ++++ 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py index d929834aeeaa5..b82224b6194ed 100644 --- a/python/pyspark/sql/conf.py +++ b/python/pyspark/sql/conf.py @@ -17,7 +17,7 @@ import sys -from pyspark import since +from pyspark import since, _NoValue from pyspark.rdd import ignore_unicode_prefix @@ -39,15 +39,16 @@ def set(self, key, value): @ignore_unicode_prefix @since(2.0) - def get(self, key, default=None): + def get(self, key, default=_NoValue): """Returns the value of Spark runtime configuration property for the given key, assuming it is set. """ self._checkType(key, "key") - if default is None: + if default is _NoValue: return self._jconf.get(key) else: - self._checkType(default, "default") + if default is not None: + self._checkType(default, "default") return self._jconf.get(key, default) @ignore_unicode_prefix diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 6cb90399dd616..e1fb8bd45ef82 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -22,7 +22,7 @@ if sys.version >= '3': basestring = unicode = str -from pyspark import since +from pyspark import since, _NoValue from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.session import _monkey_patch_RDD, SparkSession from pyspark.sql.dataframe import DataFrame @@ -124,7 +124,7 @@ def setConf(self, key, value): @ignore_unicode_prefix @since(1.3) - def getConf(self, key, defaultValue=None): + def getConf(self, key, defaultValue=_NoValue): """Returns the value of Spark SQL configuration property for the given key. If the key is not set and defaultValue is not None, return diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 480815d27333f..7cea8f78766fc 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2504,6 +2504,10 @@ def test_conf(self): spark.conf.unset("bogo") self.assertEqual(spark.conf.get("bogo", "colombia"), "colombia") + self.assertRaisesRegexp(Exception, "hyukjin", lambda: spark.conf.get("hyukjin")) + self.assertEqual(spark.conf.get("hyukjin", None), None) + self.assertEqual(spark.conf.get("spark.sql.sources.partitionOverwriteMode", None), None) + def test_current_database(self): spark = self.spark spark.catalog._reset() From d8ee18fab1a6183dfffa6e070e852fda67e1d809 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 16 Mar 2018 15:47:18 +0900 Subject: [PATCH 2/3] Fix few words --- python/pyspark/sql/context.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index e1fb8bd45ef82..e9ec7ba866761 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -127,8 +127,8 @@ def setConf(self, key, value): def getConf(self, key, defaultValue=_NoValue): """Returns the value of Spark SQL configuration property for the given key. - If the key is not set and defaultValue is not None, return - defaultValue. If the key is not set and defaultValue is None, return + If the key is not set and defaultValue is set, return + defaultValue. If the key is not set and defaultValue is not set, return the system default value. >>> sqlContext.getConf("spark.sql.shuffle.partitions") From 3a8de043162adc250aa734f747cece1bb161e20e Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 17 Mar 2018 14:07:56 +0900 Subject: [PATCH 3/3] Adds one more test with some comments --- python/pyspark/sql/tests.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 7cea8f78766fc..a0d547ad620e5 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2504,8 +2504,15 @@ def test_conf(self): spark.conf.unset("bogo") self.assertEqual(spark.conf.get("bogo", "colombia"), "colombia") - self.assertRaisesRegexp(Exception, "hyukjin", lambda: spark.conf.get("hyukjin")) self.assertEqual(spark.conf.get("hyukjin", None), None) + + # This returns 'STATIC' because it's the default value of + # 'spark.sql.sources.partitionOverwriteMode', and `defaultValue` in + # `spark.conf.get` is unset. + self.assertEqual(spark.conf.get("spark.sql.sources.partitionOverwriteMode"), "STATIC") + + # This returns None because 'spark.sql.sources.partitionOverwriteMode' is unset, but + # `defaultValue` in `spark.conf.get` is set to None. self.assertEqual(spark.conf.get("spark.sql.sources.partitionOverwriteMode", None), None) def test_current_database(self):