From 91f5be81a24eb14ae6206fa5f01cb2cf286504b6 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 28 Sep 2016 15:15:08 -0700 Subject: [PATCH 1/3] added fix in LogisticRegression to evaluate with numeric types other than double --- .../classification/LogisticRegression.scala | 2 +- .../LogisticRegressionSuite.scala | 20 ++++++++++++++++++- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 5ab63d1de95d3..329961a25d984 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1169,7 +1169,7 @@ class BinaryLogisticRegressionSummary private[classification] ( // TODO: Allow the user to vary the number of bins using a setBins method in // BinaryClassificationMetrics. For now the default is set to 100. @transient private val binaryMetrics = new BinaryClassificationMetrics( - predictions.select(probabilityCol, labelCol).rdd.map { + predictions.select(col(probabilityCol), col(labelCol).cast(DoubleType)).rdd.map { case Row(score: Vector, label: Double) => (score(1), label) }, 100 ) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 8451e60144981..0673a21963233 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -32,7 +32,8 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Dataset, Row} -import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.functions.{col, lit} +import org.apache.spark.sql.types.LongType class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -1776,6 +1777,23 @@ class LogisticRegressionSuite summary.precisionByThreshold.collect() === sameSummary.precisionByThreshold.collect()) } + test("evaluate with labels that are not doubles") { + // Evaluate a test set with Label that is a numeric type other than Double + val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(1.0) + .setThreshold(0.6) + val model = lr.fit(smallBinaryDataset) + + val summary = model.evaluate(smallBinaryDataset).asInstanceOf[BinaryLogisticRegressionSummary] + + val longLabelData = smallBinaryDataset.select(col(model.getLabelCol).cast(LongType), + col(model.getFeaturesCol)) + val longSummary = model.evaluate(longLabelData).asInstanceOf[BinaryLogisticRegressionSummary] + + assert(summary.areaUnderROC === longSummary.areaUnderROC) + } + test("statistics on training data") { // Test that loss is monotonically decreasing. val lr = new LogisticRegression() From 5c508611e0148b8224be6b06158d1ed4f349a70d Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 28 Sep 2016 15:45:59 -0700 Subject: [PATCH 2/3] fixed same issue in GeneralizedLinearRegression --- .../GeneralizedLinearRegression.scala | 11 ++++---- .../GeneralizedLinearRegressionSuite.scala | 27 +++++++++++++++++++ 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 02b27fb650979..bb9e150c49772 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -992,7 +992,7 @@ class GeneralizedLinearRegressionSummary private[regression] ( } else { link.unlink(0.0) } - predictions.select(col(model.getLabelCol), w).rdd.map { + predictions.select(col(model.getLabelCol).cast(DoubleType), w).rdd.map { case Row(y: Double, weight: Double) => family.deviance(y, wtdmu, weight) }.sum() @@ -1004,7 +1004,7 @@ class GeneralizedLinearRegressionSummary private[regression] ( @Since("2.0.0") lazy val deviance: Double = { val w = weightCol - predictions.select(col(model.getLabelCol), col(predictionCol), w).rdd.map { + predictions.select(col(model.getLabelCol).cast(DoubleType), col(predictionCol), w).rdd.map { case Row(label: Double, pred: Double, weight: Double) => family.deviance(label, pred, weight) }.sum() @@ -1030,9 +1030,10 @@ class GeneralizedLinearRegressionSummary private[regression] ( lazy val aic: Double = { val w = weightCol val weightSum = predictions.select(w).agg(sum(w)).first().getDouble(0) - val t = predictions.select(col(model.getLabelCol), col(predictionCol), w).rdd.map { - case Row(label: Double, pred: Double, weight: Double) => - (label, pred, weight) + val t = predictions.select( + col(model.getLabelCol).cast(DoubleType), col(predictionCol), w).rdd.map { + case Row(label: Double, pred: Double, weight: Double) => + (label, pred, weight) } family.aic(t, deviance, numInstances, weightSum) + 2 * rank } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 937aa7d3c2045..5b02339e7ce38 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.mllib.random._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.FloatType class GeneralizedLinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -1067,6 +1068,32 @@ class GeneralizedLinearRegressionSuite idx += 1 } } + + test("evaluate with labels that are not doubles") { + // Evaulate with a dataset that contains Labels not as doubles to verify correct casting + val datasetWithWeight = Seq( + Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(29.0, 4.0, Vectors.dense(3.0, 13.0)) + ).toDF() + + val trainer = new GeneralizedLinearRegression() + .setWeightCol("weight") + val model = trainer.fit(datasetWithWeight) + + assert(model.hasSummary) + val summary = model.summary + assert(summary.isInstanceOf[GeneralizedLinearRegressionTrainingSummary]) + + val longLabelDataset = datasetWithWeight.select(col(model.getLabelCol).cast(FloatType), + col(model.getFeaturesCol), col(model.getWeightCol)) + val evalSummary = model.evaluate(longLabelDataset) + // The calculations below involve pattern matching with Label as a double + assert(evalSummary.nullDeviance === summary.nullDeviance) + assert(evalSummary.deviance === summary.deviance) + assert(evalSummary.aic === summary.aic) + } } object GeneralizedLinearRegressionSuite { From 640f84b6ac7064d5848107910eccf7a25587fafb Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 29 Sep 2016 11:03:41 -0700 Subject: [PATCH 3/3] simplified unit tests --- .../LogisticRegressionSuite.scala | 6 ++---- .../GeneralizedLinearRegressionSuite.scala | 18 ++++++++---------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 0673a21963233..42b56754e0835 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -1780,11 +1780,9 @@ class LogisticRegressionSuite test("evaluate with labels that are not doubles") { // Evaluate a test set with Label that is a numeric type other than Double val lr = new LogisticRegression() - .setMaxIter(10) - .setRegParam(1.0) - .setThreshold(0.6) + .setMaxIter(1) + .setRegParam(1.0) val model = lr.fit(smallBinaryDataset) - val summary = model.evaluate(smallBinaryDataset).asInstanceOf[BinaryLogisticRegressionSummary] val longLabelData = smallBinaryDataset.select(col(model.getLabelCol).cast(LongType), diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 5b02339e7ce38..ac1ef5feb95ba 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -1071,23 +1071,21 @@ class GeneralizedLinearRegressionSuite test("evaluate with labels that are not doubles") { // Evaulate with a dataset that contains Labels not as doubles to verify correct casting - val datasetWithWeight = Seq( + val dataset = Seq( Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), - Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)), - Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)), - Instance(29.0, 4.0, Vectors.dense(3.0, 13.0)) + Instance(19.0, 1.0, Vectors.dense(1.0, 7.0)), + Instance(23.0, 1.0, Vectors.dense(2.0, 11.0)), + Instance(29.0, 1.0, Vectors.dense(3.0, 13.0)) ).toDF() val trainer = new GeneralizedLinearRegression() - .setWeightCol("weight") - val model = trainer.fit(datasetWithWeight) - + .setMaxIter(1) + val model = trainer.fit(dataset) assert(model.hasSummary) val summary = model.summary - assert(summary.isInstanceOf[GeneralizedLinearRegressionTrainingSummary]) - val longLabelDataset = datasetWithWeight.select(col(model.getLabelCol).cast(FloatType), - col(model.getFeaturesCol), col(model.getWeightCol)) + val longLabelDataset = dataset.select(col(model.getLabelCol).cast(FloatType), + col(model.getFeaturesCol)) val evalSummary = model.evaluate(longLabelDataset) // The calculations below involve pattern matching with Label as a double assert(evalSummary.nullDeviance === summary.nullDeviance)