From 15ba3c26f388c2c9545057f86786e241e6d9bfde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=8B=E6=98=9F=E5=8D=9A?= Date: Mon, 11 Jul 2016 19:11:26 +0800 Subject: [PATCH 01/22] [SPARK-16282][SQL] Implement percentile SQL function. --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/aggregate/Percentile.scala | 151 ++++++++++++++++++ .../org/apache/spark/sql/functions.scala | 19 +++ .../spark/sql/DataFrameAggregateSuite.scala | 16 ++ 4 files changed, 187 insertions(+) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 007cdc1ccbe4e..2636afe6209ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -249,6 +249,7 @@ object FunctionRegistry { expression[Max]("max"), expression[Average]("mean"), expression[Min]("min"), + expression[Percentile]("percentile"), expression[Skewness]("skewness"), expression[ApproximatePercentile]("percentile_approx"), expression[StddevSamp]("std"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala new file mode 100644 index 0000000000000..13809dea39b54 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.types._ +import org.apache.spark.util.collection.OpenHashMap + +import scala.collection.mutable + + +/** + * The Percentile aggregate function computes the exact percentile(s) of expr at pc with range in + * [0, 1]. + * The parameter pc can be a DoubleType or DoubleType array. + */ +@ExpressionDescription( + usage = """_FUNC_(epxr, pc) - Returns the percentile(s) of expr at pc (range: [0,1]). pc can be + a double or double array.""") +case class Percentile( + child: Expression, + pc: Seq[Double], + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends ImperativeAggregate { + + def this(child: Expression, pc: Double) = { + this(child = child, pc = Seq(pc), mutableAggBufferOffset = 0, inputAggBufferOffset = 0) + } + + override def prettyName: String = "percentile" + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + var counts = new OpenHashMap[Long, Long]() + + override def children: Seq[Expression] = Seq(child) + + override def nullable: Boolean = false + + override def dataType: DataType = ArrayType(DoubleType) + + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + override def supportsPartial: Boolean = false + + override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + + override val aggBufferAttributes: Seq[AttributeReference] = pc.map(percentile => + AttributeReference(percentile.toString, DoubleType)()) + + override val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) + + override def initialize(buffer: MutableRow): Unit = { + for (i <- 0 until pc.size) { + buffer.setNullAt(mutableAggBufferOffset + i) + } + } + + override def update(buffer: MutableRow, input: InternalRow): Unit = { + val v = child.eval(input) + + v match { + case o: Int => counts.changeValue(o.toLong, 1L, _ + 1L) + case o: Long => counts.changeValue(o, 1L, _ + 1L) + case _ => return false + } + } + + override def merge(buffer: MutableRow, inputBuffer: InternalRow): Unit = { + sys.error("Percentile cannot be used in partial aggregations.") + } + + override def eval(buffer: InternalRow): Any = { + if (counts.size == 0) { + return new GenericArrayData(Seq.empty) + } + + // Sort all items and generate a sequence, then accumulate the counts + val sortedCounts = counts.toSeq.sortBy(_._1) + val aggreCounts = sortedCounts.scanLeft(0L, 0L) { (k1: (Long, Long), k2: (Long, Long)) => + (k2._1, k1._2 + k2._2) + }.drop(1) + val maxPosition = aggreCounts.last._2 - 1 + + new GenericArrayData(pc.map { percentile => + if (percentile < 0.0 || percentile > 1.0) { + sys.error("Percentile value must be within the range of 0 to 1.") + } + getPercentile(aggreCounts, maxPosition * percentile) + }) + } + + /** + * Get the percentile value. + */ + private def getPercentile(aggreCounts: Seq[(Long, Long)], position: Double): Double = { + // We may need to do linear interpolation to get the exact percentile + val lower = position.floor + val higher = position.ceil + + // Linear search since this won't take much time from the total execution anyway + // lower has the range of [0 .. total-1] + // The first entry with accumulated count (lower+1) corresponds to the lower position. + var i = 0 + while (aggreCounts(i)._2 < lower + 1) { + i += 1 + } + + val lowerKey = aggreCounts(i)._1 + if (higher == lower) { + // no interpolation needed because position does not have a fraction + return lowerKey + } + + if (aggreCounts(i)._2 < higher + 1) { + i += 1 + } + val higherKey = aggreCounts(i)._1 + + if (higherKey == lowerKey) { + // no interpolation needed because lower position and higher position has the same key + return lowerKey + } + + // Linear interpolation to get the exact percentile + return (higher - position) * lowerKey + (position - lower) * higherKey + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d5940c638acdb..7f5348f6373dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -612,6 +612,25 @@ object functions { */ def min(columnName: String): Column = min(Column(columnName)) + /** + * Aggregate function: returns the exact percentile(s) of the expression in a group at pc with + * range in [0, 1]. + * + * @group agg_funcs + * @since 2.1.0 + */ + def percentile(e: Column, pc: Seq[Double]): Column = + withAggregateFunction {Percentile(e.expr, pc)} + + /** + * Aggregate function: returns the exact percentile(s) of the column in a group at pc with range + * in [0, 1]. + * + * @group agg_funcs + * @since 2.1.0 + */ + def percentile(columnName: String, pc: Seq[Double]): Column = percentile(Column(columnName), pc) + /** * Aggregate function: returns the skewness of the values in a group. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 7aa4f0026f275..23553649e0574 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -513,4 +513,20 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { df.groupBy($"x").agg(countDistinct($"y"), sort_array(collect_list($"z"))), Seq(Row(1, 2, Seq("a", "b")), Row(3, 2, Seq("c", "c", "d")))) } + + test("percentile functions") { + val df = Seq(1, 3, 3, 6, 5, 4, 17, 38, 29, 400).toDF("a") + checkAnswer( + df.select(percentile($"a", Seq(0.5)), percentile($"a", Seq(0, 0.75, 1))), + Seq(Row(Seq(5.5), Seq(1.0, 26.0, 400.0))) + ) + } + + test("percentile functions with zero input rows.") { + val df = Seq(1, 3, 3, 6, 5, 4, 17, 38, 29, 400).toDF("a").where($"a" < 0) + checkAnswer( + df.select(percentile($"a", Seq(0.5))), + Seq(Row(Seq.empty)) + ) + } } From ef14aab54a6f0c45c8bd4726492e4ff0825d473d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=8B=E6=98=9F=E5=8D=9A?= Date: Mon, 11 Jul 2016 19:26:29 +0800 Subject: [PATCH 02/22] remove unused import --- .../spark/sql/catalyst/expressions/aggregate/Percentile.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index 13809dea39b54..0df0d47c4e3c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -23,9 +23,6 @@ import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashMap -import scala.collection.mutable - - /** * The Percentile aggregate function computes the exact percentile(s) of expr at pc with range in * [0, 1]. From 91ddabd808a16f9dfbffd647f7805093d0e5ff44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=8B=E6=98=9F=E5=8D=9A?= Date: Thu, 14 Jul 2016 19:27:38 +0800 Subject: [PATCH 03/22] 1. Add support for all numeric types; 2. Add comments to document the memory characeristics; 3. Add basic SQL tests. 4. Some other refactor. --- .../expressions/aggregate/Percentile.scala | 95 ++++++++++++------- .../org/apache/spark/sql/functions.scala | 23 ++++- .../spark/sql/DataFrameAggregateSuite.scala | 4 +- .../sql/catalyst/ExpressionToSQLSuite.scala | 4 + 4 files changed, 87 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index 0df0d47c4e3c4..2879525f74e8f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.catalyst.util.{TypeUtils, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashMap @@ -27,19 +29,29 @@ import org.apache.spark.util.collection.OpenHashMap * The Percentile aggregate function computes the exact percentile(s) of expr at pc with range in * [0, 1]. * The parameter pc can be a DoubleType or DoubleType array. + * + * The operator is bound to the slower sort based aggregation path because the number of elements + * and their partial order cannot be determined in advance. Therefore we have to store all the + * elements in memory, and that too many elements can cause GC paused and eventually OutOfMemory + * Errors. */ @ExpressionDescription( usage = """_FUNC_(epxr, pc) - Returns the percentile(s) of expr at pc (range: [0,1]). pc can be a double or double array.""") case class Percentile( - child: Expression, - pc: Seq[Double], - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends ImperativeAggregate { - - def this(child: Expression, pc: Double) = { - this(child = child, pc = Seq(pc), mutableAggBufferOffset = 0, inputAggBufferOffset = 0) + child: Expression, + pc: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends ImperativeAggregate { + + def this(child: Expression, pc: Expression) = { + this(child = child, pc = pc, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) + } + + private val percentiles: Seq[Double] = pc match { + case Literal(ar: GenericArrayData, _: ArrayType) => + ar.asInstanceOf[GenericArrayData].array.map{ d => d.asInstanceOf[Double]} + case _ => Seq.empty } override def prettyName: String = "percentile" @@ -50,40 +62,52 @@ case class Percentile( override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = copy(inputAggBufferOffset = newInputAggBufferOffset) - var counts = new OpenHashMap[Long, Long]() + private var counts = new OpenHashMap[Double, Long]() - override def children: Seq[Expression] = Seq(child) + override def children: Seq[Expression] = child :: pc :: Nil override def nullable: Boolean = false override def dataType: DataType = ArrayType(DoubleType) - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, NumericType) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForOrderingExpr(child.dataType, "function percentile") override def supportsPartial: Boolean = false override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) - override val aggBufferAttributes: Seq[AttributeReference] = pc.map(percentile => + override val aggBufferAttributes: Seq[AttributeReference] = percentiles.map(percentile => AttributeReference(percentile.toString, DoubleType)()) override val inputAggBufferAttributes: Seq[AttributeReference] = aggBufferAttributes.map(_.newInstance()) override def initialize(buffer: MutableRow): Unit = { - for (i <- 0 until pc.size) { + var i = 0 + while (i < percentiles.size) { buffer.setNullAt(mutableAggBufferOffset + i) + i += 1 } } override def update(buffer: MutableRow, input: InternalRow): Unit = { val v = child.eval(input) - v match { - case o: Int => counts.changeValue(o.toLong, 1L, _ + 1L) - case o: Long => counts.changeValue(o, 1L, _ + 1L) - case _ => return false + val key = v match { + case o: Byte => o.toDouble + case o: Short => o.toDouble + case o: Int => o.toDouble + case o: Long => o.toDouble + case o: Float => o.toDouble + case o: Decimal => o.toDouble + case o: Double => o + case _ => sys.error("Percentile is restricted to Numeric types only.") } + + counts.changeValue(key, 1L, _ + 1L) } override def merge(buffer: MutableRow, inputBuffer: InternalRow): Unit = { @@ -91,29 +115,30 @@ case class Percentile( } override def eval(buffer: InternalRow): Any = { - if (counts.size == 0) { - return new GenericArrayData(Seq.empty) - } - - // Sort all items and generate a sequence, then accumulate the counts - val sortedCounts = counts.toSeq.sortBy(_._1) - val aggreCounts = sortedCounts.scanLeft(0L, 0L) { (k1: (Long, Long), k2: (Long, Long)) => - (k2._1, k1._2 + k2._2) - }.drop(1) - val maxPosition = aggreCounts.last._2 - 1 - - new GenericArrayData(pc.map { percentile => - if (percentile < 0.0 || percentile > 1.0) { - sys.error("Percentile value must be within the range of 0 to 1.") + if (percentiles.forall(percentile => percentile >= 0.0 && percentile <= 1.0)) { + if (counts.size == 0) { + return new GenericArrayData(Seq.empty) } - getPercentile(aggreCounts, maxPosition * percentile) - }) + + // Sort all items and generate a sequence, then accumulate the counts + val sortedCounts = counts.toSeq.sortBy(_._1) + val aggreCounts = sortedCounts.scanLeft(sortedCounts.head._1, 0L) { + (k1: (Double, Long), k2: (Double, Long)) => (k2._1, k1._2 + k2._2) + }.drop(1) + val maxPosition = aggreCounts.last._2 - 1 + + new GenericArrayData(percentiles.map { percentile => + getPercentile(aggreCounts, maxPosition * percentile) + }) + } else { + sys.error("Percentile value must be within the range of 0 to 1.") + } } /** * Get the percentile value. */ - private def getPercentile(aggreCounts: Seq[(Long, Long)], position: Double): Double = { + private def getPercentile(aggreCounts: Seq[(Double, Long)], position: Double): Double = { // We may need to do linear interpolation to get the exact percentile val lower = position.floor val higher = position.ceil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 7f5348f6373dc..ae31bf99e153c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -619,8 +619,9 @@ object functions { * @group agg_funcs * @since 2.1.0 */ - def percentile(e: Column, pc: Seq[Double]): Column = - withAggregateFunction {Percentile(e.expr, pc)} + def percentile(e: Column, pc: Seq[Double]): Column = withAggregateFunction { + Percentile(e.expr, CreateArray(pc.map(v => Literal(v)))) + } /** * Aggregate function: returns the exact percentile(s) of the column in a group at pc with range @@ -631,6 +632,24 @@ object functions { */ def percentile(columnName: String, pc: Seq[Double]): Column = percentile(Column(columnName), pc) + /** + * Aggregate function: returns the exact percentile of the expression in a group at pc with + * range in [0, 1]. + * + * @group agg_funcs + * @since 2.1.0 + */ + def percentile(e: Column, pc: Double): Column = percentile(e, Seq(pc)) + + /** + * Aggregate function: returns the exact percentile of the column in a group at pc with range + * in [0, 1]. + * + * @group agg_funcs + * @since 2.1.0 + */ + def percentile(columnName: String, pc: Double): Column = percentile(Column(columnName), pc) + /** * Aggregate function: returns the skewness of the values in a group. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 23553649e0574..12647976b7a4c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -517,7 +517,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { test("percentile functions") { val df = Seq(1, 3, 3, 6, 5, 4, 17, 38, 29, 400).toDF("a") checkAnswer( - df.select(percentile($"a", Seq(0.5)), percentile($"a", Seq(0, 0.75, 1))), + df.select(percentile($"a", 0.5d), percentile($"a", Seq(0d, 0.75d, 1d))), Seq(Row(Seq(5.5), Seq(1.0, 26.0, 400.0))) ) } @@ -525,7 +525,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { test("percentile functions with zero input rows.") { val df = Seq(1, 3, 3, 6, 5, 4, 17, 38, 29, 400).toDF("a").where($"a" < 0) checkAnswer( - df.select(percentile($"a", Seq(0.5))), + df.select(percentile($"a", 0.5d)), Seq(Row(Seq.empty)) ) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala index fdd02821dfa29..900da7a28239b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala @@ -41,6 +41,8 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils { .saveAsTable("t1") spark.range(10).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write.saveAsTable("t2") + + spark.range(10).select('id as 'key, 'id as 'value).write.saveAsTable("t3") } override protected def afterAll(): Unit = { @@ -154,6 +156,7 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils { } test("aggregate functions") { + checkSqlGeneration("SELECT percentile(value, array(0.5d, 0.9d)) FROM t3 GROUP BY key") checkSqlGeneration("SELECT approx_count_distinct(value) FROM t1 GROUP BY key") checkSqlGeneration("SELECT percentile_approx(value, 0.25) FROM t1 GROUP BY key") checkSqlGeneration("SELECT percentile_approx(value, array(0.25, 0.75)) FROM t1 GROUP BY key") @@ -173,6 +176,7 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils { checkSqlGeneration("SELECT max(value) FROM t1 GROUP BY key") checkSqlGeneration("SELECT mean(value) FROM t1 GROUP BY key") checkSqlGeneration("SELECT min(value) FROM t1 GROUP BY key") + //checkSqlGeneration("SELECT percentile(value, array(0.5d, 0.9d)) FROM t3 GROUP BY key") checkSqlGeneration("SELECT skewness(value) FROM t1 GROUP BY key") checkSqlGeneration("SELECT stddev(value) FROM t1 GROUP BY key") checkSqlGeneration("SELECT stddev_pop(value) FROM t1 GROUP BY key") From 324483f1dcd14d075003acef4fab62809e1f6648 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=8B=E6=98=9F=E5=8D=9A?= Date: Thu, 14 Jul 2016 19:28:30 +0800 Subject: [PATCH 04/22] remove unused import. --- .../spark/sql/catalyst/expressions/aggregate/Percentile.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index 2879525f74e8f..ff59afb1b0459 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ From 6bda505898996ce3d36af010b1a31043ea62c37b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=8B=E6=98=9F=E5=8D=9A?= Date: Thu, 14 Jul 2016 19:36:45 +0800 Subject: [PATCH 05/22] fix scala style fail. --- .../spark/sql/catalyst/expressions/aggregate/Percentile.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index ff59afb1b0459..9c736b6332b11 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.{TypeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashMap From a29d8b36017b35ee6f55a7ab4a3355f942539870 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=8B=E6=98=9F=E5=8D=9A?= Date: Thu, 14 Jul 2016 20:09:33 +0800 Subject: [PATCH 06/22] fix scala style fail. --- .../org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala index 900da7a28239b..3a2b4efc887e4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala @@ -156,7 +156,6 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils { } test("aggregate functions") { - checkSqlGeneration("SELECT percentile(value, array(0.5d, 0.9d)) FROM t3 GROUP BY key") checkSqlGeneration("SELECT approx_count_distinct(value) FROM t1 GROUP BY key") checkSqlGeneration("SELECT percentile_approx(value, 0.25) FROM t1 GROUP BY key") checkSqlGeneration("SELECT percentile_approx(value, array(0.25, 0.75)) FROM t1 GROUP BY key") @@ -176,7 +175,7 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils { checkSqlGeneration("SELECT max(value) FROM t1 GROUP BY key") checkSqlGeneration("SELECT mean(value) FROM t1 GROUP BY key") checkSqlGeneration("SELECT min(value) FROM t1 GROUP BY key") - //checkSqlGeneration("SELECT percentile(value, array(0.5d, 0.9d)) FROM t3 GROUP BY key") + checkSqlGeneration("SELECT percentile(value, array(0.5d, 0.9d)) FROM t3 GROUP BY key") checkSqlGeneration("SELECT skewness(value) FROM t1 GROUP BY key") checkSqlGeneration("SELECT stddev(value) FROM t1 GROUP BY key") checkSqlGeneration("SELECT stddev_pop(value) FROM t1 GROUP BY key") From c7193d412badac33437cf15457db9fa6e9e68d90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=8B=E6=98=9F=E5=8D=9A?= Date: Sat, 16 Jul 2016 14:33:18 +0800 Subject: [PATCH 07/22] bugfix --- .../expressions/aggregate/Percentile.scala | 102 +++++++++++------- .../org/apache/spark/sql/functions.scala | 4 +- .../spark/sql/DataFrameAggregateSuite.scala | 11 +- .../spark/sql/hive/HiveSessionCatalog.scala | 3 +- .../sql/hive/execution/HiveUDFSuite.scala | 2 +- 5 files changed, 77 insertions(+), 45 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index 9c736b6332b11..fb93e15c662ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -47,12 +47,44 @@ case class Percentile( this(child = child, pc = pc, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) } - private val percentiles: Seq[Double] = pc match { + private val percentiles: Seq[Number] = pc match { + case e: Literal => + analyzePercentile(e) + case CreateArray(e: Seq[Expression]) => + analyzePercentile(e) + case e: Expression => + analyzePercentile(e.children) + case _ => sys.error("Percentiles expression cannot be analyzed.") + } + + private def analyzePercentile(e: Expression): Seq[Number] = e match { case Literal(ar: GenericArrayData, _: ArrayType) => - ar.asInstanceOf[GenericArrayData].array.map{ d => d.asInstanceOf[Double]} - case _ => Seq.empty + ar.asInstanceOf[GenericArrayData].array.map{ d => d.asInstanceOf[Number]} + case Literal(d: Any, _: NumericType) => + Seq(d.asInstanceOf[Number]) + case PrettyAttribute(n: String, _: NumericType) => + Seq(n.toDouble.asInstanceOf[Number]) + case _ => sys.error("Percentiles expression cannot be analyzed.") } + private def analyzePercentile(e: Seq[Expression]): Seq[Number] = { + e.map { expr => + expr match { + case Literal(d: Any, _: NumericType) => + d.asInstanceOf[Number] + case PrettyAttribute(n: String, _: NumericType) => + n.toDouble.asInstanceOf[Number] + case _ => sys.error("Percentiles expression cannot be analyzed.") + } + } + } + + require(percentiles.size > 0, "Percentiles should not be empty.") + + require(percentiles.forall(percentile => + percentile.doubleValue() >= 0.0 && percentile.doubleValue() <= 1.0), + "Percentile value must be within the range of 0 to 1.") + override def prettyName: String = "percentile" override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = @@ -61,7 +93,7 @@ case class Percentile( override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = copy(inputAggBufferOffset = newInputAggBufferOffset) - private var counts = new OpenHashMap[Double, Long]() + private var counts = new OpenHashMap[Number, Long] override def children: Seq[Expression] = child :: pc :: Nil @@ -78,11 +110,9 @@ case class Percentile( override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) - override val aggBufferAttributes: Seq[AttributeReference] = percentiles.map(percentile => - AttributeReference(percentile.toString, DoubleType)()) + override val aggBufferAttributes: Seq[AttributeReference] = Nil - override val inputAggBufferAttributes: Seq[AttributeReference] = - aggBufferAttributes.map(_.newInstance()) + override val inputAggBufferAttributes: Seq[AttributeReference] = Nil override def initialize(buffer: MutableRow): Unit = { var i = 0 @@ -90,21 +120,13 @@ case class Percentile( buffer.setNullAt(mutableAggBufferOffset + i) i += 1 } + + //The counts openhashmap will contain values of other groups if we don't initialize it here. + counts = new OpenHashMap[Number, Long] } override def update(buffer: MutableRow, input: InternalRow): Unit = { - val v = child.eval(input) - - val key = v match { - case o: Byte => o.toDouble - case o: Short => o.toDouble - case o: Int => o.toDouble - case o: Long => o.toDouble - case o: Float => o.toDouble - case o: Decimal => o.toDouble - case o: Double => o - case _ => sys.error("Percentile is restricted to Numeric types only.") - } + val key = child.eval(input).asInstanceOf[Number] counts.changeValue(key, 1L, _ + 1L) } @@ -114,30 +136,31 @@ case class Percentile( } override def eval(buffer: InternalRow): Any = { - if (percentiles.forall(percentile => percentile >= 0.0 && percentile <= 1.0)) { - if (counts.size == 0) { - return new GenericArrayData(Seq.empty) - } - - // Sort all items and generate a sequence, then accumulate the counts - val sortedCounts = counts.toSeq.sortBy(_._1) - val aggreCounts = sortedCounts.scanLeft(sortedCounts.head._1, 0L) { - (k1: (Double, Long), k2: (Double, Long)) => (k2._1, k1._2 + k2._2) - }.drop(1) - val maxPosition = aggreCounts.last._2 - 1 - - new GenericArrayData(percentiles.map { percentile => - getPercentile(aggreCounts, maxPosition * percentile) - }) - } else { - sys.error("Percentile value must be within the range of 0 to 1.") + if (counts.isEmpty) { + return new GenericArrayData(Seq.empty) + } + // Sort all items and generate a sequence, then accumulate the counts + var ascOrder = new Ordering[Int](){ + override def compare(a:Int,b:Int):Int = a - b } + val sortedCounts = counts.toSeq.sortBy(_._1)(new Ordering[Number]() { + override def compare(a: Number, b: Number): Int = + scala.math.signum(a.doubleValue() - b.doubleValue()).toInt + }) + val aggreCounts = sortedCounts.scanLeft(sortedCounts.head._1, 0L) { + (k1: (Number, Long), k2: (Number, Long)) => (k2._1, k1._2 + k2._2) + }.drop(1) + val maxPosition = aggreCounts.last._2 - 1 + + new GenericArrayData(percentiles.map { percentile => + getPercentile(aggreCounts, maxPosition * percentile.doubleValue()).doubleValue() + }) } /** * Get the percentile value. */ - private def getPercentile(aggreCounts: Seq[(Double, Long)], position: Double): Double = { + private def getPercentile(aggreCounts: Seq[(Number, Long)], position: Double): Number = { // We may need to do linear interpolation to get the exact percentile val lower = position.floor val higher = position.ceil @@ -167,6 +190,7 @@ case class Percentile( } // Linear interpolation to get the exact percentile - return (higher - position) * lowerKey + (position - lower) * higherKey + return (higher - position) * lowerKey.doubleValue() + + (position - lower) * higherKey.doubleValue() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index ae31bf99e153c..77f1c391773a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -639,7 +639,9 @@ object functions { * @group agg_funcs * @since 2.1.0 */ - def percentile(e: Column, pc: Double): Column = percentile(e, Seq(pc)) + def percentile(e: Column, pc: Double): Column = withAggregateFunction { + Percentile(e.expr, Literal(pc)) + } /** * Aggregate function: returns the exact percentile of the column in a group at pc with range diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 12647976b7a4c..eceb1dff3766f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -517,7 +517,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { test("percentile functions") { val df = Seq(1, 3, 3, 6, 5, 4, 17, 38, 29, 400).toDF("a") checkAnswer( - df.select(percentile($"a", 0.5d), percentile($"a", Seq(0d, 0.75d, 1d))), + df.select(percentile($"a", 0.5), percentile($"a", Seq(0, 0.75, 1))), Seq(Row(Seq(5.5), Seq(1.0, 26.0, 400.0))) ) } @@ -525,8 +525,15 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { test("percentile functions with zero input rows.") { val df = Seq(1, 3, 3, 6, 5, 4, 17, 38, 29, 400).toDF("a").where($"a" < 0) checkAnswer( - df.select(percentile($"a", 0.5d)), + df.select(percentile($"a", 0.5)), Seq(Row(Seq.empty)) ) } + + test("percentile functions with empty percentile param") { + val df = Seq(1, 3, 3, 6, 5, 4, 17, 38, 29, 400).toDF("a") + val error = intercept[IllegalArgumentException] { + df.select(percentile($"a", Seq())) + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 4a9b28a455a44..08bf1cd0efbb9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -234,7 +234,6 @@ private[sql] class HiveSessionCatalog( // noopwithmapstreaming, parse_url_tuple, reflect2, windowingtablefunction. // Note: don't forget to update SessionCatalog.isTemporaryFunction private val hiveFunctions = Seq( - "histogram_numeric", - "percentile" + "histogram_numeric" ) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 48adc833f4b22..2235ace6d98ad 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -136,7 +136,7 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { test("SPARK-2693 udaf aggregates test") { checkAnswer(sql("SELECT percentile(key, 1) FROM src LIMIT 1"), - sql("SELECT max(key) FROM src").collect().toSeq) + sql("SELECT array(max(key)) FROM src").collect().toSeq) checkAnswer(sql("SELECT percentile(key, array(1, 1)) FROM src LIMIT 1"), sql("SELECT array(max(key), max(key)) FROM src").collect().toSeq) From d21a1042f5d1cc035daad92b1693333e3e7c8840 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=8B=E6=98=9F=E5=8D=9A?= Date: Sat, 16 Jul 2016 19:17:15 +0800 Subject: [PATCH 08/22] add testcases. --- .../expressions/aggregate/Percentile.scala | 80 ++++++++----------- .../spark/sql/DataFrameAggregateSuite.scala | 6 +- .../execution/AggregationQuerySuite.scala | 36 +++++++++ 3 files changed, 74 insertions(+), 48 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index fb93e15c662ff..7eb4018c53e9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -35,7 +35,7 @@ import org.apache.spark.util.collection.OpenHashMap * Errors. */ @ExpressionDescription( - usage = """_FUNC_(epxr, pc) - Returns the percentile(s) of expr at pc (range: [0,1]). pc can be + usage = """_FUNC_(expr, pc) - Returns the percentile(s) of expr at pc (range: [0,1]). pc can be a double or double array.""") case class Percentile( child: Expression, @@ -47,44 +47,6 @@ case class Percentile( this(child = child, pc = pc, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) } - private val percentiles: Seq[Number] = pc match { - case e: Literal => - analyzePercentile(e) - case CreateArray(e: Seq[Expression]) => - analyzePercentile(e) - case e: Expression => - analyzePercentile(e.children) - case _ => sys.error("Percentiles expression cannot be analyzed.") - } - - private def analyzePercentile(e: Expression): Seq[Number] = e match { - case Literal(ar: GenericArrayData, _: ArrayType) => - ar.asInstanceOf[GenericArrayData].array.map{ d => d.asInstanceOf[Number]} - case Literal(d: Any, _: NumericType) => - Seq(d.asInstanceOf[Number]) - case PrettyAttribute(n: String, _: NumericType) => - Seq(n.toDouble.asInstanceOf[Number]) - case _ => sys.error("Percentiles expression cannot be analyzed.") - } - - private def analyzePercentile(e: Seq[Expression]): Seq[Number] = { - e.map { expr => - expr match { - case Literal(d: Any, _: NumericType) => - d.asInstanceOf[Number] - case PrettyAttribute(n: String, _: NumericType) => - n.toDouble.asInstanceOf[Number] - case _ => sys.error("Percentiles expression cannot be analyzed.") - } - } - } - - require(percentiles.size > 0, "Percentiles should not be empty.") - - require(percentiles.forall(percentile => - percentile.doubleValue() >= 0.0 && percentile.doubleValue() <= 1.0), - "Percentile value must be within the range of 0 to 1.") - override def prettyName: String = "percentile" override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = @@ -115,20 +77,43 @@ case class Percentile( override val inputAggBufferAttributes: Seq[AttributeReference] = Nil override def initialize(buffer: MutableRow): Unit = { - var i = 0 - while (i < percentiles.size) { - buffer.setNullAt(mutableAggBufferOffset + i) - i += 1 + //The counts OpenHashMap will contain values of other groups if we don't initialize it here. + //Since OpenHashMap doesn't support deletions, we have to create a new instance. + counts = new OpenHashMap[Number, Long] + } + + private def evalPercentiles(input: InternalRow): Seq[Number] = { + val exprs = children + val percentiles: Seq[Number] = children(1).eval(input) match { + case ar: GenericArrayData => + ar.asInstanceOf[GenericArrayData].array.map{ d => d.asInstanceOf[Number]} + case d: Number => + Seq(d.asInstanceOf[Number]) + case d: Decimal => + Seq(d.toDouble.asInstanceOf[Number]) + case _ => + sys.error("Percentiles expression cannot be analyzed.") } - //The counts openhashmap will contain values of other groups if we don't initialize it here. - counts = new OpenHashMap[Number, Long] + require(percentiles.size > 0, "Percentiles should not be empty.") + + require(percentiles.forall(percentile => + percentile.doubleValue() >= 0.0 && percentile.doubleValue() <= 1.0), + "Percentile value must be within the range of 0 to 1.") + + percentiles } override def update(buffer: MutableRow, input: InternalRow): Unit = { + //Eval percentiles and check whether its value is valid. + val percentiles = evalPercentiles(input) + val key = child.eval(input).asInstanceOf[Number] - counts.changeValue(key, 1L, _ + 1L) + //Null values are ignored when computing percentiles. + if (key != null) { + counts.changeValue(key, 1L, _ + 1L) + } } override def merge(buffer: MutableRow, inputBuffer: InternalRow): Unit = { @@ -139,6 +124,9 @@ case class Percentile( if (counts.isEmpty) { return new GenericArrayData(Seq.empty) } + + val percentiles = evalPercentiles(buffer) + // Sort all items and generate a sequence, then accumulate the counts var ascOrder = new Ordering[Int](){ override def compare(a:Int,b:Int):Int = a - b diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index eceb1dff3766f..ee52b9be2c81a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.SparkException import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -532,8 +533,9 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { test("percentile functions with empty percentile param") { val df = Seq(1, 3, 3, 6, 5, 4, 17, 38, 29, 400).toDF("a") - val error = intercept[IllegalArgumentException] { - df.select(percentile($"a", Seq())) + val error = intercept[SparkException] { + df.select(percentile($"a", Seq())).collect() } + assert(error.getMessage.contains("Percentiles should not be empty.")) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 4a8086d7e5400..e776b4cba8bc6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -851,6 +851,42 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te checkAnswer(df3.groupBy().agg(covar_pop("a", "b")), Row(0.0)) } + test("percentile") { + checkAnswer( + spark.sql( + """ + |SELECT + | percentile(value1, 0.5) + |FROM agg2 + |LIMIT 1 + """.stripMargin), + Row(Seq(1)) :: Nil) + + checkAnswer( + spark.sql( + """ + |SELECT + | percentile(value1, array(0d, 0.5d, 1d)) + |FROM agg2 + |LIMIT 1 + """.stripMargin), + Row(Seq(-60, 1, 100)) :: Nil) + + checkAnswer( + spark.sql( + """ + |SELECT + | key, + | percentile(value1, array(0d, 0.5d, 1d)) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(null, Seq(-60, -10, 100)) :: + Row(1, Seq(10, 30, 30)) :: + Row(2, Seq(-1, 1, 1)) :: + Row(3, Seq.empty) :: Nil) + } + test("no aggregation function (SPARK-11486)") { val df = spark.range(20).selectExpr("id", "repeat(id, 1) as s") .groupBy("s").count() From 8eebb6af9e19bf1af1ecb3ad0206fcfa4bbc8097 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=8B=E6=98=9F=E5=8D=9A?= Date: Sat, 16 Jul 2016 19:37:03 +0800 Subject: [PATCH 09/22] fix scala style fail. --- .../catalyst/expressions/aggregate/Percentile.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index 7eb4018c53e9b..cc4b3c49e82a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -77,8 +77,8 @@ case class Percentile( override val inputAggBufferAttributes: Seq[AttributeReference] = Nil override def initialize(buffer: MutableRow): Unit = { - //The counts OpenHashMap will contain values of other groups if we don't initialize it here. - //Since OpenHashMap doesn't support deletions, we have to create a new instance. + // The counts OpenHashMap will contain values of other groups if we don't initialize it here. + // Since OpenHashMap doesn't support deletions, we have to create a new instance. counts = new OpenHashMap[Number, Long] } @@ -105,12 +105,12 @@ case class Percentile( } override def update(buffer: MutableRow, input: InternalRow): Unit = { - //Eval percentiles and check whether its value is valid. + // Eval percentiles and check whether its value is valid. val percentiles = evalPercentiles(input) val key = child.eval(input).asInstanceOf[Number] - //Null values are ignored when computing percentiles. + // Null values are ignored when computing percentiles. if (key != null) { counts.changeValue(key, 1L, _ + 1L) } @@ -128,8 +128,8 @@ case class Percentile( val percentiles = evalPercentiles(buffer) // Sort all items and generate a sequence, then accumulate the counts - var ascOrder = new Ordering[Int](){ - override def compare(a:Int,b:Int):Int = a - b + var ascOrder = new Ordering[Int]() { + override def compare(a: Int, b: Int): Int = a - b } val sortedCounts = counts.toSeq.sortBy(_._1)(new Ordering[Number]() { override def compare(a: Number, b: Number): Int = From 79a2b97c6b58845142753ed9c5ef2841b87ed4c7 Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Mon, 18 Jul 2016 20:39:05 +0800 Subject: [PATCH 10/22] add testcases for windowing. --- .../sql/DataFrameWindowFunctionsSuite.scala | 11 ++-- .../HiveWindowFunctionQuerySuite.scala | 2 +- .../sql/hive/execution/WindowQuerySuite.scala | 55 ++++++++++--------- 3 files changed, 35 insertions(+), 33 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 1255c49104718..a730500cfbb2c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -116,11 +116,12 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { dense_rank().over(Window.partitionBy("value").orderBy("key")), rank().over(Window.partitionBy("value").orderBy("key")), cume_dist().over(Window.partitionBy("value").orderBy("key")), - percent_rank().over(Window.partitionBy("value").orderBy("key"))), - Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d, 0.0d) :: - Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d / 3.0d, 0.0d) :: - Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 1, 2, 2, 2, 1.0d, 0.5d) :: - Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 2, 3, 2, 2, 1.0d, 0.5d) :: Nil) + percent_rank().over(Window.partitionBy("value").orderBy("key")), + percentile("key", 0.5).over(Window.partitionBy("value").orderBy("key"))), + Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d, 0.0d, Seq(1.0)) :: + Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d / 3.0d, 0.0d, Seq(1.0)) :: + Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 1, 2, 2, 2, 1.0d, 0.5d, Seq(2.0)) :: + Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 2, 3, 2, 2, 1.0d, 0.5d, Seq(2.0)) :: Nil) } test("window function should fail if order by clause is not specified") { diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala index 7ba5790c2979d..f12ca0dc55759 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala @@ -538,7 +538,7 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte """ |select p_mfgr,p_name, p_size, |histogram_numeric(p_retailprice, 5) over w1 as hist, - |percentile(p_partkey, cast(0.5 as double)) over w1 as per, + |percentile(p_partkey, array(0.5, 0.9)) over w1 as per, |row_number() over(distribute by p_mfgr sort by p_name) as rn |from part |window w1 as (distribute by p_mfgr sort by p_mfgr, p_name diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala index 0ff3511c87a4f..2b65f51f888a2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala @@ -75,39 +75,40 @@ class WindowQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleto |stddev(p_size) over(distribute by p_mfgr sort by p_name) as st, |first_value(p_size % 5) over(distribute by p_mfgr sort by p_name) as fv, |last_value(p_size) over(distribute by p_mfgr sort by p_name) as lv, - |first_value(p_size) over w1 as fvW1 + |first_value(p_size) over w1 as fvW1, + |percentile(p_size, 0.5) over w1 as per |from part |window w1 as (distribute by p_mfgr sort by p_mfgr, p_name | rows between 2 preceding and 2 following) """.stripMargin), // scalastyle:off Seq( - Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 1, 1, 0.3333333333333333, 0.0, 1, 2, 2.0, 0.0, 2, 2, 2), - Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 1, 1, 0.3333333333333333, 0.0, 1, 2, 2.0, 0.0, 2, 2, 2), - Row("Manufacturer#1", "almond antique chartreuse lavender yellow", 34, 3, 2, 0.5, 0.4, 2, 3, 12.666666666666666, 18.475208614068027, 2, 34, 2), - Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 4, 3, 0.6666666666666666, 0.6, 2, 4, 11.0, 15.448840301675292, 2, 6, 2), - Row("Manufacturer#1", "almond aquamarine burnished black steel", 28, 5, 4, 0.8333333333333334, 0.8, 3, 5, 14.4, 15.388307249337076, 2, 28, 34), - Row("Manufacturer#1", "almond aquamarine pink moccasin thistle", 42, 6, 5, 1.0, 1.0, 3, 6, 19.0, 17.787636155487327, 2, 42, 6), - Row("Manufacturer#2", "almond antique violet chocolate turquoise", 14, 1, 1, 0.2, 0.0, 1, 1, 14.0, Double.NaN, 4, 14, 14), - Row("Manufacturer#2", "almond antique violet turquoise frosted", 40, 2, 2, 0.4, 0.25, 1, 2, 27.0, 18.384776310850235, 4, 40, 14), - Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 3, 3, 0.6, 0.5, 2, 3, 18.666666666666668, 19.42506971244462, 4, 2, 14), - Row("Manufacturer#2", "almond aquamarine rose maroon antique", 25, 4, 4, 0.8, 0.75, 2, 4, 20.25, 16.17353805861084, 4, 25, 40), - Row("Manufacturer#2", "almond aquamarine sandy cyan gainsboro", 18, 5, 5, 1.0, 1.0, 3, 5, 19.8, 14.042791745233567, 4, 18, 2), - Row("Manufacturer#3", "almond antique chartreuse khaki white", 17, 1, 1, 0.2, 0.0, 1, 1, 17.0,Double.NaN, 2, 17, 17), - Row("Manufacturer#3", "almond antique forest lavender goldenrod", 14, 2, 2, 0.4, 0.25, 1, 2, 15.5, 2.1213203435596424, 2, 14, 17), - Row("Manufacturer#3", "almond antique metallic orange dim", 19, 3, 3, 0.6, 0.5, 2, 3, 16.666666666666668, 2.516611478423583, 2, 19, 17), - Row("Manufacturer#3", "almond antique misty red olive", 1, 4, 4, 0.8, 0.75, 2, 4, 12.75, 8.098353742170895, 2, 1, 14), - Row("Manufacturer#3", "almond antique olive coral navajo", 45, 5, 5, 1.0, 1.0, 3, 5, 19.2, 16.037456157383566, 2, 45, 19), - Row("Manufacturer#4", "almond antique gainsboro frosted violet", 10, 1, 1, 0.2, 0.0, 1, 1, 10.0, Double.NaN, 0, 10, 10), - Row("Manufacturer#4", "almond antique violet mint lemon", 39, 2, 2, 0.4, 0.25, 1, 2, 24.5, 20.506096654409877, 0, 39, 10), - Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 3, 3, 0.6, 0.5, 2, 3, 25.333333333333332, 14.571661996262929, 0, 27, 10), - Row("Manufacturer#4", "almond aquamarine yellow dodger mint", 7, 4, 4, 0.8, 0.75, 2, 4, 20.75, 15.01943185787443, 0, 7, 39), - Row("Manufacturer#4", "almond azure aquamarine papaya violet", 12, 5, 5, 1.0, 1.0, 3, 5, 19.0, 13.583077707206124, 0, 12, 27), - Row("Manufacturer#5", "almond antique blue firebrick mint", 31, 1, 1, 0.2, 0.0, 1, 1, 31.0, Double.NaN, 1, 31, 31), - Row("Manufacturer#5", "almond antique medium spring khaki", 6, 2, 2, 0.4, 0.25, 1, 2, 18.5, 17.67766952966369, 1, 6, 31), - Row("Manufacturer#5", "almond antique sky peru orange", 2, 3, 3, 0.6, 0.5, 2, 3, 13.0, 15.716233645501712, 1, 2, 31), - Row("Manufacturer#5", "almond aquamarine dodger light gainsboro", 46, 4, 4, 0.8, 0.75, 2, 4, 21.25, 20.902551678363736, 1, 46, 6), - Row("Manufacturer#5", "almond azure blanched chiffon midnight", 23, 5, 5, 1.0, 1.0, 3, 5, 21.6, 18.1190507477627, 1, 23, 2))) + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 1, 1, 0.3333333333333333, 0.0, 1, 2, 2.0, 0.0, 2, 2, 2, Seq(2.0)), + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 1, 1, 0.3333333333333333, 0.0, 1, 2, 2.0, 0.0, 2, 2, 2, Seq(4.0)), + Row("Manufacturer#1", "almond antique chartreuse lavender yellow", 34, 3, 2, 0.5, 0.4, 2, 3, 12.666666666666666, 18.475208614068027, 2, 34, 2, Seq(6.0)), + Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 4, 3, 0.6666666666666666, 0.6, 2, 4, 11.0, 15.448840301675292, 2, 6, 2, Seq(28.0)), + Row("Manufacturer#1", "almond aquamarine burnished black steel", 28, 5, 4, 0.8333333333333334, 0.8, 3, 5, 14.4, 15.388307249337076, 2, 28, 34, Seq(31.0)), + Row("Manufacturer#1", "almond aquamarine pink moccasin thistle", 42, 6, 5, 1.0, 1.0, 3, 6, 19.0, 17.787636155487327, 2, 42, 6, Seq(28.0)), + Row("Manufacturer#2", "almond antique violet chocolate turquoise", 14, 1, 1, 0.2, 0.0, 1, 1, 14.0, Double.NaN, 4, 14, 14, Seq(14.0)), + Row("Manufacturer#2", "almond antique violet turquoise frosted", 40, 2, 2, 0.4, 0.25, 1, 2, 27.0, 18.384776310850235, 4, 40, 14, Seq(19.5)), + Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 3, 3, 0.6, 0.5, 2, 3, 18.666666666666668, 19.42506971244462, 4, 2, 14, Seq(18.0)), + Row("Manufacturer#2", "almond aquamarine rose maroon antique", 25, 4, 4, 0.8, 0.75, 2, 4, 20.25, 16.17353805861084, 4, 25, 40, Seq(21.5)), + Row("Manufacturer#2", "almond aquamarine sandy cyan gainsboro", 18, 5, 5, 1.0, 1.0, 3, 5, 19.8, 14.042791745233567, 4, 18, 2, Seq(18.0)), + Row("Manufacturer#3", "almond antique chartreuse khaki white", 17, 1, 1, 0.2, 0.0, 1, 1, 17.0,Double.NaN, 2, 17, 17, Seq(17.0)), + Row("Manufacturer#3", "almond antique forest lavender goldenrod", 14, 2, 2, 0.4, 0.25, 1, 2, 15.5, 2.1213203435596424, 2, 14, 17, Seq(15.5)), + Row("Manufacturer#3", "almond antique metallic orange dim", 19, 3, 3, 0.6, 0.5, 2, 3, 16.666666666666668, 2.516611478423583, 2, 19, 17, Seq(17.0)), + Row("Manufacturer#3", "almond antique misty red olive", 1, 4, 4, 0.8, 0.75, 2, 4, 12.75, 8.098353742170895, 2, 1, 14, Seq(16.5)), + Row("Manufacturer#3", "almond antique olive coral navajo", 45, 5, 5, 1.0, 1.0, 3, 5, 19.2, 16.037456157383566, 2, 45, 19, Seq(19.0)), + Row("Manufacturer#4", "almond antique gainsboro frosted violet", 10, 1, 1, 0.2, 0.0, 1, 1, 10.0, Double.NaN, 0, 10, 10, Seq(27.0)), + Row("Manufacturer#4", "almond antique violet mint lemon", 39, 2, 2, 0.4, 0.25, 1, 2, 24.5, 20.506096654409877, 0, 39, 10, Seq(18.5)), + Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 3, 3, 0.6, 0.5, 2, 3, 25.333333333333332, 14.571661996262929, 0, 27, 10, Seq(12.0)), + Row("Manufacturer#4", "almond aquamarine yellow dodger mint", 7, 4, 4, 0.8, 0.75, 2, 4, 20.75, 15.01943185787443, 0, 7, 39, Seq(19.5)), + Row("Manufacturer#4", "almond azure aquamarine papaya violet", 12, 5, 5, 1.0, 1.0, 3, 5, 19.0, 13.583077707206124, 0, 12, 27, Seq(12.0)), + Row("Manufacturer#5", "almond antique blue firebrick mint", 31, 1, 1, 0.2, 0.0, 1, 1, 31.0, Double.NaN, 1, 31, 31, Seq(6.0)), + Row("Manufacturer#5", "almond antique medium spring khaki", 6, 2, 2, 0.4, 0.25, 1, 2, 18.5, 17.67766952966369, 1, 6, 31, Seq(18.5)), + Row("Manufacturer#5", "almond antique sky peru orange", 2, 3, 3, 0.6, 0.5, 2, 3, 13.0, 15.716233645501712, 1, 2, 31, Seq(23.0)), + Row("Manufacturer#5", "almond aquamarine dodger light gainsboro", 46, 4, 4, 0.8, 0.75, 2, 4, 21.25, 20.902551678363736, 1, 46, 6, Seq(14.5)), + Row("Manufacturer#5", "almond azure blanched chiffon midnight", 23, 5, 5, 1.0, 1.0, 3, 5, 21.6, 18.1190507477627, 1, 23, 2, Seq(23.0)))) // scalastyle:on } From 2ae7b48117500ae2221aeda52642f8528117d76d Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Tue, 19 Jul 2016 15:34:56 +0800 Subject: [PATCH 11/22] fix testcase fail. --- .../HiveWindowFunctionQuerySuite.scala | 2 +- ...stDISTs-0-d9065e533430691d70b3370174fbbd50 | 52 +++++++++---------- 2 files changed, 27 insertions(+), 27 deletions(-) diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala index f12ca0dc55759..7ba5790c2979d 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala @@ -538,7 +538,7 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte """ |select p_mfgr,p_name, p_size, |histogram_numeric(p_retailprice, 5) over w1 as hist, - |percentile(p_partkey, array(0.5, 0.9)) over w1 as per, + |percentile(p_partkey, cast(0.5 as double)) over w1 as per, |row_number() over(distribute by p_mfgr sort by p_name) as rn |from part |window w1 as (distribute by p_mfgr sort by p_mfgr, p_name diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 21. testDISTs-0-d9065e533430691d70b3370174fbbd50 b/sql/hive/src/test/resources/golden/windowing.q -- 21. testDISTs-0-d9065e533430691d70b3370174fbbd50 index e7c39f454fb37..6dbaaad364ea8 100644 --- a/sql/hive/src/test/resources/golden/windowing.q -- 21. testDISTs-0-d9065e533430691d70b3370174fbbd50 +++ b/sql/hive/src/test/resources/golden/windowing.q -- 21. testDISTs-0-d9065e533430691d70b3370174fbbd50 @@ -1,26 +1,26 @@ -Manufacturer#1 almond antique burnished rose metallic 2 [{"x":1173.15,"y":2.0},{"x":1753.76,"y":1.0}] 121152.0 1 -Manufacturer#1 almond antique burnished rose metallic 2 [{"x":1173.15,"y":2.0},{"x":1602.59,"y":1.0},{"x":1753.76,"y":1.0}] 115872.0 2 -Manufacturer#1 almond antique chartreuse lavender yellow 34 [{"x":1173.15,"y":2.0},{"x":1414.42,"y":1.0},{"x":1602.59,"y":1.0},{"x":1753.76,"y":1.0}] 110592.0 3 -Manufacturer#1 almond antique salmon chartreuse burlywood 6 [{"x":1173.15,"y":1.0},{"x":1414.42,"y":1.0},{"x":1602.59,"y":1.0},{"x":1632.66,"y":1.0},{"x":1753.76,"y":1.0}] 86428.0 4 -Manufacturer#1 almond aquamarine burnished black steel 28 [{"x":1414.42,"y":1.0},{"x":1602.59,"y":1.0},{"x":1632.66,"y":1.0},{"x":1753.76,"y":1.0}] 86098.0 5 -Manufacturer#1 almond aquamarine pink moccasin thistle 42 [{"x":1414.42,"y":1.0},{"x":1602.59,"y":1.0},{"x":1632.66,"y":1.0}] 86428.0 6 -Manufacturer#2 almond antique violet chocolate turquoise 14 [{"x":1690.68,"y":1.0},{"x":1800.7,"y":1.0},{"x":2031.98,"y":1.0}] 146985.0 1 -Manufacturer#2 almond antique violet turquoise frosted 40 [{"x":1690.68,"y":1.0},{"x":1698.66,"y":1.0},{"x":1800.7,"y":1.0},{"x":2031.98,"y":1.0}] 139825.5 2 -Manufacturer#2 almond aquamarine midnight light salmon 2 [{"x":1690.68,"y":1.0},{"x":1698.66,"y":1.0},{"x":1701.6,"y":1.0},{"x":1800.7,"y":1.0},{"x":2031.98,"y":1.0}] 146985.0 3 -Manufacturer#2 almond aquamarine rose maroon antique 25 [{"x":1698.66,"y":1.0},{"x":1701.6,"y":1.0},{"x":1800.7,"y":1.0},{"x":2031.98,"y":1.0}] 169347.0 4 -Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 [{"x":1698.66,"y":1.0},{"x":1701.6,"y":1.0},{"x":2031.98,"y":1.0}] 146985.0 5 -Manufacturer#3 almond antique chartreuse khaki white 17 [{"x":1190.27,"y":1.0},{"x":1410.39,"y":1.0},{"x":1671.68,"y":1.0}] 90681.0 1 -Manufacturer#3 almond antique forest lavender goldenrod 14 [{"x":1190.27,"y":1.0},{"x":1410.39,"y":1.0},{"x":1671.68,"y":1.0},{"x":1922.98,"y":1.0}] 65831.5 2 -Manufacturer#3 almond antique metallic orange dim 19 [{"x":1190.27,"y":1.0},{"x":1337.29,"y":1.0},{"x":1410.39,"y":1.0},{"x":1671.68,"y":1.0},{"x":1922.98,"y":1.0}] 90681.0 3 -Manufacturer#3 almond antique misty red olive 1 [{"x":1190.27,"y":1.0},{"x":1337.29,"y":1.0},{"x":1410.39,"y":1.0},{"x":1922.98,"y":1.0}] 76690.0 4 -Manufacturer#3 almond antique olive coral navajo 45 [{"x":1337.29,"y":1.0},{"x":1410.39,"y":1.0},{"x":1922.98,"y":1.0}] 112398.0 5 -Manufacturer#4 almond antique gainsboro frosted violet 10 [{"x":1206.26,"y":1.0},{"x":1375.42,"y":1.0},{"x":1620.67,"y":1.0}] 48427.0 1 -Manufacturer#4 almond antique violet mint lemon 39 [{"x":1206.26,"y":1.0},{"x":1375.42,"y":1.0},{"x":1620.67,"y":1.0},{"x":1844.92,"y":1.0}] 46844.0 2 -Manufacturer#4 almond aquamarine floral ivory bisque 27 [{"x":1206.26,"y":1.0},{"x":1290.35,"y":1.0},{"x":1375.42,"y":1.0},{"x":1620.67,"y":1.0},{"x":1844.92,"y":1.0}] 45261.0 3 -Manufacturer#4 almond aquamarine yellow dodger mint 7 [{"x":1206.26,"y":1.0},{"x":1290.35,"y":1.0},{"x":1375.42,"y":1.0},{"x":1844.92,"y":1.0}] 39309.0 4 -Manufacturer#4 almond azure aquamarine papaya violet 12 [{"x":1206.26,"y":1.0},{"x":1290.35,"y":1.0},{"x":1844.92,"y":1.0}] 33357.0 5 -Manufacturer#5 almond antique blue firebrick mint 31 [{"x":1611.66,"y":1.0},{"x":1788.73,"y":1.0},{"x":1789.69,"y":1.0}] 155733.0 1 -Manufacturer#5 almond antique medium spring khaki 6 [{"x":1018.1,"y":1.0},{"x":1611.66,"y":1.0},{"x":1788.73,"y":1.0},{"x":1789.69,"y":1.0}] 99201.0 2 -Manufacturer#5 almond antique sky peru orange 2 [{"x":1018.1,"y":1.0},{"x":1464.48,"y":1.0},{"x":1611.66,"y":1.0},{"x":1788.73,"y":1.0},{"x":1789.69,"y":1.0}] 78486.0 3 -Manufacturer#5 almond aquamarine dodger light gainsboro 46 [{"x":1018.1,"y":1.0},{"x":1464.48,"y":1.0},{"x":1611.66,"y":1.0},{"x":1788.73,"y":1.0}] 60577.5 4 -Manufacturer#5 almond azure blanched chiffon midnight 23 [{"x":1018.1,"y":1.0},{"x":1464.48,"y":1.0},{"x":1788.73,"y":1.0}] 78486.0 5 +Manufacturer#1 almond antique burnished rose metallic 2 [{"x":1173.15,"y":2.0},{"x":1753.76,"y":1.0}] [121152.0] 1 +Manufacturer#1 almond antique burnished rose metallic 2 [{"x":1173.15,"y":2.0},{"x":1602.59,"y":1.0},{"x":1753.76,"y":1.0}] [115872.0] 2 +Manufacturer#1 almond antique chartreuse lavender yellow 34 [{"x":1173.15,"y":2.0},{"x":1414.42,"y":1.0},{"x":1602.59,"y":1.0},{"x":1753.76,"y":1.0}] [110592.0] 3 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 [{"x":1173.15,"y":1.0},{"x":1414.42,"y":1.0},{"x":1602.59,"y":1.0},{"x":1632.66,"y":1.0},{"x":1753.76,"y":1.0}] [86428.0] 4 +Manufacturer#1 almond aquamarine burnished black steel 28 [{"x":1414.42,"y":1.0},{"x":1602.59,"y":1.0},{"x":1632.66,"y":1.0},{"x":1753.76,"y":1.0}] [86098.0] 5 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 [{"x":1414.42,"y":1.0},{"x":1602.59,"y":1.0},{"x":1632.66,"y":1.0}] [86428.0] 6 +Manufacturer#2 almond antique violet chocolate turquoise 14 [{"x":1690.68,"y":1.0},{"x":1800.7,"y":1.0},{"x":2031.98,"y":1.0}] [146985.0] 1 +Manufacturer#2 almond antique violet turquoise frosted 40 [{"x":1690.68,"y":1.0},{"x":1698.66,"y":1.0},{"x":1800.7,"y":1.0},{"x":2031.98,"y":1.0}] [139825.5] 2 +Manufacturer#2 almond aquamarine midnight light salmon 2 [{"x":1690.68,"y":1.0},{"x":1698.66,"y":1.0},{"x":1701.6,"y":1.0},{"x":1800.7,"y":1.0},{"x":2031.98,"y":1.0}] [146985.0] 3 +Manufacturer#2 almond aquamarine rose maroon antique 25 [{"x":1698.66,"y":1.0},{"x":1701.6,"y":1.0},{"x":1800.7,"y":1.0},{"x":2031.98,"y":1.0}] [169347.0] 4 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 [{"x":1698.66,"y":1.0},{"x":1701.6,"y":1.0},{"x":2031.98,"y":1.0}] [146985.0] 5 +Manufacturer#3 almond antique chartreuse khaki white 17 [{"x":1190.27,"y":1.0},{"x":1410.39,"y":1.0},{"x":1671.68,"y":1.0}] [90681.0] 1 +Manufacturer#3 almond antique forest lavender goldenrod 14 [{"x":1190.27,"y":1.0},{"x":1410.39,"y":1.0},{"x":1671.68,"y":1.0},{"x":1922.98,"y":1.0}] [65831.5] 2 +Manufacturer#3 almond antique metallic orange dim 19 [{"x":1190.27,"y":1.0},{"x":1337.29,"y":1.0},{"x":1410.39,"y":1.0},{"x":1671.68,"y":1.0},{"x":1922.98,"y":1.0}] [90681.0] 3 +Manufacturer#3 almond antique misty red olive 1 [{"x":1190.27,"y":1.0},{"x":1337.29,"y":1.0},{"x":1410.39,"y":1.0},{"x":1922.98,"y":1.0}] [76690.0] 4 +Manufacturer#3 almond antique olive coral navajo 45 [{"x":1337.29,"y":1.0},{"x":1410.39,"y":1.0},{"x":1922.98,"y":1.0}] [112398.0] 5 +Manufacturer#4 almond antique gainsboro frosted violet 10 [{"x":1206.26,"y":1.0},{"x":1375.42,"y":1.0},{"x":1620.67,"y":1.0}] [48427.0] 1 +Manufacturer#4 almond antique violet mint lemon 39 [{"x":1206.26,"y":1.0},{"x":1375.42,"y":1.0},{"x":1620.67,"y":1.0},{"x":1844.92,"y":1.0}] [46844.0] 2 +Manufacturer#4 almond aquamarine floral ivory bisque 27 [{"x":1206.26,"y":1.0},{"x":1290.35,"y":1.0},{"x":1375.42,"y":1.0},{"x":1620.67,"y":1.0},{"x":1844.92,"y":1.0}] [45261.0] 3 +Manufacturer#4 almond aquamarine yellow dodger mint 7 [{"x":1206.26,"y":1.0},{"x":1290.35,"y":1.0},{"x":1375.42,"y":1.0},{"x":1844.92,"y":1.0}] [39309.0] 4 +Manufacturer#4 almond azure aquamarine papaya violet 12 [{"x":1206.26,"y":1.0},{"x":1290.35,"y":1.0},{"x":1844.92,"y":1.0}] [33357.0] 5 +Manufacturer#5 almond antique blue firebrick mint 31 [{"x":1611.66,"y":1.0},{"x":1788.73,"y":1.0},{"x":1789.69,"y":1.0}] [155733.0] 1 +Manufacturer#5 almond antique medium spring khaki 6 [{"x":1018.1,"y":1.0},{"x":1611.66,"y":1.0},{"x":1788.73,"y":1.0},{"x":1789.69,"y":1.0}] [99201.0] 2 +Manufacturer#5 almond antique sky peru orange 2 [{"x":1018.1,"y":1.0},{"x":1464.48,"y":1.0},{"x":1611.66,"y":1.0},{"x":1788.73,"y":1.0},{"x":1789.69,"y":1.0}] [78486.0] 3 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 [{"x":1018.1,"y":1.0},{"x":1464.48,"y":1.0},{"x":1611.66,"y":1.0},{"x":1788.73,"y":1.0}] [60577.5] 4 +Manufacturer#5 almond azure blanched chiffon midnight 23 [{"x":1018.1,"y":1.0},{"x":1464.48,"y":1.0},{"x":1788.73,"y":1.0}] [78486.0] 5 From 8f24a9bfed2d3e7325bf5a4e1779fe887637dfb5 Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Mon, 17 Oct 2016 16:17:43 +0800 Subject: [PATCH 12/22] MutableRow is replaced by InternalRow. --- .../sql/catalyst/expressions/aggregate/Percentile.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index cc4b3c49e82a7..7304cdb65582c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -76,7 +76,7 @@ case class Percentile( override val inputAggBufferAttributes: Seq[AttributeReference] = Nil - override def initialize(buffer: MutableRow): Unit = { + override def initialize(buffer: InternalRow): Unit = { // The counts OpenHashMap will contain values of other groups if we don't initialize it here. // Since OpenHashMap doesn't support deletions, we have to create a new instance. counts = new OpenHashMap[Number, Long] @@ -104,7 +104,7 @@ case class Percentile( percentiles } - override def update(buffer: MutableRow, input: InternalRow): Unit = { + override def update(buffer: InternalRow, input: InternalRow): Unit = { // Eval percentiles and check whether its value is valid. val percentiles = evalPercentiles(input) @@ -116,7 +116,7 @@ case class Percentile( } } - override def merge(buffer: MutableRow, inputBuffer: InternalRow): Unit = { + override def merge(buffer: InternalRow, inputBuffer: InternalRow): Unit = { sys.error("Percentile cannot be used in partial aggregations.") } From 59a61cfb4ca9685f0dd0202ae5137b1ffcee667f Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Mon, 24 Oct 2016 18:24:42 +0800 Subject: [PATCH 13/22] refactor some code; remove unnecessary test cases. --- .../expressions/aggregate/Percentile.scala | 93 +++++++++++-------- .../org/apache/spark/sql/functions.scala | 40 -------- .../spark/sql/DataFrameAggregateSuite.scala | 24 ----- .../sql/DataFrameWindowFunctionsSuite.scala | 11 +-- .../execution/AggregationQuerySuite.scala | 36 ------- .../sql/hive/execution/HiveUDFSuite.scala | 2 +- .../sql/hive/execution/WindowQuerySuite.scala | 52 +++++------ 7 files changed, 86 insertions(+), 172 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index 7304cdb65582c..94c5f61795234 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ @@ -25,26 +26,37 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashMap /** - * The Percentile aggregate function computes the exact percentile(s) of expr at pc with range in - * [0, 1]. - * The parameter pc can be a DoubleType or DoubleType array. + * The Percentile aggregate function returns the exact percentile(s) of numeric column `expr` at + * the given percentage(s) with value range in [0.0, 1.0]. * * The operator is bound to the slower sort based aggregation path because the number of elements * and their partial order cannot be determined in advance. Therefore we have to store all the * elements in memory, and that too many elements can cause GC paused and eventually OutOfMemory * Errors. + * + * @param child child expression that produce numeric column value with `child.eval(inputRow)` + * @param percentageExpression Expression that represents a single percentage value or an array of + * percentage values. Each percentage value must be in the range + * [0.0, 1.0]. */ @ExpressionDescription( - usage = """_FUNC_(expr, pc) - Returns the percentile(s) of expr at pc (range: [0,1]). pc can be - a double or double array.""") + usage = + """ + _FUNC_(col, percentage) - Returns the exact percentile value of numeric column `col` at the + given percentage. The value of percentage must be between 0.0 and 1.0. + + _FUNC_(col, array(percentage1 [, percentage2]...)) - Returns the exact percentile value array + of numeric column `col` at the given percentage(s). Each value of the percentage array must + be between 0.0 and 1.0. + """) case class Percentile( child: Expression, - pc: Expression, + percentageExpression: Expression, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) extends ImperativeAggregate { - def this(child: Expression, pc: Expression) = { - this(child = child, pc = pc, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) + def this(child: Expression, percentageExpression: Expression) = { + this(child, percentageExpression, 0, 0) } override def prettyName: String = "percentile" @@ -57,13 +69,19 @@ case class Percentile( private var counts = new OpenHashMap[Number, Long] - override def children: Seq[Expression] = child :: pc :: Nil + // Mark as lazy so that percentageExpression is not evaluated during tree transformation. + private lazy val (returnPercentileArray: Boolean, percentages: Seq[Number]) = + evalPercentages(percentageExpression) + + override def children: Seq[Expression] = child :: percentageExpression :: Nil override def nullable: Boolean = false - override def dataType: DataType = ArrayType(DoubleType) + override def dataType: DataType = + if (returnPercentileArray) ArrayType(DoubleType) else DoubleType - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, NumericType) + override def inputTypes: Seq[AbstractDataType] = + Seq(NumericType, TypeCollection(NumericType, ArrayType)) override def checkInputDataTypes(): TypeCheckResult = TypeUtils.checkForOrderingExpr(child.dataType, "function percentile") @@ -82,32 +100,25 @@ case class Percentile( counts = new OpenHashMap[Number, Long] } - private def evalPercentiles(input: InternalRow): Seq[Number] = { - val exprs = children - val percentiles: Seq[Number] = children(1).eval(input) match { - case ar: GenericArrayData => - ar.asInstanceOf[GenericArrayData].array.map{ d => d.asInstanceOf[Number]} - case d: Number => - Seq(d.asInstanceOf[Number]) - case d: Decimal => - Seq(d.toDouble.asInstanceOf[Number]) - case _ => - sys.error("Percentiles expression cannot be analyzed.") + private def evalPercentages(expr: Expression): (Boolean, Seq[Number]) = { + val (isArrayType, values) = (expr.dataType, expr.eval()) match { + case (_, n: Number) => (false, Seq(n)) + case (_, d: Decimal) => (false, Seq(d.toDouble.asInstanceOf[Number])) + case (ArrayType(baseType: NumericType, _), arrayData: ArrayData) => + (true, arrayData.toArray[Number](baseType).toSeq) + case other => + throw new AnalysisException(s"Invalid data type ${other._1} for parameter percentage") } - require(percentiles.size > 0, "Percentiles should not be empty.") + require(values.size > 0, s"Percentage values should not be empty.") - require(percentiles.forall(percentile => - percentile.doubleValue() >= 0.0 && percentile.doubleValue() <= 1.0), - "Percentile value must be within the range of 0 to 1.") + require(values.forall(value => value.doubleValue() >= 0.0 && value.doubleValue() <= 1.0), + s"Percentage values must be between 0.0 and 1.0, current values = ${values.mkString(", ")}") - percentiles + (isArrayType, values) } override def update(buffer: InternalRow, input: InternalRow): Unit = { - // Eval percentiles and check whether its value is valid. - val percentiles = evalPercentiles(input) - val key = child.eval(input).asInstanceOf[Number] // Null values are ignored when computing percentiles. @@ -122,29 +133,33 @@ case class Percentile( override def eval(buffer: InternalRow): Any = { if (counts.isEmpty) { - return new GenericArrayData(Seq.empty) + return generateOutput(Seq.empty) } - val percentiles = evalPercentiles(buffer) - - // Sort all items and generate a sequence, then accumulate the counts - var ascOrder = new Ordering[Int]() { - override def compare(a: Int, b: Int): Int = a - b - } val sortedCounts = counts.toSeq.sortBy(_._1)(new Ordering[Number]() { override def compare(a: Number, b: Number): Int = scala.math.signum(a.doubleValue() - b.doubleValue()).toInt }) val aggreCounts = sortedCounts.scanLeft(sortedCounts.head._1, 0L) { (k1: (Number, Long), k2: (Number, Long)) => (k2._1, k1._2 + k2._2) - }.drop(1) + }.tail val maxPosition = aggreCounts.last._2 - 1 - new GenericArrayData(percentiles.map { percentile => + generateOutput(percentages.map { percentile => getPercentile(aggreCounts, maxPosition * percentile.doubleValue()).doubleValue() }) } + private def generateOutput(results: Seq[Double]): Any = { + if (results.isEmpty) { + null + } else if (returnPercentileArray) { + new GenericArrayData(results) + } else { + results.head + } + } + /** * Get the percentile value. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 77f1c391773a9..d5940c638acdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -612,46 +612,6 @@ object functions { */ def min(columnName: String): Column = min(Column(columnName)) - /** - * Aggregate function: returns the exact percentile(s) of the expression in a group at pc with - * range in [0, 1]. - * - * @group agg_funcs - * @since 2.1.0 - */ - def percentile(e: Column, pc: Seq[Double]): Column = withAggregateFunction { - Percentile(e.expr, CreateArray(pc.map(v => Literal(v)))) - } - - /** - * Aggregate function: returns the exact percentile(s) of the column in a group at pc with range - * in [0, 1]. - * - * @group agg_funcs - * @since 2.1.0 - */ - def percentile(columnName: String, pc: Seq[Double]): Column = percentile(Column(columnName), pc) - - /** - * Aggregate function: returns the exact percentile of the expression in a group at pc with - * range in [0, 1]. - * - * @group agg_funcs - * @since 2.1.0 - */ - def percentile(e: Column, pc: Double): Column = withAggregateFunction { - Percentile(e.expr, Literal(pc)) - } - - /** - * Aggregate function: returns the exact percentile of the column in a group at pc with range - * in [0, 1]. - * - * @group agg_funcs - * @since 2.1.0 - */ - def percentile(columnName: String, pc: Double): Column = percentile(Column(columnName), pc) - /** * Aggregate function: returns the skewness of the values in a group. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index ee52b9be2c81a..c08d349df1b9e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -514,28 +514,4 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { df.groupBy($"x").agg(countDistinct($"y"), sort_array(collect_list($"z"))), Seq(Row(1, 2, Seq("a", "b")), Row(3, 2, Seq("c", "c", "d")))) } - - test("percentile functions") { - val df = Seq(1, 3, 3, 6, 5, 4, 17, 38, 29, 400).toDF("a") - checkAnswer( - df.select(percentile($"a", 0.5), percentile($"a", Seq(0, 0.75, 1))), - Seq(Row(Seq(5.5), Seq(1.0, 26.0, 400.0))) - ) - } - - test("percentile functions with zero input rows.") { - val df = Seq(1, 3, 3, 6, 5, 4, 17, 38, 29, 400).toDF("a").where($"a" < 0) - checkAnswer( - df.select(percentile($"a", 0.5)), - Seq(Row(Seq.empty)) - ) - } - - test("percentile functions with empty percentile param") { - val df = Seq(1, 3, 3, 6, 5, 4, 17, 38, 29, 400).toDF("a") - val error = intercept[SparkException] { - df.select(percentile($"a", Seq())).collect() - } - assert(error.getMessage.contains("Percentiles should not be empty.")) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index a730500cfbb2c..1255c49104718 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -116,12 +116,11 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { dense_rank().over(Window.partitionBy("value").orderBy("key")), rank().over(Window.partitionBy("value").orderBy("key")), cume_dist().over(Window.partitionBy("value").orderBy("key")), - percent_rank().over(Window.partitionBy("value").orderBy("key")), - percentile("key", 0.5).over(Window.partitionBy("value").orderBy("key"))), - Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d, 0.0d, Seq(1.0)) :: - Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d / 3.0d, 0.0d, Seq(1.0)) :: - Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 1, 2, 2, 2, 1.0d, 0.5d, Seq(2.0)) :: - Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 2, 3, 2, 2, 1.0d, 0.5d, Seq(2.0)) :: Nil) + percent_rank().over(Window.partitionBy("value").orderBy("key"))), + Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d, 0.0d) :: + Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d / 3.0d, 0.0d) :: + Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 1, 2, 2, 2, 1.0d, 0.5d) :: + Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 2, 3, 2, 2, 1.0d, 0.5d) :: Nil) } test("window function should fail if order by clause is not specified") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index e776b4cba8bc6..4a8086d7e5400 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -851,42 +851,6 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te checkAnswer(df3.groupBy().agg(covar_pop("a", "b")), Row(0.0)) } - test("percentile") { - checkAnswer( - spark.sql( - """ - |SELECT - | percentile(value1, 0.5) - |FROM agg2 - |LIMIT 1 - """.stripMargin), - Row(Seq(1)) :: Nil) - - checkAnswer( - spark.sql( - """ - |SELECT - | percentile(value1, array(0d, 0.5d, 1d)) - |FROM agg2 - |LIMIT 1 - """.stripMargin), - Row(Seq(-60, 1, 100)) :: Nil) - - checkAnswer( - spark.sql( - """ - |SELECT - | key, - | percentile(value1, array(0d, 0.5d, 1d)) - |FROM agg2 - |GROUP BY key - """.stripMargin), - Row(null, Seq(-60, -10, 100)) :: - Row(1, Seq(10, 30, 30)) :: - Row(2, Seq(-1, 1, 1)) :: - Row(3, Seq.empty) :: Nil) - } - test("no aggregation function (SPARK-11486)") { val df = spark.range(20).selectExpr("id", "repeat(id, 1) as s") .groupBy("s").count() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 2235ace6d98ad..48adc833f4b22 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -136,7 +136,7 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { test("SPARK-2693 udaf aggregates test") { checkAnswer(sql("SELECT percentile(key, 1) FROM src LIMIT 1"), - sql("SELECT array(max(key)) FROM src").collect().toSeq) + sql("SELECT max(key) FROM src").collect().toSeq) checkAnswer(sql("SELECT percentile(key, array(1, 1)) FROM src LIMIT 1"), sql("SELECT array(max(key), max(key)) FROM src").collect().toSeq) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala index 2b65f51f888a2..5ea85aabe209c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala @@ -83,32 +83,32 @@ class WindowQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleto """.stripMargin), // scalastyle:off Seq( - Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 1, 1, 0.3333333333333333, 0.0, 1, 2, 2.0, 0.0, 2, 2, 2, Seq(2.0)), - Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 1, 1, 0.3333333333333333, 0.0, 1, 2, 2.0, 0.0, 2, 2, 2, Seq(4.0)), - Row("Manufacturer#1", "almond antique chartreuse lavender yellow", 34, 3, 2, 0.5, 0.4, 2, 3, 12.666666666666666, 18.475208614068027, 2, 34, 2, Seq(6.0)), - Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 4, 3, 0.6666666666666666, 0.6, 2, 4, 11.0, 15.448840301675292, 2, 6, 2, Seq(28.0)), - Row("Manufacturer#1", "almond aquamarine burnished black steel", 28, 5, 4, 0.8333333333333334, 0.8, 3, 5, 14.4, 15.388307249337076, 2, 28, 34, Seq(31.0)), - Row("Manufacturer#1", "almond aquamarine pink moccasin thistle", 42, 6, 5, 1.0, 1.0, 3, 6, 19.0, 17.787636155487327, 2, 42, 6, Seq(28.0)), - Row("Manufacturer#2", "almond antique violet chocolate turquoise", 14, 1, 1, 0.2, 0.0, 1, 1, 14.0, Double.NaN, 4, 14, 14, Seq(14.0)), - Row("Manufacturer#2", "almond antique violet turquoise frosted", 40, 2, 2, 0.4, 0.25, 1, 2, 27.0, 18.384776310850235, 4, 40, 14, Seq(19.5)), - Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 3, 3, 0.6, 0.5, 2, 3, 18.666666666666668, 19.42506971244462, 4, 2, 14, Seq(18.0)), - Row("Manufacturer#2", "almond aquamarine rose maroon antique", 25, 4, 4, 0.8, 0.75, 2, 4, 20.25, 16.17353805861084, 4, 25, 40, Seq(21.5)), - Row("Manufacturer#2", "almond aquamarine sandy cyan gainsboro", 18, 5, 5, 1.0, 1.0, 3, 5, 19.8, 14.042791745233567, 4, 18, 2, Seq(18.0)), - Row("Manufacturer#3", "almond antique chartreuse khaki white", 17, 1, 1, 0.2, 0.0, 1, 1, 17.0,Double.NaN, 2, 17, 17, Seq(17.0)), - Row("Manufacturer#3", "almond antique forest lavender goldenrod", 14, 2, 2, 0.4, 0.25, 1, 2, 15.5, 2.1213203435596424, 2, 14, 17, Seq(15.5)), - Row("Manufacturer#3", "almond antique metallic orange dim", 19, 3, 3, 0.6, 0.5, 2, 3, 16.666666666666668, 2.516611478423583, 2, 19, 17, Seq(17.0)), - Row("Manufacturer#3", "almond antique misty red olive", 1, 4, 4, 0.8, 0.75, 2, 4, 12.75, 8.098353742170895, 2, 1, 14, Seq(16.5)), - Row("Manufacturer#3", "almond antique olive coral navajo", 45, 5, 5, 1.0, 1.0, 3, 5, 19.2, 16.037456157383566, 2, 45, 19, Seq(19.0)), - Row("Manufacturer#4", "almond antique gainsboro frosted violet", 10, 1, 1, 0.2, 0.0, 1, 1, 10.0, Double.NaN, 0, 10, 10, Seq(27.0)), - Row("Manufacturer#4", "almond antique violet mint lemon", 39, 2, 2, 0.4, 0.25, 1, 2, 24.5, 20.506096654409877, 0, 39, 10, Seq(18.5)), - Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 3, 3, 0.6, 0.5, 2, 3, 25.333333333333332, 14.571661996262929, 0, 27, 10, Seq(12.0)), - Row("Manufacturer#4", "almond aquamarine yellow dodger mint", 7, 4, 4, 0.8, 0.75, 2, 4, 20.75, 15.01943185787443, 0, 7, 39, Seq(19.5)), - Row("Manufacturer#4", "almond azure aquamarine papaya violet", 12, 5, 5, 1.0, 1.0, 3, 5, 19.0, 13.583077707206124, 0, 12, 27, Seq(12.0)), - Row("Manufacturer#5", "almond antique blue firebrick mint", 31, 1, 1, 0.2, 0.0, 1, 1, 31.0, Double.NaN, 1, 31, 31, Seq(6.0)), - Row("Manufacturer#5", "almond antique medium spring khaki", 6, 2, 2, 0.4, 0.25, 1, 2, 18.5, 17.67766952966369, 1, 6, 31, Seq(18.5)), - Row("Manufacturer#5", "almond antique sky peru orange", 2, 3, 3, 0.6, 0.5, 2, 3, 13.0, 15.716233645501712, 1, 2, 31, Seq(23.0)), - Row("Manufacturer#5", "almond aquamarine dodger light gainsboro", 46, 4, 4, 0.8, 0.75, 2, 4, 21.25, 20.902551678363736, 1, 46, 6, Seq(14.5)), - Row("Manufacturer#5", "almond azure blanched chiffon midnight", 23, 5, 5, 1.0, 1.0, 3, 5, 21.6, 18.1190507477627, 1, 23, 2, Seq(23.0)))) + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 1, 1, 0.3333333333333333, 0.0, 1, 2, 2.0, 0.0, 2, 2, 2, 2.0), + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 1, 1, 0.3333333333333333, 0.0, 1, 2, 2.0, 0.0, 2, 2, 2, 4.0), + Row("Manufacturer#1", "almond antique chartreuse lavender yellow", 34, 3, 2, 0.5, 0.4, 2, 3, 12.666666666666666, 18.475208614068027, 2, 34, 2, 6.0), + Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 4, 3, 0.6666666666666666, 0.6, 2, 4, 11.0, 15.448840301675292, 2, 6, 2, 28.0), + Row("Manufacturer#1", "almond aquamarine burnished black steel", 28, 5, 4, 0.8333333333333334, 0.8, 3, 5, 14.4, 15.388307249337076, 2, 28, 34, 31.0), + Row("Manufacturer#1", "almond aquamarine pink moccasin thistle", 42, 6, 5, 1.0, 1.0, 3, 6, 19.0, 17.787636155487327, 2, 42, 6, 28.0), + Row("Manufacturer#2", "almond antique violet chocolate turquoise", 14, 1, 1, 0.2, 0.0, 1, 1, 14.0, Double.NaN, 4, 14, 14, 14.0), + Row("Manufacturer#2", "almond antique violet turquoise frosted", 40, 2, 2, 0.4, 0.25, 1, 2, 27.0, 18.384776310850235, 4, 40, 14, 19.5), + Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 3, 3, 0.6, 0.5, 2, 3, 18.666666666666668, 19.42506971244462, 4, 2, 14, 18.0), + Row("Manufacturer#2", "almond aquamarine rose maroon antique", 25, 4, 4, 0.8, 0.75, 2, 4, 20.25, 16.17353805861084, 4, 25, 40, 21.5), + Row("Manufacturer#2", "almond aquamarine sandy cyan gainsboro", 18, 5, 5, 1.0, 1.0, 3, 5, 19.8, 14.042791745233567, 4, 18, 2, 18.0), + Row("Manufacturer#3", "almond antique chartreuse khaki white", 17, 1, 1, 0.2, 0.0, 1, 1, 17.0,Double.NaN, 2, 17, 17, 17.0), + Row("Manufacturer#3", "almond antique forest lavender goldenrod", 14, 2, 2, 0.4, 0.25, 1, 2, 15.5, 2.1213203435596424, 2, 14, 17, 15.5), + Row("Manufacturer#3", "almond antique metallic orange dim", 19, 3, 3, 0.6, 0.5, 2, 3, 16.666666666666668, 2.516611478423583, 2, 19, 17, 17.0), + Row("Manufacturer#3", "almond antique misty red olive", 1, 4, 4, 0.8, 0.75, 2, 4, 12.75, 8.098353742170895, 2, 1, 14, 16.5), + Row("Manufacturer#3", "almond antique olive coral navajo", 45, 5, 5, 1.0, 1.0, 3, 5, 19.2, 16.037456157383566, 2, 45, 19, 19.0), + Row("Manufacturer#4", "almond antique gainsboro frosted violet", 10, 1, 1, 0.2, 0.0, 1, 1, 10.0, Double.NaN, 0, 10, 10, 27.0), + Row("Manufacturer#4", "almond antique violet mint lemon", 39, 2, 2, 0.4, 0.25, 1, 2, 24.5, 20.506096654409877, 0, 39, 10, 18.5), + Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 3, 3, 0.6, 0.5, 2, 3, 25.333333333333332, 14.571661996262929, 0, 27, 10, 12.0), + Row("Manufacturer#4", "almond aquamarine yellow dodger mint", 7, 4, 4, 0.8, 0.75, 2, 4, 20.75, 15.01943185787443, 0, 7, 39, 19.5), + Row("Manufacturer#4", "almond azure aquamarine papaya violet", 12, 5, 5, 1.0, 1.0, 3, 5, 19.0, 13.583077707206124, 0, 12, 27, 12.0), + Row("Manufacturer#5", "almond antique blue firebrick mint", 31, 1, 1, 0.2, 0.0, 1, 1, 31.0, Double.NaN, 1, 31, 31, 6.0), + Row("Manufacturer#5", "almond antique medium spring khaki", 6, 2, 2, 0.4, 0.25, 1, 2, 18.5, 17.67766952966369, 1, 6, 31, 18.5), + Row("Manufacturer#5", "almond antique sky peru orange", 2, 3, 3, 0.6, 0.5, 2, 3, 13.0, 15.716233645501712, 1, 2, 31, 23.0), + Row("Manufacturer#5", "almond aquamarine dodger light gainsboro", 46, 4, 4, 0.8, 0.75, 2, 4, 21.25, 20.902551678363736, 1, 46, 6, 14.5), + Row("Manufacturer#5", "almond azure blanched chiffon midnight", 23, 5, 5, 1.0, 1.0, 3, 5, 21.6, 18.1190507477627, 1, 23, 2, 23.0))) // scalastyle:on } From 7ad1a3518b34aedb7cb2a0f6a36a2e3dfcaf0c71 Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Mon, 24 Oct 2016 18:42:55 +0800 Subject: [PATCH 14/22] remove unnecessary test cases. --- .../spark/sql/DataFrameAggregateSuite.scala | 1 - ...stDISTs-0-d9065e533430691d70b3370174fbbd50 | 52 +++++++++--------- .../sql/catalyst/ExpressionToSQLSuite.scala | 5 +- .../sql/hive/execution/WindowQuerySuite.scala | 55 +++++++++---------- 4 files changed, 55 insertions(+), 58 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index c08d349df1b9e..7aa4f0026f275 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql -import org.apache.spark.SparkException import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 21. testDISTs-0-d9065e533430691d70b3370174fbbd50 b/sql/hive/src/test/resources/golden/windowing.q -- 21. testDISTs-0-d9065e533430691d70b3370174fbbd50 index 6dbaaad364ea8..e7c39f454fb37 100644 --- a/sql/hive/src/test/resources/golden/windowing.q -- 21. testDISTs-0-d9065e533430691d70b3370174fbbd50 +++ b/sql/hive/src/test/resources/golden/windowing.q -- 21. testDISTs-0-d9065e533430691d70b3370174fbbd50 @@ -1,26 +1,26 @@ -Manufacturer#1 almond antique burnished rose metallic 2 [{"x":1173.15,"y":2.0},{"x":1753.76,"y":1.0}] [121152.0] 1 -Manufacturer#1 almond antique burnished rose metallic 2 [{"x":1173.15,"y":2.0},{"x":1602.59,"y":1.0},{"x":1753.76,"y":1.0}] [115872.0] 2 -Manufacturer#1 almond antique chartreuse lavender yellow 34 [{"x":1173.15,"y":2.0},{"x":1414.42,"y":1.0},{"x":1602.59,"y":1.0},{"x":1753.76,"y":1.0}] [110592.0] 3 -Manufacturer#1 almond antique salmon chartreuse burlywood 6 [{"x":1173.15,"y":1.0},{"x":1414.42,"y":1.0},{"x":1602.59,"y":1.0},{"x":1632.66,"y":1.0},{"x":1753.76,"y":1.0}] [86428.0] 4 -Manufacturer#1 almond aquamarine burnished black steel 28 [{"x":1414.42,"y":1.0},{"x":1602.59,"y":1.0},{"x":1632.66,"y":1.0},{"x":1753.76,"y":1.0}] [86098.0] 5 -Manufacturer#1 almond aquamarine pink moccasin thistle 42 [{"x":1414.42,"y":1.0},{"x":1602.59,"y":1.0},{"x":1632.66,"y":1.0}] [86428.0] 6 -Manufacturer#2 almond antique violet chocolate turquoise 14 [{"x":1690.68,"y":1.0},{"x":1800.7,"y":1.0},{"x":2031.98,"y":1.0}] [146985.0] 1 -Manufacturer#2 almond antique violet turquoise frosted 40 [{"x":1690.68,"y":1.0},{"x":1698.66,"y":1.0},{"x":1800.7,"y":1.0},{"x":2031.98,"y":1.0}] [139825.5] 2 -Manufacturer#2 almond aquamarine midnight light salmon 2 [{"x":1690.68,"y":1.0},{"x":1698.66,"y":1.0},{"x":1701.6,"y":1.0},{"x":1800.7,"y":1.0},{"x":2031.98,"y":1.0}] [146985.0] 3 -Manufacturer#2 almond aquamarine rose maroon antique 25 [{"x":1698.66,"y":1.0},{"x":1701.6,"y":1.0},{"x":1800.7,"y":1.0},{"x":2031.98,"y":1.0}] [169347.0] 4 -Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 [{"x":1698.66,"y":1.0},{"x":1701.6,"y":1.0},{"x":2031.98,"y":1.0}] [146985.0] 5 -Manufacturer#3 almond antique chartreuse khaki white 17 [{"x":1190.27,"y":1.0},{"x":1410.39,"y":1.0},{"x":1671.68,"y":1.0}] [90681.0] 1 -Manufacturer#3 almond antique forest lavender goldenrod 14 [{"x":1190.27,"y":1.0},{"x":1410.39,"y":1.0},{"x":1671.68,"y":1.0},{"x":1922.98,"y":1.0}] [65831.5] 2 -Manufacturer#3 almond antique metallic orange dim 19 [{"x":1190.27,"y":1.0},{"x":1337.29,"y":1.0},{"x":1410.39,"y":1.0},{"x":1671.68,"y":1.0},{"x":1922.98,"y":1.0}] [90681.0] 3 -Manufacturer#3 almond antique misty red olive 1 [{"x":1190.27,"y":1.0},{"x":1337.29,"y":1.0},{"x":1410.39,"y":1.0},{"x":1922.98,"y":1.0}] [76690.0] 4 -Manufacturer#3 almond antique olive coral navajo 45 [{"x":1337.29,"y":1.0},{"x":1410.39,"y":1.0},{"x":1922.98,"y":1.0}] [112398.0] 5 -Manufacturer#4 almond antique gainsboro frosted violet 10 [{"x":1206.26,"y":1.0},{"x":1375.42,"y":1.0},{"x":1620.67,"y":1.0}] [48427.0] 1 -Manufacturer#4 almond antique violet mint lemon 39 [{"x":1206.26,"y":1.0},{"x":1375.42,"y":1.0},{"x":1620.67,"y":1.0},{"x":1844.92,"y":1.0}] [46844.0] 2 -Manufacturer#4 almond aquamarine floral ivory bisque 27 [{"x":1206.26,"y":1.0},{"x":1290.35,"y":1.0},{"x":1375.42,"y":1.0},{"x":1620.67,"y":1.0},{"x":1844.92,"y":1.0}] [45261.0] 3 -Manufacturer#4 almond aquamarine yellow dodger mint 7 [{"x":1206.26,"y":1.0},{"x":1290.35,"y":1.0},{"x":1375.42,"y":1.0},{"x":1844.92,"y":1.0}] [39309.0] 4 -Manufacturer#4 almond azure aquamarine papaya violet 12 [{"x":1206.26,"y":1.0},{"x":1290.35,"y":1.0},{"x":1844.92,"y":1.0}] [33357.0] 5 -Manufacturer#5 almond antique blue firebrick mint 31 [{"x":1611.66,"y":1.0},{"x":1788.73,"y":1.0},{"x":1789.69,"y":1.0}] [155733.0] 1 -Manufacturer#5 almond antique medium spring khaki 6 [{"x":1018.1,"y":1.0},{"x":1611.66,"y":1.0},{"x":1788.73,"y":1.0},{"x":1789.69,"y":1.0}] [99201.0] 2 -Manufacturer#5 almond antique sky peru orange 2 [{"x":1018.1,"y":1.0},{"x":1464.48,"y":1.0},{"x":1611.66,"y":1.0},{"x":1788.73,"y":1.0},{"x":1789.69,"y":1.0}] [78486.0] 3 -Manufacturer#5 almond aquamarine dodger light gainsboro 46 [{"x":1018.1,"y":1.0},{"x":1464.48,"y":1.0},{"x":1611.66,"y":1.0},{"x":1788.73,"y":1.0}] [60577.5] 4 -Manufacturer#5 almond azure blanched chiffon midnight 23 [{"x":1018.1,"y":1.0},{"x":1464.48,"y":1.0},{"x":1788.73,"y":1.0}] [78486.0] 5 +Manufacturer#1 almond antique burnished rose metallic 2 [{"x":1173.15,"y":2.0},{"x":1753.76,"y":1.0}] 121152.0 1 +Manufacturer#1 almond antique burnished rose metallic 2 [{"x":1173.15,"y":2.0},{"x":1602.59,"y":1.0},{"x":1753.76,"y":1.0}] 115872.0 2 +Manufacturer#1 almond antique chartreuse lavender yellow 34 [{"x":1173.15,"y":2.0},{"x":1414.42,"y":1.0},{"x":1602.59,"y":1.0},{"x":1753.76,"y":1.0}] 110592.0 3 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 [{"x":1173.15,"y":1.0},{"x":1414.42,"y":1.0},{"x":1602.59,"y":1.0},{"x":1632.66,"y":1.0},{"x":1753.76,"y":1.0}] 86428.0 4 +Manufacturer#1 almond aquamarine burnished black steel 28 [{"x":1414.42,"y":1.0},{"x":1602.59,"y":1.0},{"x":1632.66,"y":1.0},{"x":1753.76,"y":1.0}] 86098.0 5 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 [{"x":1414.42,"y":1.0},{"x":1602.59,"y":1.0},{"x":1632.66,"y":1.0}] 86428.0 6 +Manufacturer#2 almond antique violet chocolate turquoise 14 [{"x":1690.68,"y":1.0},{"x":1800.7,"y":1.0},{"x":2031.98,"y":1.0}] 146985.0 1 +Manufacturer#2 almond antique violet turquoise frosted 40 [{"x":1690.68,"y":1.0},{"x":1698.66,"y":1.0},{"x":1800.7,"y":1.0},{"x":2031.98,"y":1.0}] 139825.5 2 +Manufacturer#2 almond aquamarine midnight light salmon 2 [{"x":1690.68,"y":1.0},{"x":1698.66,"y":1.0},{"x":1701.6,"y":1.0},{"x":1800.7,"y":1.0},{"x":2031.98,"y":1.0}] 146985.0 3 +Manufacturer#2 almond aquamarine rose maroon antique 25 [{"x":1698.66,"y":1.0},{"x":1701.6,"y":1.0},{"x":1800.7,"y":1.0},{"x":2031.98,"y":1.0}] 169347.0 4 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 [{"x":1698.66,"y":1.0},{"x":1701.6,"y":1.0},{"x":2031.98,"y":1.0}] 146985.0 5 +Manufacturer#3 almond antique chartreuse khaki white 17 [{"x":1190.27,"y":1.0},{"x":1410.39,"y":1.0},{"x":1671.68,"y":1.0}] 90681.0 1 +Manufacturer#3 almond antique forest lavender goldenrod 14 [{"x":1190.27,"y":1.0},{"x":1410.39,"y":1.0},{"x":1671.68,"y":1.0},{"x":1922.98,"y":1.0}] 65831.5 2 +Manufacturer#3 almond antique metallic orange dim 19 [{"x":1190.27,"y":1.0},{"x":1337.29,"y":1.0},{"x":1410.39,"y":1.0},{"x":1671.68,"y":1.0},{"x":1922.98,"y":1.0}] 90681.0 3 +Manufacturer#3 almond antique misty red olive 1 [{"x":1190.27,"y":1.0},{"x":1337.29,"y":1.0},{"x":1410.39,"y":1.0},{"x":1922.98,"y":1.0}] 76690.0 4 +Manufacturer#3 almond antique olive coral navajo 45 [{"x":1337.29,"y":1.0},{"x":1410.39,"y":1.0},{"x":1922.98,"y":1.0}] 112398.0 5 +Manufacturer#4 almond antique gainsboro frosted violet 10 [{"x":1206.26,"y":1.0},{"x":1375.42,"y":1.0},{"x":1620.67,"y":1.0}] 48427.0 1 +Manufacturer#4 almond antique violet mint lemon 39 [{"x":1206.26,"y":1.0},{"x":1375.42,"y":1.0},{"x":1620.67,"y":1.0},{"x":1844.92,"y":1.0}] 46844.0 2 +Manufacturer#4 almond aquamarine floral ivory bisque 27 [{"x":1206.26,"y":1.0},{"x":1290.35,"y":1.0},{"x":1375.42,"y":1.0},{"x":1620.67,"y":1.0},{"x":1844.92,"y":1.0}] 45261.0 3 +Manufacturer#4 almond aquamarine yellow dodger mint 7 [{"x":1206.26,"y":1.0},{"x":1290.35,"y":1.0},{"x":1375.42,"y":1.0},{"x":1844.92,"y":1.0}] 39309.0 4 +Manufacturer#4 almond azure aquamarine papaya violet 12 [{"x":1206.26,"y":1.0},{"x":1290.35,"y":1.0},{"x":1844.92,"y":1.0}] 33357.0 5 +Manufacturer#5 almond antique blue firebrick mint 31 [{"x":1611.66,"y":1.0},{"x":1788.73,"y":1.0},{"x":1789.69,"y":1.0}] 155733.0 1 +Manufacturer#5 almond antique medium spring khaki 6 [{"x":1018.1,"y":1.0},{"x":1611.66,"y":1.0},{"x":1788.73,"y":1.0},{"x":1789.69,"y":1.0}] 99201.0 2 +Manufacturer#5 almond antique sky peru orange 2 [{"x":1018.1,"y":1.0},{"x":1464.48,"y":1.0},{"x":1611.66,"y":1.0},{"x":1788.73,"y":1.0},{"x":1789.69,"y":1.0}] 78486.0 3 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 [{"x":1018.1,"y":1.0},{"x":1464.48,"y":1.0},{"x":1611.66,"y":1.0},{"x":1788.73,"y":1.0}] 60577.5 4 +Manufacturer#5 almond azure blanched chiffon midnight 23 [{"x":1018.1,"y":1.0},{"x":1464.48,"y":1.0},{"x":1788.73,"y":1.0}] 78486.0 5 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala index 3a2b4efc887e4..27ea167b9050c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala @@ -41,8 +41,6 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils { .saveAsTable("t1") spark.range(10).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write.saveAsTable("t2") - - spark.range(10).select('id as 'key, 'id as 'value).write.saveAsTable("t3") } override protected def afterAll(): Unit = { @@ -175,7 +173,8 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils { checkSqlGeneration("SELECT max(value) FROM t1 GROUP BY key") checkSqlGeneration("SELECT mean(value) FROM t1 GROUP BY key") checkSqlGeneration("SELECT min(value) FROM t1 GROUP BY key") - checkSqlGeneration("SELECT percentile(value, array(0.5d, 0.9d)) FROM t3 GROUP BY key") + checkSqlGeneration("SELECT percentile(value, 0.25) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT percentile(value, array(0.25, 0.75)) FROM t1 GROUP BY key") checkSqlGeneration("SELECT skewness(value) FROM t1 GROUP BY key") checkSqlGeneration("SELECT stddev(value) FROM t1 GROUP BY key") checkSqlGeneration("SELECT stddev_pop(value) FROM t1 GROUP BY key") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala index 5ea85aabe209c..0ff3511c87a4f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala @@ -75,40 +75,39 @@ class WindowQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleto |stddev(p_size) over(distribute by p_mfgr sort by p_name) as st, |first_value(p_size % 5) over(distribute by p_mfgr sort by p_name) as fv, |last_value(p_size) over(distribute by p_mfgr sort by p_name) as lv, - |first_value(p_size) over w1 as fvW1, - |percentile(p_size, 0.5) over w1 as per + |first_value(p_size) over w1 as fvW1 |from part |window w1 as (distribute by p_mfgr sort by p_mfgr, p_name | rows between 2 preceding and 2 following) """.stripMargin), // scalastyle:off Seq( - Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 1, 1, 0.3333333333333333, 0.0, 1, 2, 2.0, 0.0, 2, 2, 2, 2.0), - Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 1, 1, 0.3333333333333333, 0.0, 1, 2, 2.0, 0.0, 2, 2, 2, 4.0), - Row("Manufacturer#1", "almond antique chartreuse lavender yellow", 34, 3, 2, 0.5, 0.4, 2, 3, 12.666666666666666, 18.475208614068027, 2, 34, 2, 6.0), - Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 4, 3, 0.6666666666666666, 0.6, 2, 4, 11.0, 15.448840301675292, 2, 6, 2, 28.0), - Row("Manufacturer#1", "almond aquamarine burnished black steel", 28, 5, 4, 0.8333333333333334, 0.8, 3, 5, 14.4, 15.388307249337076, 2, 28, 34, 31.0), - Row("Manufacturer#1", "almond aquamarine pink moccasin thistle", 42, 6, 5, 1.0, 1.0, 3, 6, 19.0, 17.787636155487327, 2, 42, 6, 28.0), - Row("Manufacturer#2", "almond antique violet chocolate turquoise", 14, 1, 1, 0.2, 0.0, 1, 1, 14.0, Double.NaN, 4, 14, 14, 14.0), - Row("Manufacturer#2", "almond antique violet turquoise frosted", 40, 2, 2, 0.4, 0.25, 1, 2, 27.0, 18.384776310850235, 4, 40, 14, 19.5), - Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 3, 3, 0.6, 0.5, 2, 3, 18.666666666666668, 19.42506971244462, 4, 2, 14, 18.0), - Row("Manufacturer#2", "almond aquamarine rose maroon antique", 25, 4, 4, 0.8, 0.75, 2, 4, 20.25, 16.17353805861084, 4, 25, 40, 21.5), - Row("Manufacturer#2", "almond aquamarine sandy cyan gainsboro", 18, 5, 5, 1.0, 1.0, 3, 5, 19.8, 14.042791745233567, 4, 18, 2, 18.0), - Row("Manufacturer#3", "almond antique chartreuse khaki white", 17, 1, 1, 0.2, 0.0, 1, 1, 17.0,Double.NaN, 2, 17, 17, 17.0), - Row("Manufacturer#3", "almond antique forest lavender goldenrod", 14, 2, 2, 0.4, 0.25, 1, 2, 15.5, 2.1213203435596424, 2, 14, 17, 15.5), - Row("Manufacturer#3", "almond antique metallic orange dim", 19, 3, 3, 0.6, 0.5, 2, 3, 16.666666666666668, 2.516611478423583, 2, 19, 17, 17.0), - Row("Manufacturer#3", "almond antique misty red olive", 1, 4, 4, 0.8, 0.75, 2, 4, 12.75, 8.098353742170895, 2, 1, 14, 16.5), - Row("Manufacturer#3", "almond antique olive coral navajo", 45, 5, 5, 1.0, 1.0, 3, 5, 19.2, 16.037456157383566, 2, 45, 19, 19.0), - Row("Manufacturer#4", "almond antique gainsboro frosted violet", 10, 1, 1, 0.2, 0.0, 1, 1, 10.0, Double.NaN, 0, 10, 10, 27.0), - Row("Manufacturer#4", "almond antique violet mint lemon", 39, 2, 2, 0.4, 0.25, 1, 2, 24.5, 20.506096654409877, 0, 39, 10, 18.5), - Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 3, 3, 0.6, 0.5, 2, 3, 25.333333333333332, 14.571661996262929, 0, 27, 10, 12.0), - Row("Manufacturer#4", "almond aquamarine yellow dodger mint", 7, 4, 4, 0.8, 0.75, 2, 4, 20.75, 15.01943185787443, 0, 7, 39, 19.5), - Row("Manufacturer#4", "almond azure aquamarine papaya violet", 12, 5, 5, 1.0, 1.0, 3, 5, 19.0, 13.583077707206124, 0, 12, 27, 12.0), - Row("Manufacturer#5", "almond antique blue firebrick mint", 31, 1, 1, 0.2, 0.0, 1, 1, 31.0, Double.NaN, 1, 31, 31, 6.0), - Row("Manufacturer#5", "almond antique medium spring khaki", 6, 2, 2, 0.4, 0.25, 1, 2, 18.5, 17.67766952966369, 1, 6, 31, 18.5), - Row("Manufacturer#5", "almond antique sky peru orange", 2, 3, 3, 0.6, 0.5, 2, 3, 13.0, 15.716233645501712, 1, 2, 31, 23.0), - Row("Manufacturer#5", "almond aquamarine dodger light gainsboro", 46, 4, 4, 0.8, 0.75, 2, 4, 21.25, 20.902551678363736, 1, 46, 6, 14.5), - Row("Manufacturer#5", "almond azure blanched chiffon midnight", 23, 5, 5, 1.0, 1.0, 3, 5, 21.6, 18.1190507477627, 1, 23, 2, 23.0))) + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 1, 1, 0.3333333333333333, 0.0, 1, 2, 2.0, 0.0, 2, 2, 2), + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 1, 1, 0.3333333333333333, 0.0, 1, 2, 2.0, 0.0, 2, 2, 2), + Row("Manufacturer#1", "almond antique chartreuse lavender yellow", 34, 3, 2, 0.5, 0.4, 2, 3, 12.666666666666666, 18.475208614068027, 2, 34, 2), + Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 4, 3, 0.6666666666666666, 0.6, 2, 4, 11.0, 15.448840301675292, 2, 6, 2), + Row("Manufacturer#1", "almond aquamarine burnished black steel", 28, 5, 4, 0.8333333333333334, 0.8, 3, 5, 14.4, 15.388307249337076, 2, 28, 34), + Row("Manufacturer#1", "almond aquamarine pink moccasin thistle", 42, 6, 5, 1.0, 1.0, 3, 6, 19.0, 17.787636155487327, 2, 42, 6), + Row("Manufacturer#2", "almond antique violet chocolate turquoise", 14, 1, 1, 0.2, 0.0, 1, 1, 14.0, Double.NaN, 4, 14, 14), + Row("Manufacturer#2", "almond antique violet turquoise frosted", 40, 2, 2, 0.4, 0.25, 1, 2, 27.0, 18.384776310850235, 4, 40, 14), + Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 3, 3, 0.6, 0.5, 2, 3, 18.666666666666668, 19.42506971244462, 4, 2, 14), + Row("Manufacturer#2", "almond aquamarine rose maroon antique", 25, 4, 4, 0.8, 0.75, 2, 4, 20.25, 16.17353805861084, 4, 25, 40), + Row("Manufacturer#2", "almond aquamarine sandy cyan gainsboro", 18, 5, 5, 1.0, 1.0, 3, 5, 19.8, 14.042791745233567, 4, 18, 2), + Row("Manufacturer#3", "almond antique chartreuse khaki white", 17, 1, 1, 0.2, 0.0, 1, 1, 17.0,Double.NaN, 2, 17, 17), + Row("Manufacturer#3", "almond antique forest lavender goldenrod", 14, 2, 2, 0.4, 0.25, 1, 2, 15.5, 2.1213203435596424, 2, 14, 17), + Row("Manufacturer#3", "almond antique metallic orange dim", 19, 3, 3, 0.6, 0.5, 2, 3, 16.666666666666668, 2.516611478423583, 2, 19, 17), + Row("Manufacturer#3", "almond antique misty red olive", 1, 4, 4, 0.8, 0.75, 2, 4, 12.75, 8.098353742170895, 2, 1, 14), + Row("Manufacturer#3", "almond antique olive coral navajo", 45, 5, 5, 1.0, 1.0, 3, 5, 19.2, 16.037456157383566, 2, 45, 19), + Row("Manufacturer#4", "almond antique gainsboro frosted violet", 10, 1, 1, 0.2, 0.0, 1, 1, 10.0, Double.NaN, 0, 10, 10), + Row("Manufacturer#4", "almond antique violet mint lemon", 39, 2, 2, 0.4, 0.25, 1, 2, 24.5, 20.506096654409877, 0, 39, 10), + Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 3, 3, 0.6, 0.5, 2, 3, 25.333333333333332, 14.571661996262929, 0, 27, 10), + Row("Manufacturer#4", "almond aquamarine yellow dodger mint", 7, 4, 4, 0.8, 0.75, 2, 4, 20.75, 15.01943185787443, 0, 7, 39), + Row("Manufacturer#4", "almond azure aquamarine papaya violet", 12, 5, 5, 1.0, 1.0, 3, 5, 19.0, 13.583077707206124, 0, 12, 27), + Row("Manufacturer#5", "almond antique blue firebrick mint", 31, 1, 1, 0.2, 0.0, 1, 1, 31.0, Double.NaN, 1, 31, 31), + Row("Manufacturer#5", "almond antique medium spring khaki", 6, 2, 2, 0.4, 0.25, 1, 2, 18.5, 17.67766952966369, 1, 6, 31), + Row("Manufacturer#5", "almond antique sky peru orange", 2, 3, 3, 0.6, 0.5, 2, 3, 13.0, 15.716233645501712, 1, 2, 31), + Row("Manufacturer#5", "almond aquamarine dodger light gainsboro", 46, 4, 4, 0.8, 0.75, 2, 4, 21.25, 20.902551678363736, 1, 46, 6), + Row("Manufacturer#5", "almond azure blanched chiffon midnight", 23, 5, 5, 1.0, 1.0, 3, 5, 21.6, 18.1190507477627, 1, 23, 2))) // scalastyle:on } From 93f8285ee5b039aff951d179a4f0a12b51a365a6 Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Tue, 25 Oct 2016 01:13:50 +0800 Subject: [PATCH 15/22] add expression level test cases. --- .../expressions/aggregate/Percentile.scala | 16 +- .../aggregate/PercentileSuite.scala | 191 ++++++++++++++++++ 2 files changed, 200 insertions(+), 7 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index 94c5f61795234..136991ec495d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -75,7 +75,8 @@ case class Percentile( override def children: Seq[Expression] = child :: percentageExpression :: Nil - override def nullable: Boolean = false + // Returns null for empty inputs + override def nullable: Boolean = true override def dataType: DataType = if (returnPercentileArray) ArrayType(DoubleType) else DoubleType @@ -84,7 +85,7 @@ case class Percentile( Seq(NumericType, TypeCollection(NumericType, ArrayType)) override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForOrderingExpr(child.dataType, "function percentile") + TypeUtils.checkForNumericExpr(child.dataType, "function percentile") override def supportsPartial: Boolean = false @@ -102,16 +103,17 @@ case class Percentile( private def evalPercentages(expr: Expression): (Boolean, Seq[Number]) = { val (isArrayType, values) = (expr.dataType, expr.eval()) match { - case (_, n: Number) => (false, Seq(n)) - case (_, d: Decimal) => (false, Seq(d.toDouble.asInstanceOf[Number])) + case (_, n: Number) => (false, Array(n)) + case (_, d: Decimal) => (false, Array(d.toDouble.asInstanceOf[Number])) case (ArrayType(baseType: NumericType, _), arrayData: ArrayData) => - (true, arrayData.toArray[Number](baseType).toSeq) + val numericArray = arrayData.toObjectArray(baseType) + (true, numericArray.map { x => + baseType.numeric.toDouble(x.asInstanceOf[baseType.InternalType]).asInstanceOf[Number] + }) case other => throw new AnalysisException(s"Invalid data type ${other._1} for parameter percentage") } - require(values.size > 0, s"Percentage values should not be empty.") - require(values.forall(value => value.doubleValue() >= 0.0 && value.doubleValue() <= 1.0), s"Percentage values must be between 0.0 and 1.0, current values = ${values.mkString(", ")}") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala new file mode 100644 index 0000000000000..b4a0827b00bd3 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.types._ + +class PercentileSuite extends SparkFunSuite { + + test("high level interface, update, merge, eval...") { + val count = 10000 + val data = (1 to count) + val percentages = Array(0, 0.25, 0.5, 0.75, 1) + val expectedPercentiles = Array(1, 2500.75, 5000.5, 7500.25, 10000) + val childExpression = Cast(BoundReference(0, IntegerType, nullable = false), DoubleType) + val percentageExpression = CreateArray(percentages.toSeq.map(Literal(_))) + val agg = new Percentile(childExpression, percentageExpression) + + assert(agg.nullable) + val group = (0 until data.length) + // Don't use group buffer for now. + val groupBuffer = InternalRow.empty + group.foreach { index => + val input = InternalRow(data(index)) + agg.update(groupBuffer, input) + } + + // Don't support partial aggregations for now. + val mergeBuffer = InternalRow.empty + agg.eval(mergeBuffer) match { + case arrayData: ArrayData => + val percentiles = arrayData.toDoubleArray() + assert(percentiles.zip(expectedPercentiles) + .forall(pair => pair._1 == pair._2)) + } + } + + test("low level interface, update, merge, eval...") { + val childExpression = Cast(BoundReference(0, IntegerType, nullable = true), DoubleType) + val inputAggregationBufferOffset = 1 + val mutableAggregationBufferOffset = 2 + val percentage = 0.5 + + // Test update. + val agg = new Percentile(childExpression, Literal(percentage)) + .withNewInputAggBufferOffset(inputAggregationBufferOffset) + .withNewMutableAggBufferOffset(mutableAggregationBufferOffset) + + val mutableAggBuffer = new GenericInternalRow( + new Array[Any](mutableAggregationBufferOffset + 1)) + agg.initialize(mutableAggBuffer) + val dataCount = 10 + (1 to dataCount).foreach { data => + agg.update(mutableAggBuffer, InternalRow(data)) + } + + // Test eval + val expectedPercentile = 5.5 + assert(agg.eval(mutableAggBuffer).asInstanceOf[Double] == expectedPercentile) + } + + test("call from sql query") { + // sql, single percentile + assertEqual( + s"percentile(`a`, 0.5D)", + new Percentile("a".attr, Literal(0.5)).sql: String) + + // sql, array of percentile + assertEqual( + s"percentile(`a`, array(0.25D, 0.5D, 0.75D))", + new Percentile("a".attr, CreateArray(Seq(0.25, 0.5, 0.75).map(Literal(_)))).sql: String) + + // sql(isDistinct = false), single percentile + assertEqual( + s"percentile(`a`, 0.5D)", + new Percentile("a".attr, Literal(0.5)).sql(isDistinct = false)) + + // sql(isDistinct = false), array of percentile + assertEqual( + s"percentile(`a`, array(0.25D, 0.5D, 0.75D))", + new Percentile("a".attr, CreateArray(Seq(0.25, 0.5, 0.75).map(Literal(_)))) + .sql(isDistinct = false)) + + // sql(isDistinct = true), single percentile + assertEqual( + s"percentile(DISTINCT `a`, 0.5D)", + new Percentile("a".attr, Literal(0.5)).sql(isDistinct = true)) + + // sql(isDistinct = true), array of percentile + assertEqual( + s"percentile(DISTINCT `a`, array(0.25D, 0.5D, 0.75D))", + new Percentile("a".attr, CreateArray(Seq(0.25, 0.5, 0.75).map(Literal(_)))) + .sql(isDistinct = true)) + } + + test("fail analysis if childExpression is invalid") { + val validDataTypes = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, + NullType) + val percentage = Literal(0.5) + + validDataTypes.foreach { dataType => + val child = AttributeReference("a", dataType)() + val percentile = new Percentile(child, percentage) + assertEqual(percentile.checkInputDataTypes(), TypeCheckSuccess) + } + + val invalidDataTypes = Seq(BooleanType, StringType, DateType, TimestampType, CalendarIntervalType) + + invalidDataTypes.foreach { dataType => + val child = AttributeReference("a", dataType)() + val percentile = new Percentile(child, percentage) + assertEqual(percentile.checkInputDataTypes(), + TypeCheckFailure(s"function percentile requires numeric types, not $dataType")) + } + } + + test("fails analysis if percentage(s) are invalid") { + val child = Cast(BoundReference(0, IntegerType, nullable = false), DoubleType) + val groupBuffer = InternalRow.empty + val input = InternalRow(1) + + val validPercentages = Seq(Literal(0), Literal(0.5), Literal(1), + CreateArray(Seq(0, 0.5, 1).map(Literal(_)))) + + validPercentages.foreach { percentage => + val percentile1 = new Percentile(child, percentage) + // Make sure the inputs are not empty. + percentile1.update(groupBuffer, input) + // No exception should be thrown. + assert(percentile1.eval() != null) + } + + val invalidPercentages = Seq(Literal(-0.5), Literal(1.5), Literal(2), + CreateArray(Seq(-0.5, 0, 2).map(Literal(_)))) + + invalidPercentages.foreach { percentage => + val percentile2 = new Percentile(child, percentage) + percentile2.update(groupBuffer, input) + intercept[IllegalArgumentException](percentile2.eval(), + s"Percentage values must be between 0.0 and 1.0") + } + + val nonLiteralPercentage = Literal("val") + + val percentile3 = new Percentile(child, nonLiteralPercentage) + percentile3.update(groupBuffer, input) + intercept[AnalysisException](percentile3.eval(), + s"Invalid data type ${nonLiteralPercentage.dataType} for parameter percentage") + } + + test("null handling") { + val childExpression = Cast(BoundReference(0, IntegerType, nullable = true), DoubleType) + val agg = new Percentile(childExpression, Literal(0.5)) + val buffer = new GenericInternalRow(new Array[Any](1)) + agg.initialize(buffer) + // Empty aggregation buffer + assert(agg.eval(buffer) == null) + // Empty input row + agg.update(buffer, InternalRow(null)) + assert(agg.eval(buffer) == null) + + // Add some non-empty row + agg.update(buffer, InternalRow(0)) + assert(agg.eval(buffer) != null) + } + + private def assertEqual[T](left: T, right: T): Unit = { + assert(left == right) + } +} From 8a08576980f737ac72a31a3c103d4cf24b539303 Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Tue, 25 Oct 2016 05:01:23 +0800 Subject: [PATCH 16/22] fix scala style fail. --- .../sql/catalyst/expressions/aggregate/PercentileSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala index b4a0827b00bd3..38c588d42016e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala @@ -125,7 +125,8 @@ class PercentileSuite extends SparkFunSuite { assertEqual(percentile.checkInputDataTypes(), TypeCheckSuccess) } - val invalidDataTypes = Seq(BooleanType, StringType, DateType, TimestampType, CalendarIntervalType) + val invalidDataTypes = Seq(BooleanType, StringType, DateType, TimestampType, + CalendarIntervalType) invalidDataTypes.foreach { dataType => val child = AttributeReference("a", dataType)() From 7731066d1cc59137101f333f77194a4202522294 Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Wed, 23 Nov 2016 23:23:41 +0800 Subject: [PATCH 17/22] rewrite Percentile using the TypedImperativeAggregate interface. --- .../expressions/aggregate/Percentile.scala | 177 +++++++++++------- .../aggregate/PercentileSuite.scala | 103 ++++++++-- 2 files changed, 191 insertions(+), 89 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index 136991ec495d1..78a113c94b47d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -17,10 +17,15 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import java.nio.ByteBuffer + +import org.apache.spark.SparkConf +import org.apache.spark.serializer.{KryoSerializer, SerializerInstance} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.Percentile.Countings import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashMap @@ -53,7 +58,7 @@ case class Percentile( child: Expression, percentageExpression: Expression, mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) extends ImperativeAggregate { + inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[Countings] { def this(child: Expression, percentageExpression: Expression) = { this(child, percentageExpression, 0, 0) @@ -61,14 +66,12 @@ case class Percentile( override def prettyName: String = "percentile" - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): Percentile = copy(mutableAggBufferOffset = newMutableAggBufferOffset) - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): Percentile = copy(inputAggBufferOffset = newInputAggBufferOffset) - private var counts = new OpenHashMap[Number, Long] - // Mark as lazy so that percentageExpression is not evaluated during tree transformation. private lazy val (returnPercentileArray: Boolean, percentages: Seq[Number]) = evalPercentages(percentageExpression) @@ -87,18 +90,9 @@ case class Percentile( override def checkInputDataTypes(): TypeCheckResult = TypeUtils.checkForNumericExpr(child.dataType, "function percentile") - override def supportsPartial: Boolean = false - - override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) - - override val aggBufferAttributes: Seq[AttributeReference] = Nil - - override val inputAggBufferAttributes: Seq[AttributeReference] = Nil - - override def initialize(buffer: InternalRow): Unit = { - // The counts OpenHashMap will contain values of other groups if we don't initialize it here. - // Since OpenHashMap doesn't support deletions, we have to create a new instance. - counts = new OpenHashMap[Number, Long] + override def createAggregationBuffer(): Countings = { + // Initialize new Countings instance here. + Countings() } private def evalPercentages(expr: Expression): (Boolean, Seq[Number]) = { @@ -120,39 +114,20 @@ case class Percentile( (isArrayType, values) } - override def update(buffer: InternalRow, input: InternalRow): Unit = { + override def update(buffer: Countings, input: InternalRow): Unit = { val key = child.eval(input).asInstanceOf[Number] - - // Null values are ignored when computing percentiles. - if (key != null) { - counts.changeValue(key, 1L, _ + 1L) - } + buffer.add(key) } - override def merge(buffer: InternalRow, inputBuffer: InternalRow): Unit = { - sys.error("Percentile cannot be used in partial aggregations.") + override def merge(buffer: Countings, other: Countings): Unit = { + buffer.merge(other) } - override def eval(buffer: InternalRow): Any = { - if (counts.isEmpty) { - return generateOutput(Seq.empty) - } - - val sortedCounts = counts.toSeq.sortBy(_._1)(new Ordering[Number]() { - override def compare(a: Number, b: Number): Int = - scala.math.signum(a.doubleValue() - b.doubleValue()).toInt - }) - val aggreCounts = sortedCounts.scanLeft(sortedCounts.head._1, 0L) { - (k1: (Number, Long), k2: (Number, Long)) => (k2._1, k1._2 + k2._2) - }.tail - val maxPosition = aggreCounts.last._2 - 1 - - generateOutput(percentages.map { percentile => - getPercentile(aggreCounts, maxPosition * percentile.doubleValue()).doubleValue() - }) + override def eval(buffer: Countings): Any = { + generateOutput(buffer.getPercentiles(percentages)) } - private def generateOutput(results: Seq[Double]): Any = { + private def generateOutput(results: Seq[Number]): Any = { if (results.isEmpty) { null } else if (returnPercentileArray) { @@ -162,40 +137,104 @@ case class Percentile( } } + override def serialize(obj: Countings): Array[Byte] = { + Percentile.serializer.serialize(obj).array() + } + + override def deserialize(bytes: Array[Byte]): Countings = { + Percentile.serializer.deserialize[Countings](ByteBuffer.wrap(bytes)) + } +} + +object Percentile { + object Countings { + def apply(): Countings = Countings(new OpenHashMap[Number, Long]) + + def apply(counts: OpenHashMap[Number, Long]): Countings = new Countings(counts) + } + /** - * Get the percentile value. + * A class that stores the numbers and their counts, used to support [[Percentile]] function. */ - private def getPercentile(aggreCounts: Seq[(Number, Long)], position: Double): Number = { - // We may need to do linear interpolation to get the exact percentile - val lower = position.floor - val higher = position.ceil - - // Linear search since this won't take much time from the total execution anyway - // lower has the range of [0 .. total-1] - // The first entry with accumulated count (lower+1) corresponds to the lower position. - var i = 0 - while (aggreCounts(i)._2 < lower + 1) { - i += 1 + class Countings(val counts: OpenHashMap[Number, Long]) extends Serializable { + /** + * Insert a key into countings map. + */ + def add(key: Number): Unit = { + // Null values are ignored in countings. + if (key != null) { + counts.changeValue(key, 1L, _ + 1L) + } } - val lowerKey = aggreCounts(i)._1 - if (higher == lower) { - // no interpolation needed because position does not have a fraction - return lowerKey + /** + * In place merges in another Countings. + */ + def merge(other: Countings): Unit = { + other.counts.foreach { pair => + counts.changeValue(pair._1, pair._2, _ + pair._2) + } } - if (aggreCounts(i)._2 < higher + 1) { - i += 1 + /** + * Get the percentile value for every percentile in `percentages`. + */ + def getPercentiles(percentages: Seq[Number]): Seq[Number] = { + if (counts.isEmpty) { + return Seq.empty + } + + val sortedCounts = counts.toSeq.sortBy(_._1)(new Ordering[Number]() { + override def compare(a: Number, b: Number): Int = + scala.math.signum(a.doubleValue() - b.doubleValue()).toInt + }) + val aggreCounts = sortedCounts.scanLeft(sortedCounts.head._1, 0L) { + (k1: (Number, Long), k2: (Number, Long)) => (k2._1, k1._2 + k2._2) + }.tail + val maxPosition = aggreCounts.last._2 - 1 + + percentages.map { percentile => + getPercentile(aggreCounts, maxPosition * percentile.doubleValue()) + } } - val higherKey = aggreCounts(i)._1 - if (higherKey == lowerKey) { - // no interpolation needed because lower position and higher position has the same key - return lowerKey + /** + * Get the percentile value. + */ + private def getPercentile(aggreCounts: Seq[(Number, Long)], position: Double): Number = { + // We may need to do linear interpolation to get the exact percentile + val lower = position.floor + val higher = position.ceil + + // Linear search since this won't take much time from the total execution anyway + // lower has the range of [0 .. total-1] + // The first entry with accumulated count (lower+1) corresponds to the lower position. + var i = 0 + while (aggreCounts(i)._2 < lower + 1) { + i += 1 + } + + val lowerKey = aggreCounts(i)._1 + if (higher == lower) { + // no interpolation needed because position does not have a fraction + return lowerKey + } + + if (aggreCounts(i)._2 < higher + 1) { + i += 1 + } + val higherKey = aggreCounts(i)._1 + + if (higherKey == lowerKey) { + // no interpolation needed because lower position and higher position has the same key + return lowerKey + } + + // Linear interpolation to get the exact percentile + return (higher - position) * lowerKey.doubleValue() + + (position - lower) * higherKey.doubleValue() } - - // Linear interpolation to get the exact percentile - return (higher - position) * lowerKey.doubleValue() + - (position - lower) * higherKey.doubleValue() } + + val serializer: SerializerInstance = new KryoSerializer(new SparkConf).newInstance() } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala index 38c588d42016e..3da6b819080f5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala @@ -23,31 +23,78 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.Percentile.Countings import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types._ class PercentileSuite extends SparkFunSuite { - test("high level interface, update, merge, eval...") { + private val random = new java.util.Random() + + private val data = (0 until 10000).map { _ => + random.nextInt(10000) + } + + test("serialize and de-serialize") { + val serializer = Percentile.serializer + + // Check empty serialize and de-serialize + val emptyBuffer = Countings() + assert(compareEquals(emptyBuffer, serializer.deserialize(serializer.serialize(emptyBuffer)))) + + val buffer = Countings() + data.foreach { value => + buffer.add(value) + } + assert(compareEquals(buffer, serializer.deserialize(serializer.serialize(buffer)))) + + val agg = new Percentile(BoundReference(0, DoubleType, true), Literal(0.5)) + assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer)) + } + + test("class Countings, basic operations") { + val valueCount = 10000 + val percentages = Seq[Number](0, 0.25, 0.5, 0.75, 1) + val buffer = Countings() + (1 to valueCount).grouped(10).foreach { group => + val partialBuffer = Countings() + group.foreach(x => partialBuffer.add(x)) + buffer.merge(partialBuffer) + } + val expectedPercentiles = Seq(1, 2500.75, 5000.5, 7500.25, 10000) + val percentiles = buffer.getPercentiles(percentages) + assert(percentiles.zip(expectedPercentiles) + .forall(pair => pair._1 == pair._2)) + } + + test("class Percentile, high level interface, update, merge, eval...") { val count = 10000 val data = (1 to count) - val percentages = Array(0, 0.25, 0.5, 0.75, 1) - val expectedPercentiles = Array(1, 2500.75, 5000.5, 7500.25, 10000) + val percentages = Seq(0, 0.25, 0.5, 0.75, 1) + val expectedPercentiles = Seq(1, 2500.75, 5000.5, 7500.25, 10000) val childExpression = Cast(BoundReference(0, IntegerType, nullable = false), DoubleType) val percentageExpression = CreateArray(percentages.toSeq.map(Literal(_))) val agg = new Percentile(childExpression, percentageExpression) assert(agg.nullable) - val group = (0 until data.length) - // Don't use group buffer for now. - val groupBuffer = InternalRow.empty - group.foreach { index => + val group1 = (0 until data.length / 2) + val group1Buffer = agg.createAggregationBuffer() + group1.foreach { index => val input = InternalRow(data(index)) - agg.update(groupBuffer, input) + agg.update(group1Buffer, input) } - // Don't support partial aggregations for now. - val mergeBuffer = InternalRow.empty + val group2 = (data.length / 2 until data.length) + val group2Buffer = agg.createAggregationBuffer() + group2.foreach { index => + val input = InternalRow(data(index)) + agg.update(group2Buffer, input) + } + + val mergeBuffer = agg.createAggregationBuffer() + agg.merge(mergeBuffer, group1Buffer) + agg.merge(mergeBuffer, group2Buffer) + agg.eval(mergeBuffer) match { case arrayData: ArrayData => val percentiles = arrayData.toDoubleArray() @@ -56,13 +103,13 @@ class PercentileSuite extends SparkFunSuite { } } - test("low level interface, update, merge, eval...") { + test("class Percentile, low level interface, update, merge, eval...") { val childExpression = Cast(BoundReference(0, IntegerType, nullable = true), DoubleType) val inputAggregationBufferOffset = 1 val mutableAggregationBufferOffset = 2 val percentage = 0.5 - // Test update. + // Phase one, partial mode aggregation val agg = new Percentile(childExpression, Literal(percentage)) .withNewInputAggBufferOffset(inputAggregationBufferOffset) .withNewMutableAggBufferOffset(mutableAggregationBufferOffset) @@ -74,8 +121,16 @@ class PercentileSuite extends SparkFunSuite { (1 to dataCount).foreach { data => agg.update(mutableAggBuffer, InternalRow(data)) } + agg.serializeAggregateBufferInPlace(mutableAggBuffer) - // Test eval + // Serialize the aggregation buffer + val serialized = mutableAggBuffer.getBinary(mutableAggregationBufferOffset) + val inputAggBuffer = new GenericInternalRow(Array[Any](null, serialized)) + + // Phase 2: final mode aggregation + // Re-initialize the aggregation buffer + agg.initialize(mutableAggBuffer) + agg.merge(mutableAggBuffer, inputAggBuffer) val expectedPercentile = 5.5 assert(agg.eval(mutableAggBuffer).asInstanceOf[Double] == expectedPercentile) } @@ -138,7 +193,6 @@ class PercentileSuite extends SparkFunSuite { test("fails analysis if percentage(s) are invalid") { val child = Cast(BoundReference(0, IntegerType, nullable = false), DoubleType) - val groupBuffer = InternalRow.empty val input = InternalRow(1) val validPercentages = Seq(Literal(0), Literal(0.5), Literal(1), @@ -146,10 +200,11 @@ class PercentileSuite extends SparkFunSuite { validPercentages.foreach { percentage => val percentile1 = new Percentile(child, percentage) + val group1Buffer = percentile1.createAggregationBuffer() // Make sure the inputs are not empty. - percentile1.update(groupBuffer, input) + percentile1.update(group1Buffer, input) // No exception should be thrown. - assert(percentile1.eval() != null) + assert(percentile1.eval(group1Buffer) != null) } val invalidPercentages = Seq(Literal(-0.5), Literal(1.5), Literal(2), @@ -157,16 +212,18 @@ class PercentileSuite extends SparkFunSuite { invalidPercentages.foreach { percentage => val percentile2 = new Percentile(child, percentage) - percentile2.update(groupBuffer, input) - intercept[IllegalArgumentException](percentile2.eval(), + val group2Buffer = percentile2.createAggregationBuffer() + percentile2.update(group2Buffer, input) + intercept[IllegalArgumentException](percentile2.eval(group2Buffer), s"Percentage values must be between 0.0 and 1.0") } val nonLiteralPercentage = Literal("val") val percentile3 = new Percentile(child, nonLiteralPercentage) - percentile3.update(groupBuffer, input) - intercept[AnalysisException](percentile3.eval(), + val group3Buffer = percentile3.createAggregationBuffer() + percentile3.update(group3Buffer, input) + intercept[AnalysisException](percentile3.eval(group3Buffer), s"Invalid data type ${nonLiteralPercentage.dataType} for parameter percentage") } @@ -186,6 +243,12 @@ class PercentileSuite extends SparkFunSuite { assert(agg.eval(buffer) != null) } + private def compareEquals(left: Countings, right: Countings): Boolean = { + left.counts.size == right.counts.size && left.counts.forall { pair => + right.counts.apply(pair._1) == pair._2 + } + } + private def assertEqual[T](left: T, right: T): Unit = { assert(left == right) } From 4ace3bc5402866761c6f1b61600c4d4e71321598 Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Thu, 24 Nov 2016 14:28:55 +0800 Subject: [PATCH 18/22] fix class cast exception for output. --- .../sql/catalyst/expressions/aggregate/Percentile.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index 78a113c94b47d..a13fd4708be9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -127,7 +127,7 @@ case class Percentile( generateOutput(buffer.getPercentiles(percentages)) } - private def generateOutput(results: Seq[Number]): Any = { + private def generateOutput(results: Seq[Double]): Any = { if (results.isEmpty) { null } else if (returnPercentileArray) { @@ -179,7 +179,7 @@ object Percentile { /** * Get the percentile value for every percentile in `percentages`. */ - def getPercentiles(percentages: Seq[Number]): Seq[Number] = { + def getPercentiles(percentages: Seq[Number]): Seq[Double] = { if (counts.isEmpty) { return Seq.empty } @@ -194,7 +194,7 @@ object Percentile { val maxPosition = aggreCounts.last._2 - 1 percentages.map { percentile => - getPercentile(aggreCounts, maxPosition * percentile.doubleValue()) + getPercentile(aggreCounts, maxPosition * percentile.doubleValue()).doubleValue() } } From e01d0b2ca5c7a0f7d693d48cb2dbefeed569fec0 Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Fri, 25 Nov 2016 12:25:47 +0800 Subject: [PATCH 19/22] Implement serializer for Percentile. --- .../expressions/aggregate/Percentile.scala | 66 +++++++++++++++++-- .../aggregate/PercentileSuite.scala | 8 ++- 2 files changed, 64 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index a13fd4708be9e..c0f6ec40541c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -17,10 +17,6 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import java.nio.ByteBuffer - -import org.apache.spark.SparkConf -import org.apache.spark.serializer.{KryoSerializer, SerializerInstance} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult @@ -28,8 +24,10 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.Percentile.Countings import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET import org.apache.spark.util.collection.OpenHashMap + /** * The Percentile aggregate function returns the exact percentile(s) of numeric column `expr` at * the given percentage(s) with value range in [0.0, 1.0]. @@ -138,11 +136,11 @@ case class Percentile( } override def serialize(obj: Countings): Array[Byte] = { - Percentile.serializer.serialize(obj).array() + Percentile.serializer.serialize(obj, child.dataType) } override def deserialize(bytes: Array[Byte]): Countings = { - Percentile.serializer.deserialize[Countings](ByteBuffer.wrap(bytes)) + Percentile.serializer.deserialize(bytes, child.dataType) } } @@ -236,5 +234,59 @@ object Percentile { } } - val serializer: SerializerInstance = new KryoSerializer(new SparkConf).newInstance() + + /** + * Serializer for class [[Countings]] + * + * This class is thread safe. + */ + class CountingsSerializer { + + final def serialize(obj: Countings, dataType: DataType): Array[Byte] = { + val counts = obj.counts + + // Write the size of counts map. + val sizeProjection = UnsafeProjection.create(Array[DataType](IntegerType)) + val row = InternalRow.apply(counts.size) + var buffer = sizeProjection.apply(row).getBytes + + // Write the pairs of counts map. + val projection = UnsafeProjection.create(Array[DataType](dataType, LongType)) + counts.foreach { pair => + val row = InternalRow.apply(pair._1, pair._2) + val unsafeRow = projection.apply(row) + buffer ++= unsafeRow.getBytes + } + + buffer + } + + final def deserialize(bytes: Array[Byte], dataType: DataType): Countings = { + val counts = new OpenHashMap[Number, Long] + var offset = 0 + + // Read the size of counts map + val sizeRow = new UnsafeRow(1) + val rowSizeInBytes = UnsafeRow.calculateFixedPortionByteSize(1) + sizeRow.pointTo(bytes, rowSizeInBytes) + val size = sizeRow.get(0, IntegerType).asInstanceOf[Integer] + offset += rowSizeInBytes + + // Read the pairs of counts map + val row = new UnsafeRow(2) + val pairRowSizeInBytes = UnsafeRow.calculateFixedPortionByteSize(2) + var i = 0 + while (i < size) { + row.pointTo(bytes, offset + BYTE_ARRAY_OFFSET, pairRowSizeInBytes) + val key = row.get(0, dataType).asInstanceOf[Number] + val count = row.get(1, LongType).asInstanceOf[Long] + offset += pairRowSizeInBytes + counts.update(key, count) + i += 1 + } + Countings(counts) + } + } + + val serializer: CountingsSerializer = new CountingsSerializer } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala index 3da6b819080f5..be03c6e7646e1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala @@ -40,15 +40,17 @@ class PercentileSuite extends SparkFunSuite { // Check empty serialize and de-serialize val emptyBuffer = Countings() - assert(compareEquals(emptyBuffer, serializer.deserialize(serializer.serialize(emptyBuffer)))) + assert(compareEquals(emptyBuffer, + serializer.deserialize(serializer.serialize(emptyBuffer, DoubleType), DoubleType))) val buffer = Countings() data.foreach { value => buffer.add(value) } - assert(compareEquals(buffer, serializer.deserialize(serializer.serialize(buffer)))) + assert(compareEquals(buffer, + serializer.deserialize(serializer.serialize(buffer, IntegerType), IntegerType))) - val agg = new Percentile(BoundReference(0, DoubleType, true), Literal(0.5)) + val agg = new Percentile(BoundReference(0, IntegerType, true), Literal(0.5)) assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer)) } From b0aabf9824b85f1d249b25870ccda9a3a79d9691 Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Sat, 26 Nov 2016 21:06:45 +0800 Subject: [PATCH 20/22] code refactor. --- .../expressions/aggregate/Percentile.scala | 204 ++++++++++-------- .../aggregate/PercentileSuite.scala | 52 +++-- 2 files changed, 149 insertions(+), 107 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index c0f6ec40541c2..0f173e74979f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -17,14 +17,17 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import java.util + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.Percentile.Countings import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET import org.apache.spark.util.collection.OpenHashMap @@ -71,45 +74,62 @@ case class Percentile( copy(inputAggBufferOffset = newInputAggBufferOffset) // Mark as lazy so that percentageExpression is not evaluated during tree transformation. - private lazy val (returnPercentileArray: Boolean, percentages: Seq[Number]) = - evalPercentages(percentageExpression) + private lazy val returnPercentileArray = percentageExpression.dataType.isInstanceOf[ArrayType] + + @transient + private lazy val percentages = evalPercentages(percentageExpression) override def children: Seq[Expression] = child :: percentageExpression :: Nil // Returns null for empty inputs override def nullable: Boolean = true - override def dataType: DataType = - if (returnPercentileArray) ArrayType(DoubleType) else DoubleType + override lazy val dataType: DataType = percentageExpression.dataType match { + case _: ArrayType => ArrayType(DoubleType, false) + case _ => DoubleType + } - override def inputTypes: Seq[AbstractDataType] = - Seq(NumericType, TypeCollection(NumericType, ArrayType)) + override def inputTypes: Seq[AbstractDataType] = percentageExpression.dataType match { + case _: ArrayType => Seq(NumericType, ArrayType(DoubleType, false)) + case _ => Seq(NumericType, DoubleType) + } - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function percentile") + // Check the inputTypes are valid, and the percentageExpression satisfies: + // 1. percentageExpression must be foldable; + // 2. percentages(s) must be in the range [0.0, 1.0]. + override def checkInputDataTypes(): TypeCheckResult = { + // Validate the inputTypes + val defaultCheck = super.checkInputDataTypes() + if (defaultCheck.isFailure) { + defaultCheck + } else if (!percentageExpression.foldable) { + // percentageExpression must be foldable + TypeCheckFailure(s"The percentage(s) must be a constant literal, " + + s"but got ${percentageExpression}") + } else if (percentages.exists(percentage => percentage < 0.0 || percentage > 1.0)) { + // percentages(s) must be in the range [0.0, 1.0] + TypeCheckFailure(s"Percentage(s) must be between 0.0 and 1.0, " + + s"but got ${percentageExpression}") + } else { + TypeCheckSuccess + } + } override def createAggregationBuffer(): Countings = { // Initialize new Countings instance here. Countings() } - private def evalPercentages(expr: Expression): (Boolean, Seq[Number]) = { - val (isArrayType, values) = (expr.dataType, expr.eval()) match { - case (_, n: Number) => (false, Array(n)) - case (_, d: Decimal) => (false, Array(d.toDouble.asInstanceOf[Number])) - case (ArrayType(baseType: NumericType, _), arrayData: ArrayData) => - val numericArray = arrayData.toObjectArray(baseType) - (true, numericArray.map { x => - baseType.numeric.toDouble(x.asInstanceOf[baseType.InternalType]).asInstanceOf[Number] - }) - case other => - throw new AnalysisException(s"Invalid data type ${other._1} for parameter percentage") - } - - require(values.forall(value => value.doubleValue() >= 0.0 && value.doubleValue() <= 1.0), - s"Percentage values must be between 0.0 and 1.0, current values = ${values.mkString(", ")}") - - (isArrayType, values) + private def evalPercentages(expr: Expression): Seq[Double] = (expr.dataType, expr.eval()) match { + case (_, n: Number) => Array(n.doubleValue()) + case (_, d: Decimal) => Array(d.toDouble) + case (ArrayType(baseType: NumericType, _), arrayData: ArrayData) => + val numericArray = arrayData.toObjectArray(baseType) + numericArray.map { x => + baseType.numeric.toDouble(x.asInstanceOf[baseType.InternalType]) + } + case other => + throw new AnalysisException(s"Invalid data type ${other._1} for parameter percentage") } override def update(buffer: Countings, input: InternalRow): Unit = { @@ -145,6 +165,7 @@ case class Percentile( } object Percentile { + object Countings { def apply(): Countings = Countings(new OpenHashMap[Number, Long]) @@ -177,7 +198,7 @@ object Percentile { /** * Get the percentile value for every percentile in `percentages`. */ - def getPercentiles(percentages: Seq[Number]): Seq[Double] = { + def getPercentiles(percentages: Seq[Double]): Seq[Double] = { if (counts.isEmpty) { return Seq.empty } @@ -186,43 +207,41 @@ object Percentile { override def compare(a: Number, b: Number): Int = scala.math.signum(a.doubleValue() - b.doubleValue()).toInt }) - val aggreCounts = sortedCounts.scanLeft(sortedCounts.head._1, 0L) { - (k1: (Number, Long), k2: (Number, Long)) => (k2._1, k1._2 + k2._2) - }.tail + var sum = 0L + val aggreCounts = sortedCounts.map { case (key, count) => + sum += count + (key, sum) + } val maxPosition = aggreCounts.last._2 - 1 percentages.map { percentile => - getPercentile(aggreCounts, maxPosition * percentile.doubleValue()).doubleValue() + getPercentile(aggreCounts, maxPosition * percentile).doubleValue() } } /** * Get the percentile value. + * + * This function has been based upon similar function from HIVE + * `org.apache.hadoop.hive.ql.udf.UDAFPercentile.getPercentile()`. */ private def getPercentile(aggreCounts: Seq[(Number, Long)], position: Double): Number = { // We may need to do linear interpolation to get the exact percentile - val lower = position.floor - val higher = position.ceil - - // Linear search since this won't take much time from the total execution anyway - // lower has the range of [0 .. total-1] - // The first entry with accumulated count (lower+1) corresponds to the lower position. - var i = 0 - while (aggreCounts(i)._2 < lower + 1) { - i += 1 - } + val lower = position.floor.toLong + val higher = position.ceil.toLong + + // Use binary search to find the lower and the higher position. + val countsArray = aggreCounts.map(_._2).toArray[Long] + val lowerIndex = binarySearchCount(countsArray, 0, aggreCounts.size, lower + 1) + val higherIndex = binarySearchCount(countsArray, 0, aggreCounts.size, higher + 1) - val lowerKey = aggreCounts(i)._1 + val lowerKey = aggreCounts(lowerIndex)._1 if (higher == lower) { // no interpolation needed because position does not have a fraction return lowerKey } - if (aggreCounts(i)._2 < higher + 1) { - i += 1 - } - val higherKey = aggreCounts(i)._1 - + val higherKey = aggreCounts(higherIndex)._1 if (higherKey == lowerKey) { // no interpolation needed because lower position and higher position has the same key return lowerKey @@ -232,8 +251,18 @@ object Percentile { return (higher - position) * lowerKey.doubleValue() + (position - lower) * higherKey.doubleValue() } - } + /** + * use a binary search to find the index of the position closest to the current value. + */ + private def binarySearchCount( + countsArray: Array[Long], start: Int, end: Int, value: Long): Int = { + util.Arrays.binarySearch(countsArray, 0, end, value) match { + case ix if ix < 0 => -(ix + 1) + case ix => ix + } + } + } /** * Serializer for class [[Countings]] @@ -243,48 +272,53 @@ object Percentile { class CountingsSerializer { final def serialize(obj: Countings, dataType: DataType): Array[Byte] = { - val counts = obj.counts - - // Write the size of counts map. - val sizeProjection = UnsafeProjection.create(Array[DataType](IntegerType)) - val row = InternalRow.apply(counts.size) - var buffer = sizeProjection.apply(row).getBytes - - // Write the pairs of counts map. - val projection = UnsafeProjection.create(Array[DataType](dataType, LongType)) - counts.foreach { pair => - val row = InternalRow.apply(pair._1, pair._2) - val unsafeRow = projection.apply(row) - buffer ++= unsafeRow.getBytes + val buffer = new Array[Byte](4 << 10) // 4K + val bos = new ByteArrayOutputStream() + val out = new DataOutputStream(bos) + try { + val counts = obj.counts + val projection = UnsafeProjection.create(Array[DataType](dataType, LongType)) + // Write pairs in counts map to byte buffer. + counts.foreach { case (key, count) => + val row = InternalRow.apply(key, count) + val unsafeRow = projection.apply(row) + out.writeInt(unsafeRow.getSizeInBytes) + unsafeRow.writeToStream(out, buffer) + } + out.writeInt(-1) + out.flush() + + bos.toByteArray + } finally { + out.close() + bos.close() } - - buffer } final def deserialize(bytes: Array[Byte], dataType: DataType): Countings = { - val counts = new OpenHashMap[Number, Long] - var offset = 0 - - // Read the size of counts map - val sizeRow = new UnsafeRow(1) - val rowSizeInBytes = UnsafeRow.calculateFixedPortionByteSize(1) - sizeRow.pointTo(bytes, rowSizeInBytes) - val size = sizeRow.get(0, IntegerType).asInstanceOf[Integer] - offset += rowSizeInBytes - - // Read the pairs of counts map - val row = new UnsafeRow(2) - val pairRowSizeInBytes = UnsafeRow.calculateFixedPortionByteSize(2) - var i = 0 - while (i < size) { - row.pointTo(bytes, offset + BYTE_ARRAY_OFFSET, pairRowSizeInBytes) - val key = row.get(0, dataType).asInstanceOf[Number] - val count = row.get(1, LongType).asInstanceOf[Long] - offset += pairRowSizeInBytes - counts.update(key, count) - i += 1 + val bis = new ByteArrayInputStream(bytes) + val ins = new DataInputStream(bis) + try { + val counts = new OpenHashMap[Number, Long] + // Read unsafeRow size and content in bytes. + var sizeOfNextRow = ins.readInt() + while (sizeOfNextRow >= 0) { + val bs = new Array[Byte](sizeOfNextRow) + ins.readFully(bs) + val row = new UnsafeRow(2) + row.pointTo(bs, sizeOfNextRow) + // Insert the pairs into counts map. + val key = row.get(0, dataType).asInstanceOf[Number] + val count = row.get(1, LongType).asInstanceOf[Long] + counts.update(key, count) + sizeOfNextRow = ins.readInt() + } + + Countings(counts) + } finally { + ins.close() + bis.close() } - Countings(counts) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala index be03c6e7646e1..b1af10860c89e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala @@ -56,7 +56,7 @@ class PercentileSuite extends SparkFunSuite { test("class Countings, basic operations") { val valueCount = 10000 - val percentages = Seq[Number](0, 0.25, 0.5, 0.75, 1) + val percentages = Seq[Double](0, 0.25, 0.5, 0.75, 1) val buffer = Countings() (1 to valueCount).grouped(10).foreach { group => val partialBuffer = Countings() @@ -172,8 +172,7 @@ class PercentileSuite extends SparkFunSuite { } test("fail analysis if childExpression is invalid") { - val validDataTypes = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, - NullType) + val validDataTypes = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType) val percentage = Literal(0.5) validDataTypes.foreach { dataType => @@ -183,13 +182,14 @@ class PercentileSuite extends SparkFunSuite { } val invalidDataTypes = Seq(BooleanType, StringType, DateType, TimestampType, - CalendarIntervalType) + CalendarIntervalType, NullType) invalidDataTypes.foreach { dataType => val child = AttributeReference("a", dataType)() val percentile = new Percentile(child, percentage) assertEqual(percentile.checkInputDataTypes(), - TypeCheckFailure(s"function percentile requires numeric types, not $dataType")) + TypeCheckFailure(s"argument 1 requires numeric type, however, " + + s"'`a`' is of ${dataType.simpleString} type.")) } } @@ -197,36 +197,44 @@ class PercentileSuite extends SparkFunSuite { val child = Cast(BoundReference(0, IntegerType, nullable = false), DoubleType) val input = InternalRow(1) - val validPercentages = Seq(Literal(0), Literal(0.5), Literal(1), + val validPercentages = Seq(Literal(0D), Literal(0.5), Literal(1D), CreateArray(Seq(0, 0.5, 1).map(Literal(_)))) validPercentages.foreach { percentage => val percentile1 = new Percentile(child, percentage) - val group1Buffer = percentile1.createAggregationBuffer() - // Make sure the inputs are not empty. - percentile1.update(group1Buffer, input) - // No exception should be thrown. - assert(percentile1.eval(group1Buffer) != null) + assertEqual(percentile1.checkInputDataTypes(), TypeCheckSuccess) } - val invalidPercentages = Seq(Literal(-0.5), Literal(1.5), Literal(2), + val invalidPercentages = Seq(Literal(-0.5), Literal(1.5), Literal(2D), CreateArray(Seq(-0.5, 0, 2).map(Literal(_)))) invalidPercentages.foreach { percentage => val percentile2 = new Percentile(child, percentage) - val group2Buffer = percentile2.createAggregationBuffer() - percentile2.update(group2Buffer, input) - intercept[IllegalArgumentException](percentile2.eval(group2Buffer), - s"Percentage values must be between 0.0 and 1.0") + assertEqual(percentile2.checkInputDataTypes(), + TypeCheckFailure(s"Percentage(s) must be between 0.0 and 1.0, " + + s"but got ${percentage.simpleString}")) } - val nonLiteralPercentage = Literal("val") + val nonFoldablePercentage = Seq(NonFoldableLiteral(0.5), + CreateArray(Seq(0, 0.5, 1).map(NonFoldableLiteral(_)))) - val percentile3 = new Percentile(child, nonLiteralPercentage) - val group3Buffer = percentile3.createAggregationBuffer() - percentile3.update(group3Buffer, input) - intercept[AnalysisException](percentile3.eval(group3Buffer), - s"Invalid data type ${nonLiteralPercentage.dataType} for parameter percentage") + nonFoldablePercentage.foreach { percentage => + val percentile3 = new Percentile(child, percentage) + assertEqual(percentile3.checkInputDataTypes(), + TypeCheckFailure(s"The percentage(s) must be a constant literal, " + + s"but got ${percentage}")) + } + + val invalidDataTypes = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, + BooleanType, StringType, DateType, TimestampType, CalendarIntervalType, NullType) + + invalidDataTypes.foreach { dataType => + val percentage = Literal(0.5, dataType) + val percentile4 = new Percentile(child, percentage) + assertEqual(percentile4.checkInputDataTypes(), + TypeCheckFailure(s"argument 2 requires double type, however, " + + s"'0.5' is of ${dataType.simpleString} type.")) + } } test("null handling") { From 5b8cd4d5ba5b2cec4e7dac45a9831303f52a84ba Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Sun, 27 Nov 2016 00:44:24 +0800 Subject: [PATCH 21/22] remove the class Countings and CountingsSerializer --- .../expressions/aggregate/Percentile.scala | 294 +++++++----------- .../aggregate/PercentileSuite.scala | 44 +-- 2 files changed, 127 insertions(+), 211 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index 0f173e74979f6..2827df119ffd1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -20,17 +20,14 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import java.util -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.Percentile.Countings import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashMap - /** * The Percentile aggregate function returns the exact percentile(s) of numeric column `expr` at * the given percentage(s) with value range in [0.0, 1.0]. @@ -59,7 +56,7 @@ case class Percentile( child: Expression, percentageExpression: Expression, mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[Countings] { + inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[OpenHashMap[Number, Long]] { def this(child: Expression, percentageExpression: Expression) = { this(child, percentageExpression, 0, 0) @@ -74,10 +71,14 @@ case class Percentile( copy(inputAggBufferOffset = newInputAggBufferOffset) // Mark as lazy so that percentageExpression is not evaluated during tree transformation. + @transient private lazy val returnPercentileArray = percentageExpression.dataType.isInstanceOf[ArrayType] @transient - private lazy val percentages = evalPercentages(percentageExpression) + private lazy val percentages = percentageExpression.eval() match { + case p: Double => Seq(p) + case a: ArrayData => a.toDoubleArray().toSeq + } override def children: Seq[Expression] = child :: percentageExpression :: Nil @@ -104,45 +105,56 @@ case class Percentile( defaultCheck } else if (!percentageExpression.foldable) { // percentageExpression must be foldable - TypeCheckFailure(s"The percentage(s) must be a constant literal, " + - s"but got ${percentageExpression}") + TypeCheckFailure("The percentage(s) must be a constant literal, " + + s"but got $percentageExpression") } else if (percentages.exists(percentage => percentage < 0.0 || percentage > 1.0)) { // percentages(s) must be in the range [0.0, 1.0] - TypeCheckFailure(s"Percentage(s) must be between 0.0 and 1.0, " + - s"but got ${percentageExpression}") + TypeCheckFailure("Percentage(s) must be between 0.0 and 1.0, " + + s"but got $percentageExpression") } else { TypeCheckSuccess } } - override def createAggregationBuffer(): Countings = { - // Initialize new Countings instance here. - Countings() + override def createAggregationBuffer(): OpenHashMap[Number, Long] = { + // Initialize new counts map instance here. + new OpenHashMap[Number, Long]() } - private def evalPercentages(expr: Expression): Seq[Double] = (expr.dataType, expr.eval()) match { - case (_, n: Number) => Array(n.doubleValue()) - case (_, d: Decimal) => Array(d.toDouble) - case (ArrayType(baseType: NumericType, _), arrayData: ArrayData) => - val numericArray = arrayData.toObjectArray(baseType) - numericArray.map { x => - baseType.numeric.toDouble(x.asInstanceOf[baseType.InternalType]) - } - case other => - throw new AnalysisException(s"Invalid data type ${other._1} for parameter percentage") + override def update(buffer: OpenHashMap[Number, Long], input: InternalRow): Unit = { + val key = child.eval(input).asInstanceOf[Number] + + // Null values are ignored in counts map. + if (key != null) { + buffer.changeValue(key, 1L, _ + 1L) + } } - override def update(buffer: Countings, input: InternalRow): Unit = { - val key = child.eval(input).asInstanceOf[Number] - buffer.add(key) + override def merge(buffer: OpenHashMap[Number, Long], other: OpenHashMap[Number, Long]): Unit = { + other.foreach { case (key, count) => + buffer.changeValue(key, count, _ + count) + } } - override def merge(buffer: Countings, other: Countings): Unit = { - buffer.merge(other) + override def eval(buffer: OpenHashMap[Number, Long]): Any = { + generateOutput(getPercentiles(buffer)) } - override def eval(buffer: Countings): Any = { - generateOutput(buffer.getPercentiles(percentages)) + private def getPercentiles(buffer: OpenHashMap[Number, Long]): Seq[Double] = { + if (buffer.isEmpty) { + return Seq.empty + } + + val sortedCounts = buffer.toSeq.sortBy(_._1)( + child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[Number]]) + val aggreCounts = sortedCounts.scanLeft(sortedCounts.head._1, 0L) { + case ((key1, count1), (key2, count2)) => (key2, count1 + count2) + }.tail + val maxPosition = aggreCounts.last._2 - 1 + + percentages.map { percentile => + getPercentile(aggreCounts, maxPosition * percentile).doubleValue() + } } private def generateOutput(results: Seq[Double]): Any = { @@ -155,172 +167,96 @@ case class Percentile( } } - override def serialize(obj: Countings): Array[Byte] = { - Percentile.serializer.serialize(obj, child.dataType) - } - - override def deserialize(bytes: Array[Byte]): Countings = { - Percentile.serializer.deserialize(bytes, child.dataType) - } -} - -object Percentile { - - object Countings { - def apply(): Countings = Countings(new OpenHashMap[Number, Long]) - - def apply(counts: OpenHashMap[Number, Long]): Countings = new Countings(counts) - } - /** - * A class that stores the numbers and their counts, used to support [[Percentile]] function. + * Get the percentile value. + * + * This function has been based upon similar function from HIVE + * `org.apache.hadoop.hive.ql.udf.UDAFPercentile.getPercentile()`. */ - class Countings(val counts: OpenHashMap[Number, Long]) extends Serializable { - /** - * Insert a key into countings map. - */ - def add(key: Number): Unit = { - // Null values are ignored in countings. - if (key != null) { - counts.changeValue(key, 1L, _ + 1L) - } + private def getPercentile(aggreCounts: Seq[(Number, Long)], position: Double): Number = { + // We may need to do linear interpolation to get the exact percentile + val lower = position.floor.toLong + val higher = position.ceil.toLong + + // Use binary search to find the lower and the higher position. + val countsArray = aggreCounts.map(_._2).toArray[Long] + val lowerIndex = binarySearchCount(countsArray, 0, aggreCounts.size, lower + 1) + val higherIndex = binarySearchCount(countsArray, 0, aggreCounts.size, higher + 1) + + val lowerKey = aggreCounts(lowerIndex)._1 + if (higher == lower) { + // no interpolation needed because position does not have a fraction + return lowerKey } - /** - * In place merges in another Countings. - */ - def merge(other: Countings): Unit = { - other.counts.foreach { pair => - counts.changeValue(pair._1, pair._2, _ + pair._2) - } + val higherKey = aggreCounts(higherIndex)._1 + if (higherKey == lowerKey) { + // no interpolation needed because lower position and higher position has the same key + return lowerKey } - /** - * Get the percentile value for every percentile in `percentages`. - */ - def getPercentiles(percentages: Seq[Double]): Seq[Double] = { - if (counts.isEmpty) { - return Seq.empty - } - - val sortedCounts = counts.toSeq.sortBy(_._1)(new Ordering[Number]() { - override def compare(a: Number, b: Number): Int = - scala.math.signum(a.doubleValue() - b.doubleValue()).toInt - }) - var sum = 0L - val aggreCounts = sortedCounts.map { case (key, count) => - sum += count - (key, sum) - } - val maxPosition = aggreCounts.last._2 - 1 + // Linear interpolation to get the exact percentile + return (higher - position) * lowerKey.doubleValue() + + (position - lower) * higherKey.doubleValue() + } - percentages.map { percentile => - getPercentile(aggreCounts, maxPosition * percentile).doubleValue() - } + /** + * use a binary search to find the index of the position closest to the current value. + */ + private def binarySearchCount( + countsArray: Array[Long], start: Int, end: Int, value: Long): Int = { + util.Arrays.binarySearch(countsArray, 0, end, value) match { + case ix if ix < 0 => -(ix + 1) + case ix => ix } + } - /** - * Get the percentile value. - * - * This function has been based upon similar function from HIVE - * `org.apache.hadoop.hive.ql.udf.UDAFPercentile.getPercentile()`. - */ - private def getPercentile(aggreCounts: Seq[(Number, Long)], position: Double): Number = { - // We may need to do linear interpolation to get the exact percentile - val lower = position.floor.toLong - val higher = position.ceil.toLong - - // Use binary search to find the lower and the higher position. - val countsArray = aggreCounts.map(_._2).toArray[Long] - val lowerIndex = binarySearchCount(countsArray, 0, aggreCounts.size, lower + 1) - val higherIndex = binarySearchCount(countsArray, 0, aggreCounts.size, higher + 1) - - val lowerKey = aggreCounts(lowerIndex)._1 - if (higher == lower) { - // no interpolation needed because position does not have a fraction - return lowerKey - } - - val higherKey = aggreCounts(higherIndex)._1 - if (higherKey == lowerKey) { - // no interpolation needed because lower position and higher position has the same key - return lowerKey + override def serialize(obj: OpenHashMap[Number, Long]): Array[Byte] = { + val buffer = new Array[Byte](4 << 10) // 4K + val bos = new ByteArrayOutputStream() + val out = new DataOutputStream(bos) + try { + val projection = UnsafeProjection.create(Array[DataType](child.dataType, LongType)) + // Write pairs in counts map to byte buffer. + obj.foreach { case (key, count) => + val row = InternalRow.apply(key, count) + val unsafeRow = projection.apply(row) + out.writeInt(unsafeRow.getSizeInBytes) + unsafeRow.writeToStream(out, buffer) } + out.writeInt(-1) + out.flush() - // Linear interpolation to get the exact percentile - return (higher - position) * lowerKey.doubleValue() + - (position - lower) * higherKey.doubleValue() - } - - /** - * use a binary search to find the index of the position closest to the current value. - */ - private def binarySearchCount( - countsArray: Array[Long], start: Int, end: Int, value: Long): Int = { - util.Arrays.binarySearch(countsArray, 0, end, value) match { - case ix if ix < 0 => -(ix + 1) - case ix => ix - } + bos.toByteArray + } finally { + out.close() + bos.close() } } - /** - * Serializer for class [[Countings]] - * - * This class is thread safe. - */ - class CountingsSerializer { - - final def serialize(obj: Countings, dataType: DataType): Array[Byte] = { - val buffer = new Array[Byte](4 << 10) // 4K - val bos = new ByteArrayOutputStream() - val out = new DataOutputStream(bos) - try { - val counts = obj.counts - val projection = UnsafeProjection.create(Array[DataType](dataType, LongType)) - // Write pairs in counts map to byte buffer. - counts.foreach { case (key, count) => - val row = InternalRow.apply(key, count) - val unsafeRow = projection.apply(row) - out.writeInt(unsafeRow.getSizeInBytes) - unsafeRow.writeToStream(out, buffer) - } - out.writeInt(-1) - out.flush() - - bos.toByteArray - } finally { - out.close() - bos.close() + override def deserialize(bytes: Array[Byte]): OpenHashMap[Number, Long] = { + val bis = new ByteArrayInputStream(bytes) + val ins = new DataInputStream(bis) + try { + val counts = new OpenHashMap[Number, Long] + // Read unsafeRow size and content in bytes. + var sizeOfNextRow = ins.readInt() + while (sizeOfNextRow >= 0) { + val bs = new Array[Byte](sizeOfNextRow) + ins.readFully(bs) + val row = new UnsafeRow(2) + row.pointTo(bs, sizeOfNextRow) + // Insert the pairs into counts map. + val key = row.get(0, child.dataType).asInstanceOf[Number] + val count = row.get(1, LongType).asInstanceOf[Long] + counts.update(key, count) + sizeOfNextRow = ins.readInt() } - } - final def deserialize(bytes: Array[Byte], dataType: DataType): Countings = { - val bis = new ByteArrayInputStream(bytes) - val ins = new DataInputStream(bis) - try { - val counts = new OpenHashMap[Number, Long] - // Read unsafeRow size and content in bytes. - var sizeOfNextRow = ins.readInt() - while (sizeOfNextRow >= 0) { - val bs = new Array[Byte](sizeOfNextRow) - ins.readFully(bs) - val row = new UnsafeRow(2) - row.pointTo(bs, sizeOfNextRow) - // Insert the pairs into counts map. - val key = row.get(0, dataType).asInstanceOf[Number] - val count = row.get(1, LongType).asInstanceOf[Long] - counts.update(key, count) - sizeOfNextRow = ins.readInt() - } - - Countings(counts) - } finally { - ins.close() - bis.close() - } + counts + } finally { + ins.close() + bis.close() } } - - val serializer: CountingsSerializer = new CountingsSerializer } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala index b1af10860c89e..f060ecc18426a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala @@ -18,14 +18,13 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.Percentile.Countings import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types._ +import org.apache.spark.util.collection.OpenHashMap class PercentileSuite extends SparkFunSuite { @@ -36,37 +35,17 @@ class PercentileSuite extends SparkFunSuite { } test("serialize and de-serialize") { - val serializer = Percentile.serializer - - // Check empty serialize and de-serialize - val emptyBuffer = Countings() - assert(compareEquals(emptyBuffer, - serializer.deserialize(serializer.serialize(emptyBuffer, DoubleType), DoubleType))) - - val buffer = Countings() - data.foreach { value => - buffer.add(value) - } - assert(compareEquals(buffer, - serializer.deserialize(serializer.serialize(buffer, IntegerType), IntegerType))) - val agg = new Percentile(BoundReference(0, IntegerType, true), Literal(0.5)) + + // Check empty serialize and deserialize + val buffer = new OpenHashMap[Number, Long]() assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer)) - } - test("class Countings, basic operations") { - val valueCount = 10000 - val percentages = Seq[Double](0, 0.25, 0.5, 0.75, 1) - val buffer = Countings() - (1 to valueCount).grouped(10).foreach { group => - val partialBuffer = Countings() - group.foreach(x => partialBuffer.add(x)) - buffer.merge(partialBuffer) + // Check non-empty buffer serializa and deserialize. + data.foreach { key => + buffer.changeValue(key, 1L, _ + 1L) } - val expectedPercentiles = Seq(1, 2500.75, 5000.5, 7500.25, 10000) - val percentiles = buffer.getPercentiles(percentages) - assert(percentiles.zip(expectedPercentiles) - .forall(pair => pair._1 == pair._2)) + assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer)) } test("class Percentile, high level interface, update, merge, eval...") { @@ -253,9 +232,10 @@ class PercentileSuite extends SparkFunSuite { assert(agg.eval(buffer) != null) } - private def compareEquals(left: Countings, right: Countings): Boolean = { - left.counts.size == right.counts.size && left.counts.forall { pair => - right.counts.apply(pair._1) == pair._2 + private def compareEquals( + left: OpenHashMap[Number, Long], right: OpenHashMap[Number, Long]): Boolean = { + left.size == right.size && left.forall { case (key, count) => + right.apply(key) == count } } From 3c699adfee609781c1e4ce2c08493308f5e7f511 Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Mon, 28 Nov 2016 16:46:10 +0800 Subject: [PATCH 22/22] revert Percentile to accept percentages in ArrayType(NumericType). --- .../expressions/aggregate/Percentile.scala | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index 2827df119ffd1..356e088d1d665 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import java.util +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} @@ -75,9 +76,15 @@ case class Percentile( private lazy val returnPercentileArray = percentageExpression.dataType.isInstanceOf[ArrayType] @transient - private lazy val percentages = percentageExpression.eval() match { - case p: Double => Seq(p) - case a: ArrayData => a.toDoubleArray().toSeq + private lazy val percentages = + (percentageExpression.dataType, percentageExpression.eval()) match { + case (_, num: Double) => Seq(num) + case (ArrayType(baseType: NumericType, _), arrayData: ArrayData) => + val numericArray = arrayData.toObjectArray(baseType) + numericArray.map { x => + baseType.numeric.toDouble(x.asInstanceOf[baseType.InternalType])}.toSeq + case other => + throw new AnalysisException(s"Invalid data type ${other._1} for parameter percentages") } override def children: Seq[Expression] = child :: percentageExpression :: Nil @@ -91,7 +98,7 @@ case class Percentile( } override def inputTypes: Seq[AbstractDataType] = percentageExpression.dataType match { - case _: ArrayType => Seq(NumericType, ArrayType(DoubleType, false)) + case _: ArrayType => Seq(NumericType, ArrayType) case _ => Seq(NumericType, DoubleType) } @@ -147,13 +154,13 @@ case class Percentile( val sortedCounts = buffer.toSeq.sortBy(_._1)( child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[Number]]) - val aggreCounts = sortedCounts.scanLeft(sortedCounts.head._1, 0L) { + val accumlatedCounts = sortedCounts.scanLeft(sortedCounts.head._1, 0L) { case ((key1, count1), (key2, count2)) => (key2, count1 + count2) }.tail - val maxPosition = aggreCounts.last._2 - 1 + val maxPosition = accumlatedCounts.last._2 - 1 percentages.map { percentile => - getPercentile(aggreCounts, maxPosition * percentile).doubleValue() + getPercentile(accumlatedCounts, maxPosition * percentile).doubleValue() } }