From cf596346e2f5e333902e9e48a895040d3ec62409 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Thu, 17 Mar 2016 17:31:46 -0700 Subject: [PATCH 1/2] add pyspark rformula save/load --- python/pyspark/ml/feature.py | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 5025493c42c38..94ce87a88ddfd 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2360,7 +2360,7 @@ def explainedVariance(self): @inherit_doc -class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol): +class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, MLReadable, MLWritable): """ .. note:: Experimental @@ -2385,7 +2385,31 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol): |0.0|0.0| a|[0.0,1.0]| 0.0| +---+---+---+---------+-----+ ... - >>> rf.fit(df, {rf.formula: "y ~ . - s"}).transform(df).show() + >>> model = rf.fit(df, {rf.formula: "y ~ . - s"}) + >>> model.transform(df).show() + +---+---+---+--------+-----+ + | y| x| s|features|label| + +---+---+---+--------+-----+ + |1.0|1.0| a| [1.0]| 1.0| + |0.0|2.0| b| [2.0]| 0.0| + |0.0|0.0| a| [0.0]| 0.0| + +---+---+---+--------+-----+ + ... + >>> rFormulaPath = temp_path + "/rFormula" + >>> rf.save(rFormulaPath) + >>> loadedRF = RFormula.load(rFormulaPath) + >>> loadedRF.getFormula() == rf.getFormula() + True + >>> loadedRF.getFeaturesCol() == rf.getFeaturesCol() + True + >>> loadedRF.getLabelCol() == rf.getLabelCol() + True + >>> modelPath = temp_path + "/rFormula-model" + >>> model.save(modelPath) + >>> loadedModel = RFormulaModel.load(modelPath) + >>> loadedModel.uid == model.uid + True + >>> loadedModel.transform(df).show() +---+---+---+--------+-----+ | y| x| s|features|label| +---+---+---+--------+-----+ @@ -2439,7 +2463,7 @@ def _create_model(self, java_model): return RFormulaModel(java_model) -class RFormulaModel(JavaModel): +class RFormulaModel(JavaModel, MLReadable, MLWritable): """ .. note:: Experimental From d31bb3f9cb616805ff48cce64d0d0f136cbd5639 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Fri, 18 Mar 2016 22:20:15 -0700 Subject: [PATCH 2/2] fix nit --- python/pyspark/ml/feature.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 94ce87a88ddfd..3182faac0de0f 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2376,7 +2376,8 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, MLReadable, MLWritabl ... (0.0, 0.0, "a") ... ], ["y", "x", "s"]) >>> rf = RFormula(formula="y ~ x + s") - >>> rf.fit(df).transform(df).show() + >>> model = rf.fit(df) + >>> model.transform(df).show() +---+---+---+---------+-----+ | y| x| s| features|label| +---+---+---+---------+-----+ @@ -2385,8 +2386,7 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, MLReadable, MLWritabl |0.0|0.0| a|[0.0,1.0]| 0.0| +---+---+---+---------+-----+ ... - >>> model = rf.fit(df, {rf.formula: "y ~ . - s"}) - >>> model.transform(df).show() + >>> rf.fit(df, {rf.formula: "y ~ . - s"}).transform(df).show() +---+---+---+--------+-----+ | y| x| s|features|label| +---+---+---+--------+-----+ @@ -2404,19 +2404,19 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, MLReadable, MLWritabl True >>> loadedRF.getLabelCol() == rf.getLabelCol() True - >>> modelPath = temp_path + "/rFormula-model" + >>> modelPath = temp_path + "/rFormulaModel" >>> model.save(modelPath) >>> loadedModel = RFormulaModel.load(modelPath) >>> loadedModel.uid == model.uid True >>> loadedModel.transform(df).show() - +---+---+---+--------+-----+ - | y| x| s|features|label| - +---+---+---+--------+-----+ - |1.0|1.0| a| [1.0]| 1.0| - |0.0|2.0| b| [2.0]| 0.0| - |0.0|0.0| a| [0.0]| 0.0| - +---+---+---+--------+-----+ + +---+---+---+---------+-----+ + | y| x| s| features|label| + +---+---+---+---------+-----+ + |1.0|1.0| a|[1.0,1.0]| 1.0| + |0.0|2.0| b|[2.0,0.0]| 0.0| + |0.0|0.0| a|[0.0,1.0]| 0.0| + +---+---+---+---------+-----+ ... .. versionadded:: 1.5.0