From f32350f9ed7c360e62d58441447dbd2fd05dd506 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 25 Apr 2018 09:58:52 +0000 Subject: [PATCH 1/4] Default Params in ML should be saved separately in Python API. --- python/pyspark/ml/tests.py | 38 ++++++++++++++++++++++++++++++++++++++ python/pyspark/ml/util.py | 33 +++++++++++++++++++++++++++++++-- python/pyspark/util.py | 21 +++++++++++++++++++++ 3 files changed, 90 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 2ec0be60e9fa9..b9866575c2db8 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1595,6 +1595,44 @@ def test_default_read_write(self): self.assertEqual(lr.uid, lr3.uid) self.assertEqual(lr.extractParamMap(), lr3.extractParamMap()) + def test_default_read_write_default_params(self): + lr = LogisticRegression() + self.assertFalse(lr.isSet(lr.getParam("threshold"))) + + lr.setMaxIter(50) + lr.setThreshold(.75) + + # `threshold` is set by user, default param `predictionCol` is not set by user. + self.assertTrue(lr.isSet(lr.getParam("threshold"))) + self.assertFalse(lr.isSet(lr.getParam("predictionCol"))) + self.assertTrue(lr.hasDefault(lr.getParam("predictionCol"))) + + writer = DefaultParamsWriter(lr) + metadata = json.loads(writer._get_metadata_to_save(lr, self.sc)) + self.assertTrue("defaultParamMap" in metadata) + + reader = DefaultParamsReadable.read() + metadataStr = json.dumps(metadata, separators=[',', ':']) + loadedMetadata = reader._parseMetaData(metadataStr, ) + reader.getAndSetParams(lr, loadedMetadata) + + self.assertTrue(lr.isSet(lr.getParam("threshold"))) + self.assertFalse(lr.isSet(lr.getParam("predictionCol"))) + self.assertTrue(lr.hasDefault(lr.getParam("predictionCol"))) + + # manually create metadata without `defaultParamMap` section. + del metadata['defaultParamMap'] + metadataStr = json.dumps(metadata, separators=[',', ':']) + loadedMetadata = reader._parseMetaData(metadataStr, ) + with self.assertRaisesRegexp(AssertionError, "`defaultParamMap` section not found"): + reader.getAndSetParams(lr, loadedMetadata) + + # Prior to 2.4.0, metadata doesn't have `defaultParamMap`. + metadata['sparkVersion'] = '2.3.0' + metadataStr = json.dumps(metadata, separators=[',', ':']) + loadedMetadata = reader._parseMetaData(metadataStr, ) + reader.getAndSetParams(lr, loadedMetadata) + class LDATest(SparkSessionTestCase): diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index a486c6a3fdeb5..cdc0b51942425 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -30,6 +30,7 @@ from pyspark import SparkContext, since from pyspark.ml.common import inherit_doc from pyspark.sql import SparkSession +from pyspark.util import majorMinorVersion def _jvm(): @@ -396,6 +397,7 @@ def saveMetadata(instance, path, sc, extraMetadata=None, paramMap=None): - sparkVersion - uid - paramMap + - defalutParamMap (since 2.4.0) - (optionally, extra metadata) :param extraMetadata: Extra metadata to be saved at same level as uid, paramMap, etc. :param paramMap: If given, this is saved in the "paramMap" field. @@ -417,15 +419,24 @@ def _get_metadata_to_save(instance, sc, extraMetadata=None, paramMap=None): """ uid = instance.uid cls = instance.__module__ + '.' + instance.__class__.__name__ - params = instance.extractParamMap() + + # User-supplied param values + params = instance._paramMap jsonParams = {} if paramMap is not None: jsonParams = paramMap else: for p in params: jsonParams[p.name] = params[p] + + # Default param values + jsonDefaultParams = {} + for p in instance._defaultParamMap: + jsonDefaultParams[p.name] = instance._defaultParamMap[p] + basicMetadata = {"class": cls, "timestamp": long(round(time.time() * 1000)), - "sparkVersion": sc.version, "uid": uid, "paramMap": jsonParams} + "sparkVersion": sc.version, "uid": uid, "paramMap": jsonParams, + "defaultParamMap": jsonDefaultParams} if extraMetadata is not None: basicMetadata.update(extraMetadata) return json.dumps(basicMetadata, separators=[',', ':']) @@ -523,11 +534,29 @@ def getAndSetParams(instance, metadata): """ Extract Params from metadata, and set them in the instance. """ + # User-supplied param values for paramName in metadata['paramMap']: param = instance.getParam(paramName) paramValue = metadata['paramMap'][paramName] instance.set(param, paramValue) + # Default param values + majorAndMinorVersions = majorMinorVersion(metadata['sparkVersion']) + assert majorAndMinorVersions is not None, "Error loading metadata: Expected " + \ + "Spark version string but found {}".format(metadata['sparkVersion']) + + major = majorAndMinorVersions[0] + minor = majorAndMinorVersions[1] + # For metadata file prior to Spark 2.4, there is no default section. + if major > 2 or (major == 2 and minor >= 4): + assert 'defaultParamMap' in metadata, "Error loading metadata: Expected " + \ + "`defaultParamMap` section not found" + + for paramName in metadata['defaultParamMap']: + param = instance.getParam(paramName) + paramValue = metadata['defaultParamMap'][paramName] + instance._setDefault(**{param.name: paramValue}) + @staticmethod def loadParamsInstance(path, sc): """ diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 49afc13640332..04df835bf6717 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -16,6 +16,7 @@ # limitations under the License. # +import re import sys import inspect from py4j.protocol import Py4JJavaError @@ -61,6 +62,26 @@ def _get_argspec(f): return argspec +def majorMinorVersion(version): + """ + Get major and minor version numbers for given Spark version string. + + >>> version = "2.4.0" + >>> majorMinorVersion(version) + (2, 4) + + >>> version = "abc" + >>> majorMinorVersion(version) is None + True + + """ + m = re.search('^(\d+)\.(\d+)(\..*)?$', version) + if m is None: + return None + else: + return (int(m.group(1)), int(m.group(2))) + + if __name__ == "__main__": import doctest (failure_count, test_count) = doctest.testmod() From 526fa4a96b61f4b5adb6606a92ed440879747a28 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 27 Apr 2018 13:55:11 +0000 Subject: [PATCH 2/4] Remove redundant call to getParam. --- python/pyspark/ml/util.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index cdc0b51942425..785de453572ba 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -553,9 +553,8 @@ def getAndSetParams(instance, metadata): "`defaultParamMap` section not found" for paramName in metadata['defaultParamMap']: - param = instance.getParam(paramName) paramValue = metadata['defaultParamMap'][paramName] - instance._setDefault(**{param.name: paramValue}) + instance._setDefault(**{paramName: paramValue}) @staticmethod def loadParamsInstance(path, sc): From b47beab00876d5282ef60eb4e88d7314749a481b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 9 May 2018 03:24:28 +0000 Subject: [PATCH 3/4] Sync with updated majorAndMinorVersions API. --- python/pyspark/ml/util.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 785de453572ba..4c9cc7c6bbd00 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -30,7 +30,7 @@ from pyspark import SparkContext, since from pyspark.ml.common import inherit_doc from pyspark.sql import SparkSession -from pyspark.util import majorMinorVersion +from pyspark.util import VersionUtils def _jvm(): @@ -534,19 +534,17 @@ def getAndSetParams(instance, metadata): """ Extract Params from metadata, and set them in the instance. """ - # User-supplied param values + # Set user-supplied param values for paramName in metadata['paramMap']: param = instance.getParam(paramName) paramValue = metadata['paramMap'][paramName] instance.set(param, paramValue) - # Default param values - majorAndMinorVersions = majorMinorVersion(metadata['sparkVersion']) - assert majorAndMinorVersions is not None, "Error loading metadata: Expected " + \ - "Spark version string but found {}".format(metadata['sparkVersion']) - + # Set default param values + majorAndMinorVersions = VersionUtils.majorMinorVersion(metadata['sparkVersion']) major = majorAndMinorVersions[0] minor = majorAndMinorVersions[1] + # For metadata file prior to Spark 2.4, there is no default section. if major > 2 or (major == 2 and minor >= 4): assert 'defaultParamMap' in metadata, "Error loading metadata: Expected " + \ From dc593754c62d2daf89331ea21d9250af9b9febfd Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 14 May 2018 23:57:52 +0000 Subject: [PATCH 4/4] Fix typo. --- python/pyspark/ml/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 4c9cc7c6bbd00..9fa85664939b8 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -397,7 +397,7 @@ def saveMetadata(instance, path, sc, extraMetadata=None, paramMap=None): - sparkVersion - uid - paramMap - - defalutParamMap (since 2.4.0) + - defaultParamMap (since 2.4.0) - (optionally, extra metadata) :param extraMetadata: Extra metadata to be saved at same level as uid, paramMap, etc. :param paramMap: If given, this is saved in the "paramMap" field.