Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 167 additions & 12 deletions mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@ package org.apache.spark.ml.feature
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.annotation.Experimental
import org.apache.hadoop.fs.Path

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer}
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.VectorUDT
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -68,7 +70,8 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol {
* will be created from the specified response variable in the formula.
*/
@Experimental
class RFormula(override val uid: String) extends Estimator[RFormulaModel] with RFormulaBase {
class RFormula(override val uid: String)
extends Estimator[RFormulaModel] with RFormulaBase with DefaultParamsWritable {

def this() = this(Identifiable.randomUID("rFormula"))

Expand Down Expand Up @@ -180,6 +183,13 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
override def toString: String = s"RFormula(${get(formula)}) (uid=$uid)"
}

@Since("2.0.0")
object RFormula extends DefaultParamsReadable[RFormula] {

@Since("2.0.0")
override def load(path: String): RFormula = super.load(path)
}

/**
* :: Experimental ::
* A fitted RFormula. Fitting is required to determine the factor levels of formula terms.
Expand All @@ -189,9 +199,9 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
@Experimental
class RFormulaModel private[feature](
override val uid: String,
resolvedFormula: ResolvedRFormula,
pipelineModel: PipelineModel)
extends Model[RFormulaModel] with RFormulaBase {
private[ml] val resolvedFormula: ResolvedRFormula,
private[ml] val pipelineModel: PipelineModel)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to leave the private scope of ml, other than feature, because CrossValidator accesses the pipelineModel.

extends Model[RFormulaModel] with RFormulaBase with MLWritable {

override def transform(dataset: DataFrame): DataFrame = {
checkCanTransform(dataset.schema)
Expand Down Expand Up @@ -246,14 +256,71 @@ class RFormulaModel private[feature](
!columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType,
"Label column already exists and is not of type DoubleType.")
}

@Since("2.0.0")
override def write: MLWriter = new RFormulaModel.RFormulaModelWriter(this)
}

@Since("2.0.0")
object RFormulaModel extends MLReadable[RFormulaModel] {

@Since("2.0.0")
override def read: MLReader[RFormulaModel] = new RFormulaModelReader

@Since("2.0.0")
override def load(path: String): RFormulaModel = super.load(path)

/** [[MLWriter]] instance for [[RFormulaModel]] */
private[RFormulaModel] class RFormulaModelWriter(instance: RFormulaModel) extends MLWriter {

override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Save model data: resolvedFormula
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(instance.resolvedFormula))
.repartition(1).write.parquet(dataPath)
// Save pipeline model
val pmPath = new Path(path, "pipelineModel").toString
instance.pipelineModel.save(pmPath)
}
}

private class RFormulaModelReader extends MLReader[RFormulaModel] {

/** Checked against metadata when loading model */
private val className = classOf[RFormulaModel].getName

override def load(path: String): RFormulaModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath).select("label", "terms", "hasIntercept").head()
val label = data.getString(0)
val terms = data.getAs[Seq[Seq[String]]](1)
val hasIntercept = data.getBoolean(2)
val resolvedRFormula = ResolvedRFormula(label, terms, hasIntercept)

val pmPath = new Path(path, "pipelineModel").toString
val pipelineModel = PipelineModel.load(pmPath)

val model = new RFormulaModel(metadata.uid, resolvedRFormula, pipelineModel)

DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}
}

