diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala new file mode 100644 index 0000000000000..37a078c5f95ea --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import scala.reflect.ClassTag + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.internal.Logging +import org.apache.spark.ml.param.{IntParam, Param, ParamMap, ParamValidators} +import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} +import org.apache.spark.mllib.evaluation.RankingMetrics +import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext} +import org.apache.spark.sql.functions._ + +/** + * :: Experimental :: + * Evaluator for ranking, which expects two input columns: prediction and label. + * Both prediction and label columns need to be instances of Array[T] where T is the ClassTag. + */ +@Since("2.0.0") +@Experimental +final class RankingEvaluator[T: ClassTag] @Since("2.0.0") (@Since("2.0.0") override val uid: String) + extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable with Logging { + + @Since("2.0.0") + def this() = this(Identifiable.randomUID("rankingEval")) + + @Since("2.0.0") + final val k = new IntParam(this, "k", "Top-K cutoff", (x: Int) => x > 0) + + /** @group getParam */ + @Since("2.0.0") + def getK: Int = $(k) + + /** @group setParam */ + @Since("2.0.0") + def setK(value: Int): this.type = set(k, value) + + setDefault(k -> 1) + + /** + * Param for metric name in evaluation. Supports: + * - `"map"` (default): Mean Average Precision + * - `"mapk"`: Mean Average Precision@K + * - `"ndcg"`: Normalized Discounted Cumulative Gain + * - `"mrr"`: Mean Reciprocal Rank + * + * @group param + */ + @Since("2.0.0") + val metricName: Param[String] = { + val allowedParams = ParamValidators.inArray(Array("map", "mapk", "ndcg", "mrr")) + new Param(this, "metricName", "metric name in evaluation (map|mapk|ndcg|mrr)", allowedParams) + } + + /** @group getParam */ + @Since("2.0.0") + def getMetricName: String = $(metricName) + + /** @group setParam */ + @Since("2.0.0") + def setMetricName(value: String): this.type = set(metricName, value) + + /** @group setParam */ + @Since("2.0.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + @Since("2.0.0") + def setLabelCol(value: String): this.type = set(labelCol, value) + + setDefault(metricName -> "map") + + @Since("2.0.0") + override def evaluate(dataset: Dataset[_]): Double = { + val schema = dataset.schema + val predictionColName = $(predictionCol) + val predictionType = schema($(predictionCol)).dataType + val labelColName = $(labelCol) + val labelType = schema($(labelCol)).dataType + require(predictionType == labelType, + s"Prediction column $predictionColName and Label column $labelColName " + + s"must be of the same type, but Prediction column $predictionColName is $predictionType " + + s"and Label column $labelColName is $labelType") + + val predictionAndLabels = dataset + .select(col($(predictionCol)).cast(predictionType), col($(labelCol)).cast(labelType)) + .rdd. + map { case Row(prediction: Seq[T], label: Seq[T]) => (prediction.toArray, label.toArray) } + + val metrics = new RankingMetrics[T](predictionAndLabels) + val metric = $(metricName) match { + case "map" => metrics.meanAveragePrecision + case "ndcg" => metrics.ndcgAt($(k)) + case "mapk" => metrics.precisionAt($(k)) + case "mrr" => metrics.meanReciprocalRank + } + metric + } + + @Since("2.0.0") + override def isLargerBetter: Boolean = $(metricName) match { + case "map" => false + case "ndcg" => false + case "mapk" => false + case "mrr" => false + } + + @Since("2.0.0") + override def copy(extra: ParamMap): RankingEvaluator[T] = defaultCopy(extra) +} + +@Since("2.0.0") +object RankingEvaluator extends DefaultParamsReadable[RankingEvaluator[_]] { + + @Since("2.0.0") + override def load(path: String): RankingEvaluator[_] = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala index b98aa0534152b..22fdfb6f7bbf9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala @@ -155,6 +155,45 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])] }.mean() } + /** + * Compute the mean reciprocal rank (MRR) of all the queries. + * + * MRR is the inverse position of the first relevant document, and is therefore well-suited + * to applications in which only the first result matters.The reciprocal rank is the + * multiplicative inverse of the rank of the first correct answer for a query response and + * the mean reciprocal rank is the average of the reciprocal ranks of results for a sample + * of queries. MRR is well-suited to applications in which only the first result matters. + * + * If a query has an empty ground truth set, zero will be used as precision together with + * a log warning. + * + * See the following paper for detail: + * + * Brian McFee, Gert R. G. Lanckriet Metric Learning to Rank. ICML 2010: 775-782 + * + * @return the mean reciprocal rank of all the queries. + */ + lazy val meanReciprocalRank: Double = { + predictionAndLabels.map { case (pred, lab) => + val labSet = lab.toSet + + if (labSet.nonEmpty) { + var i = 0 + var reciprocalRank = 0.0 + while (i < pred.length && reciprocalRank == 0.0) { + if (labSet.contains(pred(i))) { + reciprocalRank = 1.0 / (i + 1) + } + i += 1 + } + reciprocalRank + } else { + logWarning("Empty ground truth set, check input data") + 0.0 + } + }.mean() + } + } object RankingMetrics { diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RankingEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RankingEvaluatorSuite.scala new file mode 100644 index 0000000000000..76833e0dd9fa2 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RankingEvaluatorSuite.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class RankingEvaluatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + test("params") { + ParamsSuite.checkParams(new RankingEvaluator) + } + + test("Ranking Evaluator: default params") { + val sqlContext = new org.apache.spark.sql.SQLContext(sc) + import sqlContext.implicits._ + + val predictionAndLabels = sqlContext.createDataFrame(sc.parallelize( + Seq( + (Array[Int](1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array[Int](1, 2, 3, 4, 5)), + (Array[Int](4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Array[Int](1, 2, 3)), + (Array[Int](1, 2, 3, 4, 5), Array[Int]()) + ), 2)).toDF(Seq("prediction", "label"): _*) + + // default = map, k = 1 + val evaluator = new RankingEvaluator() + assert(evaluator.evaluate(predictionAndLabels) ~== 0.355026 absTol 0.01) + + // mapk, k = 5 + evaluator.setMetricName("mapk").setK(5) + assert(evaluator.evaluate(predictionAndLabels) ~== 0.8/3 absTol 0.01) + + // ndcg, k = 5 + evaluator.setMetricName("ndcg") + assert(evaluator.evaluate(predictionAndLabels) ~== 0.328788 absTol 0.01) + + // mrr + evaluator.setMetricName("mrr") + assert(evaluator.evaluate(predictionAndLabels) ~== 0.5 absTol 0.01) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala index f334be2c2ba83..29de9111660b0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.mllib.util.TestingUtils._ class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { - test("Ranking metrics: MAP, NDCG") { + test("Ranking metrics: MAP, NDCG, MRR, PRECK") { val predictionAndLabels = sc.parallelize( Seq( (Array(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array(1, 2, 3, 4, 5)), @@ -34,6 +34,7 @@ class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { val metrics = new RankingMetrics(predictionAndLabels) val map = metrics.meanAveragePrecision + val mrr = metrics.meanReciprocalRank assert(metrics.precisionAt(1) ~== 1.0/3 absTol eps) assert(metrics.precisionAt(2) ~== 1.0/3 absTol eps) @@ -49,6 +50,8 @@ class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { assert(metrics.ndcgAt(5) ~== 0.328788 absTol eps) assert(metrics.ndcgAt(10) ~== 0.487913 absTol eps) assert(metrics.ndcgAt(15) ~== metrics.ndcgAt(10) absTol eps) + + assert(mrr ~== 0.5 absTol eps) } test("MAP, NDCG with few predictions (SPARK-14886)") {