Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1169,7 +1169,7 @@ class BinaryLogisticRegressionSummary private[classification] (
// TODO: Allow the user to vary the number of bins using a setBins method in
// BinaryClassificationMetrics. For now the default is set to 100.
@transient private val binaryMetrics = new BinaryClassificationMetrics(
predictions.select(probabilityCol, labelCol).rdd.map {
predictions.select(col(probabilityCol), col(labelCol).cast(DoubleType)).rdd.map {
case Row(score: Vector, label: Double) => (score(1), label)
}, 100
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -992,7 +992,7 @@ class GeneralizedLinearRegressionSummary private[regression] (
} else {
link.unlink(0.0)
}
predictions.select(col(model.getLabelCol), w).rdd.map {
predictions.select(col(model.getLabelCol).cast(DoubleType), w).rdd.map {
case Row(y: Double, weight: Double) =>
family.deviance(y, wtdmu, weight)
}.sum()
Expand All @@ -1004,7 +1004,7 @@ class GeneralizedLinearRegressionSummary private[regression] (
@Since("2.0.0")
lazy val deviance: Double = {
val w = weightCol
predictions.select(col(model.getLabelCol), col(predictionCol), w).rdd.map {
predictions.select(col(model.getLabelCol).cast(DoubleType), col(predictionCol), w).rdd.map {
case Row(label: Double, pred: Double, weight: Double) =>
family.deviance(label, pred, weight)
}.sum()
Expand All @@ -1030,9 +1030,10 @@ class GeneralizedLinearRegressionSummary private[regression] (
lazy val aic: Double = {
val w = weightCol
val weightSum = predictions.select(w).agg(sum(w)).first().getDouble(0)
val t = predictions.select(col(model.getLabelCol), col(predictionCol), w).rdd.map {
case Row(label: Double, pred: Double, weight: Double) =>
(label, pred, weight)
val t = predictions.select(
col(model.getLabelCol).cast(DoubleType), col(predictionCol), w).rdd.map {
case Row(label: Double, pred: Double, weight: Double) =>
(label, pred, weight)
}
family.aic(t, deviance, numInstances, weightSum) + 2 * rank
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.types.LongType

class LogisticRegressionSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
Expand Down Expand Up @@ -1776,6 +1777,21 @@ class LogisticRegressionSuite
summary.precisionByThreshold.collect() === sameSummary.precisionByThreshold.collect())
}

test("evaluate with labels that are not doubles") {
// Evaluate a test set with Label that is a numeric type other than Double
val lr = new LogisticRegression()
.setMaxIter(1)
.setRegParam(1.0)
val model = lr.fit(smallBinaryDataset)
val summary = model.evaluate(smallBinaryDataset).asInstanceOf[BinaryLogisticRegressionSummary]

val longLabelData = smallBinaryDataset.select(col(model.getLabelCol).cast(LongType),
col(model.getFeaturesCol))
val longSummary = model.evaluate(longLabelData).asInstanceOf[BinaryLogisticRegressionSummary]

assert(summary.areaUnderROC === longSummary.areaUnderROC)
}

test("statistics on training data") {
// Test that loss is monotonically decreasing.
val lr = new LogisticRegression()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.mllib.random._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.FloatType

class GeneralizedLinearRegressionSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
Expand Down Expand Up @@ -1067,6 +1068,30 @@ class GeneralizedLinearRegressionSuite
idx += 1
}
}

test("evaluate with labels that are not doubles") {
// Evaulate with a dataset that contains Labels not as doubles to verify correct casting
val dataset = Seq(
Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
Instance(19.0, 1.0, Vectors.dense(1.0, 7.0)),
Instance(23.0, 1.0, Vectors.dense(2.0, 11.0)),
Instance(29.0, 1.0, Vectors.dense(3.0, 13.0))
).toDF()

val trainer = new GeneralizedLinearRegression()
.setMaxIter(1)
val model = trainer.fit(dataset)
assert(model.hasSummary)
val summary = model.summary

val longLabelDataset = dataset.select(col(model.getLabelCol).cast(FloatType),
col(model.getFeaturesCol))
val evalSummary = model.evaluate(longLabelDataset)
// The calculations below involve pattern matching with Label as a double
assert(evalSummary.nullDeviance === summary.nullDeviance)
assert(evalSummary.deviance === summary.deviance)
assert(evalSummary.aic === summary.aic)
}
}

object GeneralizedLinearRegressionSuite {
Expand Down