/**
* Utility transformer for removing temporary columns from a DataFrame.
* TODO(ekl) make this a public transformer
*/
private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer {
override val uid = Identifiable.randomUID("columnPruner")
private class ColumnPruner(override val uid: String, val columnsToPrune: Set[String])
extends Transformer with MLWritable {

def this(columnsToPrune: Set[String]) =
this(Identifiable.randomUID("columnPruner"), columnsToPrune)

override def transform(dataset: DataFrame): DataFrame = {
val columnsToKeep = dataset.columns.filter(!columnsToPrune.contains(_))
Expand All @@ -265,6 +332,48 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer {
}

override def copy(extra: ParamMap): ColumnPruner = defaultCopy(extra)

override def write: MLWriter = new ColumnPruner.ColumnPrunerWriter(this)
}

private object ColumnPruner extends MLReadable[ColumnPruner] {

override def read: MLReader[ColumnPruner] = new ColumnPrunerReader

override def load(path: String): ColumnPruner = super.load(path)

/** [[MLWriter]] instance for [[ColumnPruner]] */
private[ColumnPruner] class ColumnPrunerWriter(instance: ColumnPruner) extends MLWriter {

private case class Data(columnsToPrune: Seq[String])

override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Save model data: columnsToPrune
val data = Data(instance.columnsToPrune.toSeq)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

private class ColumnPrunerReader extends MLReader[ColumnPruner] {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need newline


/** Checked against metadata when loading model */
private val className = classOf[ColumnPruner].getName

override def load(path: String): ColumnPruner = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath).select("columnsToPrune").head()
val columnsToPrune = data.getAs[Seq[String]](0).toSet
val pruner = new ColumnPruner(metadata.uid, columnsToPrune)

DefaultParamsReader.getAndSetParams(pruner, metadata)
pruner
}
}
}

/**
Expand All @@ -278,11 +387,13 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer {
* by the value in the map.
*/
private class VectorAttributeRewriter(
vectorCol: String,
prefixesToRewrite: Map[String, String])
extends Transformer {
override val uid: String,
val vectorCol: String,
val prefixesToRewrite: Map[String, String])
extends Transformer with MLWritable {

override val uid = Identifiable.randomUID("vectorAttrRewriter")
def this(vectorCol: String, prefixesToRewrite: Map[String, String]) =
this(Identifiable.randomUID("vectorAttrRewriter"), vectorCol, prefixesToRewrite)

override def transform(dataset: DataFrame): DataFrame = {
val metadata = {
Expand Down Expand Up @@ -315,4 +426,48 @@ private class VectorAttributeRewriter(
}

override def copy(extra: ParamMap): VectorAttributeRewriter = defaultCopy(extra)

override def write: MLWriter = new VectorAttributeRewriter.VectorAttributeRewriterWriter(this)
}

private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewriter] {

override def read: MLReader[VectorAttributeRewriter] = new VectorAttributeRewriterReader

override def load(path: String): VectorAttributeRewriter = super.load(path)

/** [[MLWriter]] instance for [[VectorAttributeRewriter]] */
private[VectorAttributeRewriter]
class VectorAttributeRewriterWriter(instance: VectorAttributeRewriter) extends MLWriter {

private case class Data(vectorCol: String, prefixesToRewrite: Map[String, String])

override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Save model data: vectorCol, prefixesToRewrite
val data = Data(instance.vectorCol, instance.prefixesToRewrite)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

private class VectorAttributeRewriterReader extends MLReader[VectorAttributeRewriter] {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

newline


/** Checked against metadata when loading model */
private val className = classOf[VectorAttributeRewriter].getName

override def load(path: String): VectorAttributeRewriter = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath).select("vectorCol", "prefixesToRewrite").head()
val vectorCol = data.getString(0)
val prefixesToRewrite = data.getAs[Map[String, String]](1)
val rewriter = new VectorAttributeRewriter(metadata.uid, vectorCol, prefixesToRewrite)

DefaultParamsReader.getAndSetParams(rewriter, metadata)
rewriter
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,7 @@ object CrossValidator extends MLReadable[CrossValidator] {
// TODO: SPARK-11892: This case may require special handling.
throw new UnsupportedOperationException("CrossValidator write will fail because it" +
" cannot yet handle an estimator containing type: ${ovr.getClass.getName}")
case rform: RFormulaModel =>
// TODO: SPARK-11891: This case may require special handling.
throw new UnsupportedOperationException("CrossValidator write will fail because it" +
" cannot yet handle an estimator containing an RFormulaModel")
case rformModel: RFormulaModel => Array(rformModel.pipelineModel)
case _: Params => Array()
}
val subStageMaps = subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_ ++ _)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext

class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new RFormula())
}
Expand Down Expand Up @@ -252,4 +253,41 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
new NumericAttribute(Some("a_foo:b_zz"), Some(4))))
assert(attrs === expectedAttrs)
}

test("read/write: RFormula") {
val rFormula = new RFormula()
.setFormula("id ~ a:b")
.setFeaturesCol("myFeatures")
.setLabelCol("myLabels")

testDefaultReadWrite(rFormula)
}

test("read/write: RFormulaModel") {
def checkModelData(model: RFormulaModel, model2: RFormulaModel): Unit = {
assert(model.uid === model2.uid)

assert(model.resolvedFormula.label === model2.resolvedFormula.label)
assert(model.resolvedFormula.terms === model2.resolvedFormula.terms)
assert(model.resolvedFormula.hasIntercept === model2.resolvedFormula.hasIntercept)

assert(model.pipelineModel.uid === model2.pipelineModel.uid)

model.pipelineModel.stages.zip(model2.pipelineModel.stages).foreach {
case (transformer1, transformer2) =>
assert(transformer1.uid === transformer2.uid)
assert(transformer1.params === transformer2.params)
}
}

val dataset = sqlContext.createDataFrame(
Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz"))
).toDF("id", "a", "b")

val rFormula = new RFormula().setFormula("id ~ a:b")

val model = rFormula.fit(dataset)
val newModel = testDefaultReadWrite(model)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it not work to use testDefaultEstimatorReadWrite instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RFormulaModel has no formula, while RFormula has formula. Error occurs if we use testEstimatorAndModeReadWrite.

checkModelData(model, newModel)
}
}