From d01b249669067db40662b2d5b7794a9ec96ace1f Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Sat, 21 Nov 2015 15:20:02 +0800 Subject: [PATCH 1/8] add model save/load for RFormula --- .../apache/spark/ml/feature/RFormula.scala | 191 ++++++++++++++++-- 1 file changed, 176 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 5c43a41bee3b4..bb74085056174 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -20,12 +20,14 @@ package org.apache.spark.ml.feature import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.attribute.AttributeGroup -import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} -import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol} -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.param.{Param, ParamMap} +import org.apache.spark.ml.util._ +import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} import org.apache.spark.mllib.linalg.VectorUDT import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types._ @@ -47,7 +49,8 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol { * the R formula docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html */ @Experimental -class RFormula(override val uid: String) extends Estimator[RFormulaModel] with RFormulaBase { +class RFormula(override val uid: String) + extends Estimator[RFormulaModel] with RFormulaBase with DefaultParamsWritable { def this() = this(Identifiable.randomUID("rFormula")) @@ -159,6 +162,13 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R override def toString: String = s"RFormula(${get(formula)}) (uid=$uid)" } +@Since("1.6.0") +object RFormula extends DefaultParamsReadable[RFormula] { + + @Since("1.6.0") + override def load(path: String): RFormula = super.load(path) +} + /** * :: Experimental :: * A fitted RFormula. Fitting is required to determine the factor levels of formula terms. @@ -168,9 +178,9 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R @Experimental class RFormulaModel private[feature]( override val uid: String, - resolvedFormula: ResolvedRFormula, - pipelineModel: PipelineModel) - extends Model[RFormulaModel] with RFormulaBase { + val resolvedFormula: ResolvedRFormula, + val pipelineModel: PipelineModel) + extends Model[RFormulaModel] with RFormulaBase with MLWritable { override def transform(dataset: DataFrame): DataFrame = { checkCanTransform(dataset.schema) @@ -225,14 +235,71 @@ class RFormulaModel private[feature]( !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType, "Label column already exists and is not of type DoubleType.") } + + @Since("1.6.0") + override def write: MLWriter = new RFormulaModel.RFormulaModelWriter(this) +} + +@Since("1.6.0") +object RFormulaModel extends MLReadable[RFormulaModel] { + + @Since("1.6.0") + override def read: MLReader[RFormulaModel] = new RFormulaModelReader + + @Since("1.6.0") + override def load(path: String): RFormulaModel = super.load(path) + + /** [[MLWriter]] instance for [[RFormulaModel]] */ + private[RFormulaModel] class RFormulaModelWriter(instance: RFormulaModel) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: resolvedFormula + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(instance.resolvedFormula)) + .repartition(1).write.parquet(dataPath) + // Save pipeline model + val pmPath = new Path(path, "pipelineModel").toString + instance.pipelineModel.save(pmPath) + } + } + + private class RFormulaModelReader extends MLReader[RFormulaModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[RFormulaModel].getName + + override def load(path: String): RFormulaModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath).select("label", "terms", "hasIntercept").head() + val label = data.getString(0) + val terms = data.getAs[Seq[Seq[String]]](1) + val hasIntercept = data.getBoolean(2) + val resolvedRFormula = ResolvedRFormula(label, terms, hasIntercept) + + val pmPath = new Path(path, "pipelineModel").toString + val pipelineModel = PipelineModel.load(pmPath) + + val model = new RFormulaModel(metadata.uid, resolvedRFormula, pipelineModel) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } /** * Utility transformer for removing temporary columns from a DataFrame. * TODO(ekl) make this a public transformer */ -private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer { - override val uid = Identifiable.randomUID("columnPruner") +private[RFomula] class ColumnPruner(override val uid: String, val columnsToPrune: Set[String]) + extends Transformer with MLWritable { + + private[RFormula] def this(columnsToPrune: Set[String]) = + this(Identifiable.randomUID("columnPruner"), columnsToPrune) override def transform(dataset: DataFrame): DataFrame = { val columnsToKeep = dataset.columns.filter(!columnsToPrune.contains(_)) @@ -244,6 +311,51 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer { } override def copy(extra: ParamMap): ColumnPruner = defaultCopy(extra) + + @Since("1.6.0") + override def write: MLWriter = new ColumnPruner.ColumnPrunerWriter(this) +} + +@Since("1.6.0") +object ColumnPruner extends MLReadable[ColumnPruner] { + + @Since("1.6.0") + override def read: MLReader[ColumnPruner] = new ColumnPrunerReader + + @Since("1.6.0") + override def load(path: String): ColumnPruner = super.load(path) + + /** [[MLWriter]] instance for [[ColumnPruner]] */ + private[ColumnPruner] class ColumnPrunerWriter(instance: ColumnPruner) extends MLWriter { + + private case class Data(columnsToPrune: Seq[String]) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: columnsToPrune + val data = Data(instance.columnsToPrune.toSeq) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + private class ColumnPrunerReader extends MLReader[ColumnPruner] { + + /** Checked against metadata when loading model */ + private val className = classOf[ColumnPruner].getName + + override def load(path: String): ColumnPruner = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath).select("columnsToPrune").head() + val columnsToPrune = data.getAs[Seq[String]](0).toSet + val pruner = new ColumnPruner(metadata.uid, columnsToPrune) + + DefaultParamsReader.getAndSetParams(pruner, metadata) + pruner + } + } } /** @@ -256,12 +368,14 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer { * map. When a key prefixes a name, the matching prefix will be replaced * by the value in the map. */ -private class VectorAttributeRewriter( - vectorCol: String, - prefixesToRewrite: Map[String, String]) - extends Transformer { +private[ml] class VectorAttributeRewriter( + override val uid: String, + val vectorCol: String, + val prefixesToRewrite: Map[String, String]) + extends Transformer with MLWritable { - override val uid = Identifiable.randomUID("vectorAttrRewriter") + def this(vectorCol: String, prefixesToRewrite: Map[String, String]) = + this(Identifiable.randomUID("vectorAttrRewriter"), vectorCol, prefixesToRewrite) override def transform(dataset: DataFrame): DataFrame = { val metadata = { @@ -294,4 +408,51 @@ private class VectorAttributeRewriter( } override def copy(extra: ParamMap): VectorAttributeRewriter = defaultCopy(extra) + + @Since("1.6.0") + override def write: MLWriter = new VectorAttributeRewriter.VectorAttributeRewriterWriter(this) +} + +@Since("1.6.0") +object VectorAttributeRewriter extends MLReadable[VectorAttributeRewriter] { + + @Since("1.6.0") + override def read: MLReader[VectorAttributeRewriter] = new VectorAttributeRewriterReader + + @Since("1.6.0") + override def load(path: String): VectorAttributeRewriter = super.load(path) + + /** [[MLWriter]] instance for [[VectorAttributeRewriter]] */ + private[VectorAttributeRewriter] + class VectorAttributeRewriterWriter(instance: VectorAttributeRewriter) extends MLWriter { + + private case class Data(vectorCol: String, prefixesToRewrite: Map[String, String]) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: vectorCol, prefixesToRewrite + val data = Data(instance.vectorCol, instance.prefixesToRewrite) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + private class VectorAttributeRewriterReader extends MLReader[VectorAttributeRewriter] { + + /** Checked against metadata when loading model */ + private val className = classOf[VectorAttributeRewriter].getName + + override def load(path: String): VectorAttributeRewriter = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath).select("vectorCol", "prefixesToRewrite").head() + val vectorCol = data.getString(0) + val prefixesToRewrite = data.getAs[Map[String, String]](1) + val rewriter = new VectorAttributeRewriter(metadata.uid, vectorCol, prefixesToRewrite) + + DefaultParamsReader.getAndSetParams(rewriter, metadata) + rewriter + } + } } From aa2d9aafadd162c9627214c62683f3f1d01efd79 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Sat, 21 Nov 2015 16:00:14 +0800 Subject: [PATCH 2/8] add test --- .../org/apache/spark/ml/feature/RFormula.scala | 10 +++++----- .../apache/spark/ml/feature/RFormulaSuite.scala | 17 ++++++++++++++++- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index bb74085056174..9ce8bac3d85cb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -295,10 +295,10 @@ object RFormulaModel extends MLReadable[RFormulaModel] { * Utility transformer for removing temporary columns from a DataFrame. * TODO(ekl) make this a public transformer */ -private[RFomula] class ColumnPruner(override val uid: String, val columnsToPrune: Set[String]) +private class ColumnPruner(override val uid: String, val columnsToPrune: Set[String]) extends Transformer with MLWritable { - private[RFormula] def this(columnsToPrune: Set[String]) = + def this(columnsToPrune: Set[String]) = this(Identifiable.randomUID("columnPruner"), columnsToPrune) override def transform(dataset: DataFrame): DataFrame = { @@ -317,7 +317,7 @@ private[RFomula] class ColumnPruner(override val uid: String, val columnsToPrune } @Since("1.6.0") -object ColumnPruner extends MLReadable[ColumnPruner] { +private object ColumnPruner extends MLReadable[ColumnPruner] { @Since("1.6.0") override def read: MLReader[ColumnPruner] = new ColumnPrunerReader @@ -368,7 +368,7 @@ object ColumnPruner extends MLReadable[ColumnPruner] { * map. When a key prefixes a name, the matching prefix will be replaced * by the value in the map. */ -private[ml] class VectorAttributeRewriter( +private class VectorAttributeRewriter( override val uid: String, val vectorCol: String, val prefixesToRewrite: Map[String, String]) @@ -414,7 +414,7 @@ private[ml] class VectorAttributeRewriter( } @Since("1.6.0") -object VectorAttributeRewriter extends MLReadable[VectorAttributeRewriter] { +private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewriter] { @Since("1.6.0") override def read: MLReader[VectorAttributeRewriter] = new VectorAttributeRewriterReader diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index dc20a5ec2152d..772d6c622bd00 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -20,10 +20,11 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext -class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { +class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new RFormula()) } @@ -214,4 +215,18 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { new NumericAttribute(Some("a_foo:b_zz"), Some(4)))) assert(attrs === expectedAttrs) } + + test("read/write") { + def checkModelData(model: RFormulaModel, model2: RFormulaModel): Unit = { + assert(model.resolvedFormula === model2.resolvedFormula) + assert(model.pipelineModel === model2.pipelineModel) + } + + val dataset = sqlContext.createDataFrame( + Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")) + ).toDF("id", "a", "b") + + val rFormula = new RFormula().setFormula("id ~ a:b") + testEstimatorAndModelReadWrite(rFormula, dataset, Map[String, Any](), checkModelData) + } } From a4b72a0310372567509050845e77ffd517e13ce8 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Sat, 21 Nov 2015 17:55:25 +0800 Subject: [PATCH 3/8] split test --- .../spark/ml/feature/RFormulaSuite.scala | 31 ++++++++++++++++--- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 772d6c622bd00..2953890f21767 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -216,10 +216,30 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(attrs === expectedAttrs) } - test("read/write") { + test("read/write: RFormula") { + val rFormula = new RFormula() + .setFormula("id ~ a:b") + .setFeaturesCol("myFeatures") + .setLabelCol("myLabels") + + testDefaultReadWrite(rFormula) + } + + test("read/write: RFormulaModel") { def checkModelData(model: RFormulaModel, model2: RFormulaModel): Unit = { - assert(model.resolvedFormula === model2.resolvedFormula) - assert(model.pipelineModel === model2.pipelineModel) + assert(model.uid === model2.uid) + + assert(model.resolvedFormula.label === model2.resolvedFormula.label) + assert(model.resolvedFormula.terms === model2.resolvedFormula.terms) + assert(model.resolvedFormula.hasIntercept === model2.resolvedFormula.hasIntercept) + + assert(model.pipelineModel.uid === model.pipelineModel.uid) + + model.pipelineModel.stages.zip(model2.pipelineModel.stages).foreach { + case (transformer1, transformer2) => + assert(transformer1.uid === transformer2.uid) + assert(transformer1.params === transformer2.params) + } } val dataset = sqlContext.createDataFrame( @@ -227,6 +247,9 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul ).toDF("id", "a", "b") val rFormula = new RFormula().setFormula("id ~ a:b") - testEstimatorAndModelReadWrite(rFormula, dataset, Map[String, Any](), checkModelData) + + val model = rFormula.fit(dataset) + val newModel = testDefaultReadWrite(model) + checkModelData(model, newModel) } } From 7b9abf6dbeb77ed41bff3d64c1a2f33008192271 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Tue, 24 Nov 2015 22:23:52 +0800 Subject: [PATCH 4/8] update for v1.7 --- .../apache/spark/ml/feature/RFormula.scala | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 9ce8bac3d85cb..863b272a1d1ec 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -162,10 +162,10 @@ class RFormula(override val uid: String) override def toString: String = s"RFormula(${get(formula)}) (uid=$uid)" } -@Since("1.6.0") +@Since("1.7.0") object RFormula extends DefaultParamsReadable[RFormula] { - @Since("1.6.0") + @Since("1.7.0") override def load(path: String): RFormula = super.load(path) } @@ -236,17 +236,17 @@ class RFormulaModel private[feature]( "Label column already exists and is not of type DoubleType.") } - @Since("1.6.0") + @Since("1.7.0") override def write: MLWriter = new RFormulaModel.RFormulaModelWriter(this) } -@Since("1.6.0") +@Since("1.7.0") object RFormulaModel extends MLReadable[RFormulaModel] { - @Since("1.6.0") + @Since("1.7.0") override def read: MLReader[RFormulaModel] = new RFormulaModelReader - @Since("1.6.0") + @Since("1.7.0") override def load(path: String): RFormulaModel = super.load(path) /** [[MLWriter]] instance for [[RFormulaModel]] */ @@ -312,17 +312,17 @@ private class ColumnPruner(override val uid: String, val columnsToPrune: Set[Str override def copy(extra: ParamMap): ColumnPruner = defaultCopy(extra) - @Since("1.6.0") + @Since("1.7.0") override def write: MLWriter = new ColumnPruner.ColumnPrunerWriter(this) } -@Since("1.6.0") +@Since("1.7.0") private object ColumnPruner extends MLReadable[ColumnPruner] { - @Since("1.6.0") + @Since("1.7.0") override def read: MLReader[ColumnPruner] = new ColumnPrunerReader - @Since("1.6.0") + @Since("1.7.0") override def load(path: String): ColumnPruner = super.load(path) /** [[MLWriter]] instance for [[ColumnPruner]] */ @@ -409,17 +409,17 @@ private class VectorAttributeRewriter( override def copy(extra: ParamMap): VectorAttributeRewriter = defaultCopy(extra) - @Since("1.6.0") + @Since("1.7.0") override def write: MLWriter = new VectorAttributeRewriter.VectorAttributeRewriterWriter(this) } -@Since("1.6.0") +@Since("1.7.0") private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewriter] { - @Since("1.6.0") + @Since("1.7.0") override def read: MLReader[VectorAttributeRewriter] = new VectorAttributeRewriterReader - @Since("1.6.0") + @Since("1.7.0") override def load(path: String): VectorAttributeRewriter = super.load(path) /** [[MLWriter]] instance for [[VectorAttributeRewriter]] */ From b18cea0eceb955f2f6f5c2d6af8269fc68052b56 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Thu, 7 Jan 2016 21:25:15 +0800 Subject: [PATCH 5/8] change spark version --- .../apache/spark/ml/feature/RFormula.scala | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 2c13c8592f790..3e396b18e3f22 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.fs.Path -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol} import org.apache.spark.ml.param.{Param, ParamMap} @@ -163,10 +163,10 @@ class RFormula(override val uid: String) override def toString: String = s"RFormula(${get(formula)}) (uid=$uid)" } -@Since("1.7.0") +@Since("2.0.0") object RFormula extends DefaultParamsReadable[RFormula] { - @Since("1.7.0") + @Since("2.0.0") override def load(path: String): RFormula = super.load(path) } @@ -238,17 +238,17 @@ class RFormulaModel private[feature]( "Label column already exists and is not of type DoubleType.") } - @Since("1.7.0") + @Since("2.0.0") override def write: MLWriter = new RFormulaModel.RFormulaModelWriter(this) } -@Since("1.7.0") +@Since("2.0.0") object RFormulaModel extends MLReadable[RFormulaModel] { - @Since("1.7.0") + @Since("2.0.0") override def read: MLReader[RFormulaModel] = new RFormulaModelReader - @Since("1.7.0") + @Since("2.0.0") override def load(path: String): RFormulaModel = super.load(path) /** [[MLWriter]] instance for [[RFormulaModel]] */ @@ -315,17 +315,17 @@ private class ColumnPruner(override val uid: String, val columnsToPrune: Set[Str override def copy(extra: ParamMap): ColumnPruner = defaultCopy(extra) - @Since("1.7.0") + @Since("2.0.0") override def write: MLWriter = new ColumnPruner.ColumnPrunerWriter(this) } -@Since("1.7.0") +@Since("2.0.0") private object ColumnPruner extends MLReadable[ColumnPruner] { - @Since("1.7.0") + @Since("2.0.0") override def read: MLReader[ColumnPruner] = new ColumnPrunerReader - @Since("1.7.0") + @Since("2.0.0") override def load(path: String): ColumnPruner = super.load(path) /** [[MLWriter]] instance for [[ColumnPruner]] */ @@ -413,17 +413,17 @@ private class VectorAttributeRewriter( override def copy(extra: ParamMap): VectorAttributeRewriter = defaultCopy(extra) - @Since("1.7.0") + @Since("2.0.0") override def write: MLWriter = new VectorAttributeRewriter.VectorAttributeRewriterWriter(this) } -@Since("1.7.0") +@Since("2.0.0") private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewriter] { - @Since("1.7.0") + @Since("2.0.0") override def read: MLReader[VectorAttributeRewriter] = new VectorAttributeRewriterReader - @Since("1.7.0") + @Since("2.0.0") override def load(path: String): VectorAttributeRewriter = super.load(path) /** [[MLWriter]] instance for [[VectorAttributeRewriter]] */ From 6f66d8c1432b8fb2783a78121ff2406a7220b228 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Fri, 8 Jan 2016 11:07:59 +0800 Subject: [PATCH 6/8] fix crossvalidator --- .../scala/org/apache/spark/ml/tuning/CrossValidator.scala | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 3eac616aeaf8e..96ecbd0ccbe6d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -27,7 +27,7 @@ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.classification.OneVsRestParams import org.apache.spark.ml.evaluation.Evaluator -import org.apache.spark.ml.feature.RFormulaModel +import org.apache.spark.ml.feature.{RFormula, RFormulaModel} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.HasSeed import org.apache.spark.ml.util._ @@ -220,10 +220,7 @@ object CrossValidator extends MLReadable[CrossValidator] { // TODO: SPARK-11892: This case may require special handling. throw new UnsupportedOperationException("CrossValidator write will fail because it" + " cannot yet handle an estimator containing type: ${ovr.getClass.getName}") - case rform: RFormulaModel => - // TODO: SPARK-11891: This case may require special handling. - throw new UnsupportedOperationException("CrossValidator write will fail because it" + - " cannot yet handle an estimator containing an RFormulaModel") + case rformModel: RFormulaModel => Array(rformModel.pipelineModel) case _: Params => Array() } val subStageMaps = subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_ ++ _) From 3c8faab8a951fb0ccf6c4a8082f7c80f1c948534 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Fri, 8 Jan 2016 11:09:41 +0800 Subject: [PATCH 7/8] remove RFormula --- .../main/scala/org/apache/spark/ml/tuning/CrossValidator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 96ecbd0ccbe6d..932cc3015b43d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -27,7 +27,7 @@ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.classification.OneVsRestParams import org.apache.spark.ml.evaluation.Evaluator -import org.apache.spark.ml.feature.{RFormula, RFormulaModel} +import org.apache.spark.ml.feature.RFormulaModel import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.HasSeed import org.apache.spark.ml.util._ From 739eb876736badf04ac8b5c98bdb0791d4c52a59 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Wed, 16 Mar 2016 21:04:31 -0700 Subject: [PATCH 8/8] fix nit --- .../apache/spark/ml/feature/RFormula.scala | 20 +++++++------------ .../spark/ml/feature/RFormulaSuite.scala | 2 +- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 114e9e3a40d5d..e7ca7ada74c8c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -22,12 +22,12 @@ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.fs.Path -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} import org.apache.spark.ml.attribute.AttributeGroup -import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol} import org.apache.spark.ml.param.{Param, ParamMap} +import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol} import org.apache.spark.ml.util._ -import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} import org.apache.spark.mllib.linalg.VectorUDT import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types._ @@ -199,8 +199,8 @@ object RFormula extends DefaultParamsReadable[RFormula] { @Experimental class RFormulaModel private[feature]( override val uid: String, - val resolvedFormula: ResolvedRFormula, - val pipelineModel: PipelineModel) + private[ml] val resolvedFormula: ResolvedRFormula, + private[ml] val pipelineModel: PipelineModel) extends Model[RFormulaModel] with RFormulaBase with MLWritable { override def transform(dataset: DataFrame): DataFrame = { @@ -333,17 +333,13 @@ private class ColumnPruner(override val uid: String, val columnsToPrune: Set[Str override def copy(extra: ParamMap): ColumnPruner = defaultCopy(extra) - @Since("2.0.0") override def write: MLWriter = new ColumnPruner.ColumnPrunerWriter(this) } -@Since("2.0.0") private object ColumnPruner extends MLReadable[ColumnPruner] { - @Since("2.0.0") override def read: MLReader[ColumnPruner] = new ColumnPrunerReader - @Since("2.0.0") override def load(path: String): ColumnPruner = super.load(path) /** [[MLWriter]] instance for [[ColumnPruner]] */ @@ -360,6 +356,7 @@ private object ColumnPruner extends MLReadable[ColumnPruner] { sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } + private class ColumnPrunerReader extends MLReader[ColumnPruner] { /** Checked against metadata when loading model */ @@ -430,17 +427,13 @@ private class VectorAttributeRewriter( override def copy(extra: ParamMap): VectorAttributeRewriter = defaultCopy(extra) - @Since("2.0.0") override def write: MLWriter = new VectorAttributeRewriter.VectorAttributeRewriterWriter(this) } -@Since("2.0.0") private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewriter] { - @Since("2.0.0") override def read: MLReader[VectorAttributeRewriter] = new VectorAttributeRewriterReader - @Since("2.0.0") override def load(path: String): VectorAttributeRewriter = super.load(path) /** [[MLWriter]] instance for [[VectorAttributeRewriter]] */ @@ -458,6 +451,7 @@ private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewrite sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } + private class VectorAttributeRewriterReader extends MLReader[VectorAttributeRewriter] { /** Checked against metadata when loading model */ diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 76d3286c1369f..e1b269b5b681f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -271,7 +271,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(model.resolvedFormula.terms === model2.resolvedFormula.terms) assert(model.resolvedFormula.hasIntercept === model2.resolvedFormula.hasIntercept) - assert(model.pipelineModel.uid === model.pipelineModel.uid) + assert(model.pipelineModel.uid === model2.pipelineModel.uid) model.pipelineModel.stages.zip(model2.pipelineModel.stages).foreach { case (transformer1, transformer2) =>