-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-11891] Model export/import for RFormula and RFormulaModel #9884
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d01b249
aa2d9aa
a4b72a0
7b9abf6
58f4791
b18cea0
6f66d8c
3c8faab
0779c39
739eb87
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,12 +20,14 @@ package org.apache.spark.ml.feature | |
| import scala.collection.mutable | ||
| import scala.collection.mutable.ArrayBuffer | ||
|
|
||
| import org.apache.spark.annotation.Experimental | ||
| import org.apache.hadoop.fs.Path | ||
|
|
||
| import org.apache.spark.annotation.{Experimental, Since} | ||
| import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} | ||
| import org.apache.spark.ml.attribute.AttributeGroup | ||
| import org.apache.spark.ml.param.{Param, ParamMap} | ||
| import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol} | ||
| import org.apache.spark.ml.util.Identifiable | ||
| import org.apache.spark.ml.util._ | ||
| import org.apache.spark.mllib.linalg.VectorUDT | ||
| import org.apache.spark.sql.DataFrame | ||
| import org.apache.spark.sql.types._ | ||
|
|
@@ -68,7 +70,8 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol { | |
| * will be created from the specified response variable in the formula. | ||
| */ | ||
| @Experimental | ||
| class RFormula(override val uid: String) extends Estimator[RFormulaModel] with RFormulaBase { | ||
| class RFormula(override val uid: String) | ||
| extends Estimator[RFormulaModel] with RFormulaBase with DefaultParamsWritable { | ||
|
|
||
| def this() = this(Identifiable.randomUID("rFormula")) | ||
|
|
||
|
|
@@ -180,6 +183,13 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R | |
| override def toString: String = s"RFormula(${get(formula)}) (uid=$uid)" | ||
| } | ||
|
|
||
| @Since("2.0.0") | ||
| object RFormula extends DefaultParamsReadable[RFormula] { | ||
|
|
||
| @Since("2.0.0") | ||
| override def load(path: String): RFormula = super.load(path) | ||
| } | ||
|
|
||
| /** | ||
| * :: Experimental :: | ||
| * A fitted RFormula. Fitting is required to determine the factor levels of formula terms. | ||
|
|
@@ -189,9 +199,9 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R | |
| @Experimental | ||
| class RFormulaModel private[feature]( | ||
| override val uid: String, | ||
| resolvedFormula: ResolvedRFormula, | ||
| pipelineModel: PipelineModel) | ||
| extends Model[RFormulaModel] with RFormulaBase { | ||
| private[ml] val resolvedFormula: ResolvedRFormula, | ||
| private[ml] val pipelineModel: PipelineModel) | ||
| extends Model[RFormulaModel] with RFormulaBase with MLWritable { | ||
|
|
||
| override def transform(dataset: DataFrame): DataFrame = { | ||
| checkCanTransform(dataset.schema) | ||
|
|
@@ -246,14 +256,71 @@ class RFormulaModel private[feature]( | |
| !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType, | ||
| "Label column already exists and is not of type DoubleType.") | ||
| } | ||
|
|
||
| @Since("2.0.0") | ||
| override def write: MLWriter = new RFormulaModel.RFormulaModelWriter(this) | ||
| } | ||
|
|
||
| @Since("2.0.0") | ||
| object RFormulaModel extends MLReadable[RFormulaModel] { | ||
|
|
||
| @Since("2.0.0") | ||
| override def read: MLReader[RFormulaModel] = new RFormulaModelReader | ||
|
|
||
| @Since("2.0.0") | ||
| override def load(path: String): RFormulaModel = super.load(path) | ||
|
|
||
| /** [[MLWriter]] instance for [[RFormulaModel]] */ | ||
| private[RFormulaModel] class RFormulaModelWriter(instance: RFormulaModel) extends MLWriter { | ||
|
|
||
| override protected def saveImpl(path: String): Unit = { | ||
| // Save metadata and Params | ||
| DefaultParamsWriter.saveMetadata(instance, path, sc) | ||
| // Save model data: resolvedFormula | ||
| val dataPath = new Path(path, "data").toString | ||
| sqlContext.createDataFrame(Seq(instance.resolvedFormula)) | ||
| .repartition(1).write.parquet(dataPath) | ||
| // Save pipeline model | ||
| val pmPath = new Path(path, "pipelineModel").toString | ||
| instance.pipelineModel.save(pmPath) | ||
| } | ||
| } | ||
|
|
||
| private class RFormulaModelReader extends MLReader[RFormulaModel] { | ||
|
|
||
| /** Checked against metadata when loading model */ | ||
| private val className = classOf[RFormulaModel].getName | ||
|
|
||
| override def load(path: String): RFormulaModel = { | ||
| val metadata = DefaultParamsReader.loadMetadata(path, sc, className) | ||
|
|
||
| val dataPath = new Path(path, "data").toString | ||
| val data = sqlContext.read.parquet(dataPath).select("label", "terms", "hasIntercept").head() | ||
| val label = data.getString(0) | ||
| val terms = data.getAs[Seq[Seq[String]]](1) | ||
| val hasIntercept = data.getBoolean(2) | ||
| val resolvedRFormula = ResolvedRFormula(label, terms, hasIntercept) | ||
|
|
||
| val pmPath = new Path(path, "pipelineModel").toString | ||
| val pipelineModel = PipelineModel.load(pmPath) | ||
|
|
||
| val model = new RFormulaModel(metadata.uid, resolvedRFormula, pipelineModel) | ||
|
|
||
| DefaultParamsReader.getAndSetParams(model, metadata) | ||
| model | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Utility transformer for removing temporary columns from a DataFrame. | ||
| * TODO(ekl) make this a public transformer | ||
| */ | ||
| private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer { | ||
| override val uid = Identifiable.randomUID("columnPruner") | ||
| private class ColumnPruner(override val uid: String, val columnsToPrune: Set[String]) | ||
| extends Transformer with MLWritable { | ||
|
|
||
| def this(columnsToPrune: Set[String]) = | ||
| this(Identifiable.randomUID("columnPruner"), columnsToPrune) | ||
|
|
||
| override def transform(dataset: DataFrame): DataFrame = { | ||
| val columnsToKeep = dataset.columns.filter(!columnsToPrune.contains(_)) | ||
|
|
@@ -265,6 +332,48 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer { | |
| } | ||
|
|
||
| override def copy(extra: ParamMap): ColumnPruner = defaultCopy(extra) | ||
|
|
||
| override def write: MLWriter = new ColumnPruner.ColumnPrunerWriter(this) | ||
| } | ||
|
|
||
| private object ColumnPruner extends MLReadable[ColumnPruner] { | ||
|
|
||
| override def read: MLReader[ColumnPruner] = new ColumnPrunerReader | ||
|
|
||
| override def load(path: String): ColumnPruner = super.load(path) | ||
|
|
||
| /** [[MLWriter]] instance for [[ColumnPruner]] */ | ||
| private[ColumnPruner] class ColumnPrunerWriter(instance: ColumnPruner) extends MLWriter { | ||
|
|
||
| private case class Data(columnsToPrune: Seq[String]) | ||
|
|
||
| override protected def saveImpl(path: String): Unit = { | ||
| // Save metadata and Params | ||
| DefaultParamsWriter.saveMetadata(instance, path, sc) | ||
| // Save model data: columnsToPrune | ||
| val data = Data(instance.columnsToPrune.toSeq) | ||
| val dataPath = new Path(path, "data").toString | ||
| sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) | ||
| } | ||
| } | ||
|
|
||
| private class ColumnPrunerReader extends MLReader[ColumnPruner] { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. need newline |
||
|
|
||
| /** Checked against metadata when loading model */ | ||
| private val className = classOf[ColumnPruner].getName | ||
|
|
||
| override def load(path: String): ColumnPruner = { | ||
| val metadata = DefaultParamsReader.loadMetadata(path, sc, className) | ||
|
|
||
| val dataPath = new Path(path, "data").toString | ||
| val data = sqlContext.read.parquet(dataPath).select("columnsToPrune").head() | ||
| val columnsToPrune = data.getAs[Seq[String]](0).toSet | ||
| val pruner = new ColumnPruner(metadata.uid, columnsToPrune) | ||
|
|
||
| DefaultParamsReader.getAndSetParams(pruner, metadata) | ||
| pruner | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -278,11 +387,13 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer { | |
| * by the value in the map. | ||
| */ | ||
| private class VectorAttributeRewriter( | ||
| vectorCol: String, | ||
| prefixesToRewrite: Map[String, String]) | ||
| extends Transformer { | ||
| override val uid: String, | ||
| val vectorCol: String, | ||
| val prefixesToRewrite: Map[String, String]) | ||
| extends Transformer with MLWritable { | ||
|
|
||
| override val uid = Identifiable.randomUID("vectorAttrRewriter") | ||
| def this(vectorCol: String, prefixesToRewrite: Map[String, String]) = | ||
| this(Identifiable.randomUID("vectorAttrRewriter"), vectorCol, prefixesToRewrite) | ||
|
|
||
| override def transform(dataset: DataFrame): DataFrame = { | ||
| val metadata = { | ||
|
|
@@ -315,4 +426,48 @@ private class VectorAttributeRewriter( | |
| } | ||
|
|
||
| override def copy(extra: ParamMap): VectorAttributeRewriter = defaultCopy(extra) | ||
|
|
||
| override def write: MLWriter = new VectorAttributeRewriter.VectorAttributeRewriterWriter(this) | ||
| } | ||
|
|
||
| private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewriter] { | ||
|
|
||
| override def read: MLReader[VectorAttributeRewriter] = new VectorAttributeRewriterReader | ||
|
|
||
| override def load(path: String): VectorAttributeRewriter = super.load(path) | ||
|
|
||
| /** [[MLWriter]] instance for [[VectorAttributeRewriter]] */ | ||
| private[VectorAttributeRewriter] | ||
| class VectorAttributeRewriterWriter(instance: VectorAttributeRewriter) extends MLWriter { | ||
|
|
||
| private case class Data(vectorCol: String, prefixesToRewrite: Map[String, String]) | ||
|
|
||
| override protected def saveImpl(path: String): Unit = { | ||
| // Save metadata and Params | ||
| DefaultParamsWriter.saveMetadata(instance, path, sc) | ||
| // Save model data: vectorCol, prefixesToRewrite | ||
| val data = Data(instance.vectorCol, instance.prefixesToRewrite) | ||
| val dataPath = new Path(path, "data").toString | ||
| sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) | ||
| } | ||
| } | ||
|
|
||
| private class VectorAttributeRewriterReader extends MLReader[VectorAttributeRewriter] { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. newline |
||
|
|
||
| /** Checked against metadata when loading model */ | ||
| private val className = classOf[VectorAttributeRewriter].getName | ||
|
|
||
| override def load(path: String): VectorAttributeRewriter = { | ||
| val metadata = DefaultParamsReader.loadMetadata(path, sc, className) | ||
|
|
||
| val dataPath = new Path(path, "data").toString | ||
| val data = sqlContext.read.parquet(dataPath).select("vectorCol", "prefixesToRewrite").head() | ||
| val vectorCol = data.getString(0) | ||
| val prefixesToRewrite = data.getAs[Map[String, String]](1) | ||
| val rewriter = new VectorAttributeRewriter(metadata.uid, vectorCol, prefixesToRewrite) | ||
|
|
||
| DefaultParamsReader.getAndSetParams(rewriter, metadata) | ||
| rewriter | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,10 +20,11 @@ package org.apache.spark.ml.feature | |
| import org.apache.spark.SparkFunSuite | ||
| import org.apache.spark.ml.attribute._ | ||
| import org.apache.spark.ml.param.ParamsSuite | ||
| import org.apache.spark.ml.util.DefaultReadWriteTest | ||
| import org.apache.spark.mllib.linalg.Vectors | ||
| import org.apache.spark.mllib.util.MLlibTestSparkContext | ||
|
|
||
| class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { | ||
| class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { | ||
| test("params") { | ||
| ParamsSuite.checkParams(new RFormula()) | ||
| } | ||
|
|
@@ -252,4 +253,41 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { | |
| new NumericAttribute(Some("a_foo:b_zz"), Some(4)))) | ||
| assert(attrs === expectedAttrs) | ||
| } | ||
|
|
||
| test("read/write: RFormula") { | ||
| val rFormula = new RFormula() | ||
| .setFormula("id ~ a:b") | ||
| .setFeaturesCol("myFeatures") | ||
| .setLabelCol("myLabels") | ||
|
|
||
| testDefaultReadWrite(rFormula) | ||
| } | ||
|
|
||
| test("read/write: RFormulaModel") { | ||
| def checkModelData(model: RFormulaModel, model2: RFormulaModel): Unit = { | ||
| assert(model.uid === model2.uid) | ||
|
|
||
| assert(model.resolvedFormula.label === model2.resolvedFormula.label) | ||
| assert(model.resolvedFormula.terms === model2.resolvedFormula.terms) | ||
| assert(model.resolvedFormula.hasIntercept === model2.resolvedFormula.hasIntercept) | ||
|
|
||
| assert(model.pipelineModel.uid === model2.pipelineModel.uid) | ||
|
|
||
| model.pipelineModel.stages.zip(model2.pipelineModel.stages).foreach { | ||
| case (transformer1, transformer2) => | ||
| assert(transformer1.uid === transformer2.uid) | ||
| assert(transformer1.params === transformer2.params) | ||
| } | ||
| } | ||
|
|
||
| val dataset = sqlContext.createDataFrame( | ||
| Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")) | ||
| ).toDF("id", "a", "b") | ||
|
|
||
| val rFormula = new RFormula().setFormula("id ~ a:b") | ||
|
|
||
| val model = rFormula.fit(dataset) | ||
| val newModel = testDefaultReadWrite(model) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it not work to use testDefaultEstimatorReadWrite instead?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. RFormulaModel has no |
||
| checkModelData(model, newModel) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer to leave the private scope of ml, other than feature, because CrossValidator accesses the pipelineModel.