From b1a6310d3aea1b728a812a821e9b6d8075a1bdf6 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Tue, 29 Nov 2016 18:13:16 +0800 Subject: [PATCH 1/5] create pr --- .../apache/spark/ml/classification/OneVsRest.scala | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index f4ab0a074c420..9f7702d7e8717 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -140,6 +140,20 @@ final class OneVsRestModel private[ml] ( this(uid, Metadata.empty, models.asScala.toArray) } + /** @group setParam */ + @Since("2.2.0") + def setFeaturesCol(value: String): this.type = { + models.foreach(_.setFeaturesCol(value)) + set(featuresCol, value) + } + + /** @group setParam */ + @Since("2.2.0") + def setPredictionCol(value: String): this.type = { + models.foreach(_.setPredictionCol(value)) + set(predictionCol, value) + } + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType) From f4b537434800118b751d32538bce7c836c26c759 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Tue, 29 Nov 2016 18:21:09 +0800 Subject: [PATCH 2/5] update --- .../apache/spark/ml/classification/OneVsRest.scala | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 9f7702d7e8717..c17520fd794a4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -142,17 +142,11 @@ final class OneVsRestModel private[ml] ( /** @group setParam */ @Since("2.2.0") - def setFeaturesCol(value: String): this.type = { - models.foreach(_.setFeaturesCol(value)) - set(featuresCol, value) - } + def setFeaturesCol(value: String): this.type = set(featuresCol, value) /** @group setParam */ @Since("2.2.0") - def setPredictionCol(value: String): this.type = { - models.foreach(_.setPredictionCol(value)) - set(predictionCol, value) - } + def setPredictionCol(value: String): this.type = set(predictionCol, value) @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { @@ -189,6 +183,8 @@ final class OneVsRestModel private[ml] ( val updateUDF = udf { (predictions: Map[Int, Double], prediction: Vector) => predictions + ((index, prediction(1))) } + model.setFeaturesCol($(featuresCol)) + model.setPredictionCol($(predictionCol)) val transformedDataset = model.transform(df).select(columns: _*) val updatedDataset = transformedDataset .withColumn(tmpColName, updateUDF(col(accColName), col(rawPredictionCol))) From 9de08f21f50f5eca2fff80ec7ab7463397f80013 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Tue, 29 Nov 2016 19:07:30 +0800 Subject: [PATCH 3/5] add test --- .../spark/ml/classification/OneVsRestSuite.scala | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 3f9bcec427399..d9030de7e22a9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.feature.StringIndexer -import org.apache.spark.ml.linalg.{DenseMatrix, Vectors} +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.{ParamMap, ParamsSuite} import org.apache.spark.ml.util.{DefaultReadWriteTest, MetadataUtils, MLTestingUtils} import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS @@ -33,6 +33,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.Metadata class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -136,6 +137,19 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau assert(outputFields.contains("p")) } + test("SPARK-18625 : OneVsRestModel should support setFeaturesCol and setPredictionCol") { + val ova = new OneVsRest().setClassifier(new LogisticRegression) + val ovaModel = ova.fit(dataset) + val dataset2 = dataset.select(col("label").as("y"), col("features").as("fea")) + ovaModel.setFeaturesCol("fea") + ovaModel.setPredictionCol("pred") + val transformedDataset = ovaModel.transform(dataset2) + val outputFields = transformedDataset.schema.fieldNames.toSet + assert(outputFields.contains("y")) + assert(outputFields.contains("fea")) + assert(outputFields.contains("pred")) + } + test("SPARK-8049: OneVsRest shouldn't output temp columns") { val logReg = new LogisticRegression() .setMaxIter(1) From 61dcd51e601a4db69cbacf15334fd41a2a10998d Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Mon, 5 Dec 2016 11:04:08 +0800 Subject: [PATCH 4/5] update --- .../scala/org/apache/spark/ml/classification/OneVsRest.scala | 1 - .../org/apache/spark/ml/classification/OneVsRestSuite.scala | 4 +--- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index c17520fd794a4..674ca77c29393 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -184,7 +184,6 @@ final class OneVsRestModel private[ml] ( predictions + ((index, prediction(1))) } model.setFeaturesCol($(featuresCol)) - model.setPredictionCol($(predictionCol)) val transformedDataset = model.transform(df).select(columns: _*) val updatedDataset = transformedDataset .withColumn(tmpColName, updateUDF(col(accColName), col(rawPredictionCol))) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index d9030de7e22a9..aacb7921b835f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -145,9 +145,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau ovaModel.setPredictionCol("pred") val transformedDataset = ovaModel.transform(dataset2) val outputFields = transformedDataset.schema.fieldNames.toSet - assert(outputFields.contains("y")) - assert(outputFields.contains("fea")) - assert(outputFields.contains("pred")) + assert(outputFields === Set("y", "fea", "pred")) } test("SPARK-8049: OneVsRest shouldn't output temp columns") { From b45f558e4aa5799f2e00d87ab0490ec6755ecc73 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Mon, 5 Dec 2016 11:04:44 +0800 Subject: [PATCH 5/5] update version --- .../scala/org/apache/spark/ml/classification/OneVsRest.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 674ca77c29393..e58b30d66588c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -141,11 +141,11 @@ final class OneVsRestModel private[ml] ( } /** @group setParam */ - @Since("2.2.0") + @Since("2.1.0") def setFeaturesCol(value: String): this.type = set(featuresCol, value) /** @group setParam */ - @Since("2.2.0") + @Since("2.1.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) @Since("1.4.0")