diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 093593132e56d..0dde0db9e3339 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1595,6 +1595,44 @@ def test_default_read_write(self): self.assertEqual(lr.uid, lr3.uid) self.assertEqual(lr.extractParamMap(), lr3.extractParamMap()) + def test_default_read_write_default_params(self): + lr = LogisticRegression() + self.assertFalse(lr.isSet(lr.getParam("threshold"))) + + lr.setMaxIter(50) + lr.setThreshold(.75) + + # `threshold` is set by user, default param `predictionCol` is not set by user. + self.assertTrue(lr.isSet(lr.getParam("threshold"))) + self.assertFalse(lr.isSet(lr.getParam("predictionCol"))) + self.assertTrue(lr.hasDefault(lr.getParam("predictionCol"))) + + writer = DefaultParamsWriter(lr) + metadata = json.loads(writer._get_metadata_to_save(lr, self.sc)) + self.assertTrue("defaultParamMap" in metadata) + + reader = DefaultParamsReadable.read() + metadataStr = json.dumps(metadata, separators=[',', ':']) + loadedMetadata = reader._parseMetaData(metadataStr, ) + reader.getAndSetParams(lr, loadedMetadata) + + self.assertTrue(lr.isSet(lr.getParam("threshold"))) + self.assertFalse(lr.isSet(lr.getParam("predictionCol"))) + self.assertTrue(lr.hasDefault(lr.getParam("predictionCol"))) + + # manually create metadata without `defaultParamMap` section. + del metadata['defaultParamMap'] + metadataStr = json.dumps(metadata, separators=[',', ':']) + loadedMetadata = reader._parseMetaData(metadataStr, ) + with self.assertRaisesRegexp(AssertionError, "`defaultParamMap` section not found"): + reader.getAndSetParams(lr, loadedMetadata) + + # Prior to 2.4.0, metadata doesn't have `defaultParamMap`. + metadata['sparkVersion'] = '2.3.0' + metadataStr = json.dumps(metadata, separators=[',', ':']) + loadedMetadata = reader._parseMetaData(metadataStr, ) + reader.getAndSetParams(lr, loadedMetadata) + class LDATest(SparkSessionTestCase): diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index a486c6a3fdeb5..9fa85664939b8 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -30,6 +30,7 @@ from pyspark import SparkContext, since from pyspark.ml.common import inherit_doc from pyspark.sql import SparkSession +from pyspark.util import VersionUtils def _jvm(): @@ -396,6 +397,7 @@ def saveMetadata(instance, path, sc, extraMetadata=None, paramMap=None): - sparkVersion - uid - paramMap + - defaultParamMap (since 2.4.0) - (optionally, extra metadata) :param extraMetadata: Extra metadata to be saved at same level as uid, paramMap, etc. :param paramMap: If given, this is saved in the "paramMap" field. @@ -417,15 +419,24 @@ def _get_metadata_to_save(instance, sc, extraMetadata=None, paramMap=None): """ uid = instance.uid cls = instance.__module__ + '.' + instance.__class__.__name__ - params = instance.extractParamMap() + + # User-supplied param values + params = instance._paramMap jsonParams = {} if paramMap is not None: jsonParams = paramMap else: for p in params: jsonParams[p.name] = params[p] + + # Default param values + jsonDefaultParams = {} + for p in instance._defaultParamMap: + jsonDefaultParams[p.name] = instance._defaultParamMap[p] + basicMetadata = {"class": cls, "timestamp": long(round(time.time() * 1000)), - "sparkVersion": sc.version, "uid": uid, "paramMap": jsonParams} + "sparkVersion": sc.version, "uid": uid, "paramMap": jsonParams, + "defaultParamMap": jsonDefaultParams} if extraMetadata is not None: basicMetadata.update(extraMetadata) return json.dumps(basicMetadata, separators=[',', ':']) @@ -523,11 +534,26 @@ def getAndSetParams(instance, metadata): """ Extract Params from metadata, and set them in the instance. """ + # Set user-supplied param values for paramName in metadata['paramMap']: param = instance.getParam(paramName) paramValue = metadata['paramMap'][paramName] instance.set(param, paramValue) + # Set default param values + majorAndMinorVersions = VersionUtils.majorMinorVersion(metadata['sparkVersion']) + major = majorAndMinorVersions[0] + minor = majorAndMinorVersions[1] + + # For metadata file prior to Spark 2.4, there is no default section. + if major > 2 or (major == 2 and minor >= 4): + assert 'defaultParamMap' in metadata, "Error loading metadata: Expected " + \ + "`defaultParamMap` section not found" + + for paramName in metadata['defaultParamMap']: + paramValue = metadata['defaultParamMap'][paramName] + instance._setDefault(**{paramName: paramValue}) + @staticmethod def loadParamsInstance(path, sc): """