From 4ec606df4e4bb74c9edf79a7629d600cbdbaed91 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Wed, 3 Aug 2016 14:54:56 +0800 Subject: [PATCH] deal with zero thresholds --- .../ProbabilisticClassifier.scala | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 19df8f7edd43c..53a606d009f37 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -201,11 +201,18 @@ abstract class ProbabilisticClassificationModel[ probability.argmax } else { val thresholds: Array[Double] = getThresholds - val scaledProbability: Array[Double] = - probability.toArray.zip(thresholds).map { case (p, t) => - if (t == 0.0) Double.PositiveInfinity else p / t - } - Vectors.dense(scaledProbability).argmax + + if (thresholds.contains(0.0)) { + val indices = thresholds.zipWithIndex.filter(_._1 == 0.0).map(_._2) + val values = indices.map(probability.apply) + Vectors.sparse(numClasses, indices, values).argmax + } else { + val scaledProbability: Array[Double] = + probability.toArray.zip(thresholds).map { case (p, t) => + if (t == 0.0) Double.PositiveInfinity else p / t + } + Vectors.dense(scaledProbability).argmax + } } } }