From 7f53d08b2cfd17353f1662417e3a6999cc3e3408 Mon Sep 17 00:00:00 2001 From: acidghost Date: Thu, 11 Jun 2015 16:09:52 +0200 Subject: [PATCH 1/9] get prediction probabilities of naive bayes classification --- .../mllib/classification/NaiveBayes.scala | 64 +++++++++++++++---- 1 file changed, 53 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index f51ee36d0dfcb..c75acfb6d60cd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -19,6 +19,8 @@ package org.apache.spark.mllib.classification import java.lang.{Iterable => JIterable} +import breeze.linalg.{max, min} + import scala.collection.JavaConverters._ import org.json4s.JsonDSL._ @@ -93,19 +95,10 @@ class NaiveBayesModel private[mllib] ( override def predict(testData: Vector): Double = { modelType match { case Multinomial => - val prob = thetaMatrix.multiply(testData) - BLAS.axpy(1.0, piVector, prob) + val prob = multinomialCalculation(testData) labels(prob.argmax) case Bernoulli => - testData.foreachActive { (index, value) => - if (value != 0.0 && value != 1.0) { - throw new SparkException( - s"Bernoulli naive Bayes requires 0 or 1 feature values but found $testData.") - } - } - val prob = thetaMinusNegTheta.get.multiply(testData) - BLAS.axpy(1.0, piVector, prob) - BLAS.axpy(1.0, negThetaSum.get, prob) + val prob = bernoulliCalculation(testData) labels(prob.argmax) case _ => // This should never happen. @@ -113,6 +106,55 @@ class NaiveBayesModel private[mllib] ( } } + def predictProbabilities(testData: RDD[Vector]): RDD[Map[Double, Double]] = { + val bcModel = testData.context.broadcast(this) + testData.mapPartitions { iter => + val model = bcModel.value + iter.map(model.predictProbabilities) + } + } + + def predictProbabilities(testData: Vector): Map[Double, Double] = { + modelType match { + case Multinomial => + val prob = multinomialCalculation(testData) + posteriorProbabilities(prob) + case Bernoulli => + val prob = bernoulliCalculation(testData) + posteriorProbabilities(prob) + case _ => + // This should never happen. + throw new UnknownError(s"Invalid modelType: $modelType.") + } + } + + protected[classification] def multinomialCalculation(testData: Vector): DenseVector = { + val prob = thetaMatrix.multiply(testData) + BLAS.axpy(1.0, piVector, prob) + prob + } + + protected[classification] def bernoulliCalculation(testData: Vector): DenseVector = { + testData.foreachActive { (index, value) => + if (value != 0.0 && value != 1.0) { + throw new SparkException( + s"Bernoulli naive Bayes requires 0 or 1 feature values but found $testData.") + } + } + val prob = thetaMinusNegTheta.get.multiply(testData) + BLAS.axpy(1.0, piVector, prob) + BLAS.axpy(1.0, negThetaSum.get, prob) + prob + } + + protected[classification] def posteriorProbabilities(prob: DenseVector): Map[Double, Double] = { + val maxLogs = max(prob.toBreeze) + val minLogs = min(prob.toBreeze) + val normalized = prob.toArray.map(e => (e - minLogs) / (maxLogs - minLogs)) + val total = normalized.sum + (for ((v, i) <- normalized.map(_ / total).zipWithIndex) yield (labels(i), v)).toMap + } + override def save(sc: SparkContext, path: String): Unit = { val data = NaiveBayesModel.SaveLoadV2_0.Data(labels, pi, theta, modelType) NaiveBayesModel.SaveLoadV2_0.save(sc, path, data) From 83673fc43c3ec6c09e0e888b81f1c67b30f370f9 Mon Sep 17 00:00:00 2001 From: Andrea Jemmett Date: Mon, 15 Jun 2015 20:01:19 +0200 Subject: [PATCH 2/9] private NaiveBayesModel calculations & pattern matching fallback --- .../spark/mllib/classification/NaiveBayes.scala | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index c75acfb6d60cd..5f565607ef5d5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -100,9 +100,6 @@ class NaiveBayesModel private[mllib] ( case Bernoulli => val prob = bernoulliCalculation(testData) labels(prob.argmax) - case _ => - // This should never happen. - throw new UnknownError(s"Invalid modelType: $modelType.") } } @@ -122,19 +119,16 @@ class NaiveBayesModel private[mllib] ( case Bernoulli => val prob = bernoulliCalculation(testData) posteriorProbabilities(prob) - case _ => - // This should never happen. - throw new UnknownError(s"Invalid modelType: $modelType.") } } - protected[classification] def multinomialCalculation(testData: Vector): DenseVector = { + private def multinomialCalculation(testData: Vector): DenseVector = { val prob = thetaMatrix.multiply(testData) BLAS.axpy(1.0, piVector, prob) prob } - protected[classification] def bernoulliCalculation(testData: Vector): DenseVector = { + private def bernoulliCalculation(testData: Vector): DenseVector = { testData.foreachActive { (index, value) => if (value != 0.0 && value != 1.0) { throw new SparkException( @@ -147,7 +141,7 @@ class NaiveBayesModel private[mllib] ( prob } - protected[classification] def posteriorProbabilities(prob: DenseVector): Map[Double, Double] = { + private def posteriorProbabilities(prob: DenseVector): Map[Double, Double] = { val maxLogs = max(prob.toBreeze) val minLogs = min(prob.toBreeze) val normalized = prob.toArray.map(e => (e - minLogs) / (maxLogs - minLogs)) From 35259a5c97ea39aafdf34e63c61ed984d5cf1795 Mon Sep 17 00:00:00 2001 From: acidghost Date: Tue, 16 Jun 2015 11:01:22 +0200 Subject: [PATCH 3/9] prediction probabilities calculation --- .../apache/spark/mllib/classification/NaiveBayes.scala | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 5f565607ef5d5..8a044130aa774 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -142,11 +142,9 @@ class NaiveBayesModel private[mllib] ( } private def posteriorProbabilities(prob: DenseVector): Map[Double, Double] = { - val maxLogs = max(prob.toBreeze) - val minLogs = min(prob.toBreeze) - val normalized = prob.toArray.map(e => (e - minLogs) / (maxLogs - minLogs)) - val total = normalized.sum - (for ((v, i) <- normalized.map(_ / total).zipWithIndex) yield (labels(i), v)).toMap + val probabilities = prob.toArray.map(p => math.exp(p / 1000)) + val probSum = probabilities.sum + labels.zip(probabilities.map(_ / probSum)).toMap } override def save(sc: SparkContext, path: String): Unit = { From 63c8d20fc3d0bc1fcfb602069f57187afdad8bf6 Mon Sep 17 00:00:00 2001 From: acidghost Date: Tue, 16 Jun 2015 15:20:15 +0200 Subject: [PATCH 4/9] better log probs scaling --- .../org/apache/spark/mllib/classification/NaiveBayes.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 8a044130aa774..128928f01e3eb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -142,7 +142,9 @@ class NaiveBayesModel private[mllib] ( } private def posteriorProbabilities(prob: DenseVector): Map[Double, Double] = { - val probabilities = prob.toArray.map(p => math.exp(p / 1000)) + val probArray = prob.toArray + val maxLog = probArray.max + val probabilities = probArray.map(lp => math.exp(lp / math.abs(maxLog))) val probSum = probabilities.sum labels.zip(probabilities.map(_ / probSum)).toMap } From 41b6cb8203a2a33b2e296ff28177d0594073296e Mon Sep 17 00:00:00 2001 From: acidghost Date: Tue, 16 Jun 2015 16:08:10 +0200 Subject: [PATCH 5/9] prediction probs calculation --- .../org/apache/spark/mllib/classification/NaiveBayes.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 128928f01e3eb..627c714760649 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -144,7 +144,7 @@ class NaiveBayesModel private[mllib] ( private def posteriorProbabilities(prob: DenseVector): Map[Double, Double] = { val probArray = prob.toArray val maxLog = probArray.max - val probabilities = probArray.map(lp => math.exp(lp / math.abs(maxLog))) + val probabilities = probArray.map(lp => math.exp(lp - maxLog)) val probSum = probabilities.sum labels.zip(probabilities.map(_ / probSum)).toMap } From 29a114758961d3f42cdc95a40c9f5105e62cf7a4 Mon Sep 17 00:00:00 2001 From: acidghost Date: Fri, 19 Jun 2015 10:13:52 +0200 Subject: [PATCH 6/9] prediction probabilities as Vector --- .../apache/spark/mllib/classification/NaiveBayes.scala | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 627c714760649..ce3cf9f0faa3c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -19,8 +19,6 @@ package org.apache.spark.mllib.classification import java.lang.{Iterable => JIterable} -import breeze.linalg.{max, min} - import scala.collection.JavaConverters._ import org.json4s.JsonDSL._ @@ -103,7 +101,7 @@ class NaiveBayesModel private[mllib] ( } } - def predictProbabilities(testData: RDD[Vector]): RDD[Map[Double, Double]] = { + def predictProbabilities(testData: RDD[Vector]): RDD[Vector] = { val bcModel = testData.context.broadcast(this) testData.mapPartitions { iter => val model = bcModel.value @@ -111,7 +109,7 @@ class NaiveBayesModel private[mllib] ( } } - def predictProbabilities(testData: Vector): Map[Double, Double] = { + def predictProbabilities(testData: Vector): Vector = { modelType match { case Multinomial => val prob = multinomialCalculation(testData) @@ -141,12 +139,12 @@ class NaiveBayesModel private[mllib] ( prob } - private def posteriorProbabilities(prob: DenseVector): Map[Double, Double] = { + private def posteriorProbabilities(prob: DenseVector): Vector = { val probArray = prob.toArray val maxLog = probArray.max val probabilities = probArray.map(lp => math.exp(lp - maxLog)) val probSum = probabilities.sum - labels.zip(probabilities.map(_ / probSum)).toMap + new DenseVector(labels.zip(probabilities.map(_ / probSum)).sortBy(_._1).map(_._2)) } override def save(sc: SparkContext, path: String): Unit = { From 843da0c4b1a6abfe2c7cfb5188f712adb37bd069 Mon Sep 17 00:00:00 2001 From: acidghost Date: Fri, 19 Jun 2015 10:43:51 +0200 Subject: [PATCH 7/9] testing prediction probabilities are "sane" --- .../classification/NaiveBayesSuite.scala | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index f7fc8730606af..fe50ee0a263a0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -116,6 +116,22 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { } } + def validatePredictionsProbabilities(predictionsProbabilities: Seq[Array[Double]], input: Seq[LabeledPoint]) = { + predictionsProbabilities.foreach { probabilities => + val sum = probabilities.sum + // Check that prediction probabilities sum up to one + // with an epsilon of 10^-2 + assert(sum > 0.99 && sum < 1.01) + } + + val wrongPredictions = predictionsProbabilities.zip(input).count { + case (prediction, expected) => + prediction.indexOf(prediction.max).toDouble != expected.label + } + // At least 80% of the predictions should be on. + assert(wrongPredictions < input.length / 5) + } + test("model types") { assert(Multinomial === "multinomial") assert(Bernoulli === "bernoulli") @@ -154,6 +170,12 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + + // Test prediction probabilities on RDD. + validatePredictionsProbabilities(model.predictProbabilities(validationRDD.map(_.features)).map(_.toArray).collect(), validationData) + + // Test prediction probabilities on Array. + validatePredictionsProbabilities(validationData.map(row => model.predictProbabilities(row.features)).map(_.toArray), validationData) } test("Naive Bayes Bernoulli") { @@ -182,6 +204,12 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + + // Test prediction probabilities on RDD. + validatePredictionsProbabilities(model.predictProbabilities(validationRDD.map(_.features)).map(_.toArray).collect(), validationData) + + // Test prediction probabilities on Array. + validatePredictionsProbabilities(validationData.map(row => model.predictProbabilities(row.features)).map(_.toArray), validationData) } test("detect negative values") { From 49f0d5944fc0030dd445f6a8eb55975bc3c1b9a2 Mon Sep 17 00:00:00 2001 From: Andrea Jemmett Date: Sun, 21 Jun 2015 14:05:50 +0200 Subject: [PATCH 8/9] initial tests for posteriors correctness --- .../classification/NaiveBayesSuite.scala | 78 ++++++++++++++++++- 1 file changed, 75 insertions(+), 3 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index fe50ee0a263a0..9dce43af34e95 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -116,7 +116,9 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { } } - def validatePredictionsProbabilities(predictionsProbabilities: Seq[Array[Double]], input: Seq[LabeledPoint]) = { + def validatePredictionsProbabilities(predictionsProbabilities: Seq[Array[Double]], + input: Seq[LabeledPoint], + modelType: String = Multinomial) = { predictionsProbabilities.foreach { probabilities => val sum = probabilities.sum // Check that prediction probabilities sum up to one @@ -130,6 +132,74 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { } // At least 80% of the predictions should be on. assert(wrongPredictions < input.length / 5) + + comparePosteriorsWithR(predictionsProbabilities.take(10), modelType) + } + + /** + * The following is the instruction to reproduce the model using R's e1071 package. + * + * First of all, using the following scala code to save the data into `path`. + * + * testRDD.map { x => + * s"${x.label}, ${x.features.toArray.mkString(", ")}" + * }.saveAsTextFile("path") + * + * Using the following R code to load the data and train the model using e1071 package. + * + * library(e1071) + * data <- read.csv("path", header = FALSE) + * labels <- factor(data$V1) + * features <- data.frame(data$V2, data$V3, data$V4, data$V5) + * model <- naiveBayes(features, labels) + * predictions <- predict(model, features[1:10, -1], type = "raw") + * + */ + def comparePosteriorsWithR(predictionsProbabilities: Seq[Array[Double]], + modelType: String = Multinomial, + epsilon: Double = 0.1) = { + require(predictionsProbabilities.length == 10) + + val posteriorsFromR = modelType match { + case Multinomial => + Array( + Array(2.942994e-07, 5.467545e-11, 9.999997e-01), + Array(2.931850e-07, 4.922381e-12, 9.999997e-01), + Array(9.997708e-01, 3.424392e-06, 2.257879e-04), + Array(9.991757e-01, 6.132008e-04, 2.110877e-04), + Array(9.281650e-14, 8.199463e-17, 1.000000e+00), + Array(8.099445e-01, 3.821142e-05, 1.900173e-01), + Array(2.667884e-01, 7.331288e-01, 8.276015e-05), + Array(9.999776e-01, 2.163690e-06, 2.023486e-05), + Array(9.997814e-01, 2.441990e-06, 2.161960e-04), + Array(8.850206e-14, 6.692205e-18, 1.000000e+00) + ) + case Bernoulli => + Array( + Array(1.048099e-09, 1.000000e+00, 1.578642e-09), + Array(1.831993e-19, 9.999999e-01, 1.190036e-07), + Array(4.664977e-12, 1.000000e+00, 1.291666e-12), + Array(3.224249e-11, 2.433594e-02, 9.756641e-01), + Array(9.610916e-01, 1.256859e-13, 3.890841e-02), + Array(8.318820e-01, 1.097496e-01, 5.836849e-02), + Array(8.318820e-01, 1.097496e-01, 5.836849e-02), + Array(9.610916e-01, 1.256859e-13, 3.890841e-02), + Array(8.318820e-01, 1.097496e-01, 5.836849e-02), + Array(8.318820e-01, 1.097496e-01, 5.836849e-02) + ) + } + + predictionsProbabilities.zip(posteriorsFromR).foreach { + case (probs, fromR) => + val p = probs.indexOf(probs.max) + val r = fromR.indexOf(fromR.max) + // Checking that the prediction is the same + if (p == r) { + probs.zip(fromR).foreach { + case (prob, probFromR) => assert(prob > (probFromR - epsilon) && prob < (probFromR + epsilon)) + } + } + } } test("model types") { @@ -206,10 +276,12 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { validatePrediction(validationData.map(row => model.predict(row.features)), validationData) // Test prediction probabilities on RDD. - validatePredictionsProbabilities(model.predictProbabilities(validationRDD.map(_.features)).map(_.toArray).collect(), validationData) + validatePredictionsProbabilities(model.predictProbabilities(validationRDD.map(_.features)).map(_.toArray).collect(), + validationData, Bernoulli) // Test prediction probabilities on Array. - validatePredictionsProbabilities(validationData.map(row => model.predictProbabilities(row.features)).map(_.toArray), validationData) + validatePredictionsProbabilities(validationData.map(row => model.predictProbabilities(row.features)).map(_.toArray), + validationData, Bernoulli) } test("detect negative values") { From 2ce7853ee1fa0ea3561d09dd9f37a47dd53f7d32 Mon Sep 17 00:00:00 2001 From: acidghost Date: Mon, 22 Jun 2015 12:51:50 +0200 Subject: [PATCH 9/9] better epsilon for sum to one --- .../apache/spark/mllib/classification/NaiveBayesSuite.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index 9dce43af34e95..c8c820d50b154 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} +import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils object NaiveBayesSuite { @@ -122,8 +123,8 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { predictionsProbabilities.foreach { probabilities => val sum = probabilities.sum // Check that prediction probabilities sum up to one - // with an epsilon of 10^-2 - assert(sum > 0.99 && sum < 1.01) + // with an epsilon of 10^-5 + assert(sum ~== 1.0 relTol 0.00001) } val wrongPredictions = predictionsProbabilities.zip(input).count {