From bc8ab0cae172e6e6fa7dfa045d049568468cc30d Mon Sep 17 00:00:00 2001 From: sethah Date: Wed, 16 Sep 2015 16:52:43 -0700 Subject: [PATCH 01/22] Added skewness and kurtosis aggregate functions --- .../catalyst/analysis/FunctionRegistry.scala | 2 + .../catalyst/analysis/HiveTypeCoercion.scala | 2 + .../spark/sql/catalyst/dsl/package.scala | 2 + .../expressions/aggregate/functions.scala | 208 +++++++++++++++++- .../expressions/aggregate/utils.scala | 12 + .../sql/catalyst/expressions/aggregates.scala | 57 +++++ .../sql/catalyst/expressions/arithmetic.scala | 1 + .../org/apache/spark/sql/GroupedData.scala | 26 +++ .../org/apache/spark/sql/functions.scala | 36 +++ .../spark/sql/DataFrameAggregateSuite.scala | 36 +++ 10 files changed, 381 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 3dce6c1a27e85..d513bf8cdfd63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -189,6 +189,8 @@ object FunctionRegistry { expression[StddevPop]("stddev_pop"), expression[StddevSamp]("stddev_samp"), expression[Sum]("sum"), + expression[Skewness]("skewness"), + expression[Kurtosis]("kurtosis"), // string functions expression[Ascii]("ascii"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 1140150f66864..a8300069ede55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -300,6 +300,8 @@ object HiveTypeCoercion { case Stddev(e @ StringType()) => Stddev(Cast(e, DoubleType)) case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType)) case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType)) + case Skewness(e @ StringType()) => Skewness(Cast(e, DoubleType)) + case Kurtosis(e @ StringType()) => Kurtosis(Cast(e, DoubleType)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 27b3cd84b3846..e7224ae243ece 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -162,6 +162,8 @@ package object dsl { def stddev(e: Expression): Expression = Stddev(e) def stddev_pop(e: Expression): Expression = StddevPop(e) def stddev_samp(e: Expression): Expression = StddevSamp(e) + def skewness(e: Expression): Expression = Skewness(e) + def kurtosis(e: Expression): Expression = Kurtosis(e) implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name } // TODO more implicit class for literal? diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 515246d344244..b35b7e5ef4982 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -90,6 +90,137 @@ case class Average(child: Expression) extends DeclarativeAggregate { } } +abstract class StatisticalMoments(child: Expression) extends AlgebraicAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // Return data type. + override def dataType: DataType = resultType + + // Expected input data type. + // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the + // new version at planning time (after analysis phase). For now, NullType is added at here + // to make it resolved when we have cases like `select avg(null)`. + // We can use our analyzer to cast NullType to the default data type of the NumericType once + // we remove the old aggregate functions. Then, we will not need NullType at here. + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) + + protected val resultType = DoubleType + + protected val currentM4 = AttributeReference("currentM4", resultType)() + protected val currentM3 = AttributeReference("currentM3", resultType)() + protected val currentM2 = AttributeReference("currentM2", resultType)() + protected val currentM1 = AttributeReference("currentM1", resultType)() + protected val currentM0 = AttributeReference("currentM2", resultType)() + + override val bufferAttributes = currentM4 :: currentM3 :: currentM2 :: currentM1 :: currentM0 :: Nil + + override val initialValues = Seq( + /* currentM4 = */ Cast(Literal(0), resultType), + /* currentM3 = */ Cast(Literal(0), resultType), + /* currentM2 = */ Cast(Literal(0), resultType), + /* currentM1 = */ Cast(Literal(0), resultType), + /* currentM0 = */ Cast(Literal(0), resultType) + ) + + override val updateExpressions = { + val x = Coalesce( + Cast(child, resultType) :: Cast(Literal(0), resultType) :: Nil + ) + lazy val updateM0: Expression = { + If(IsNull(child), currentM0, currentM0 + Cast(Literal(1), resultType)) + } + lazy val delta = Subtract(x, currentM1) + lazy val deltaN = Divide(delta, updateM0) + + lazy val updateM2: Expression = { + Add(currentM2, Multiply(deltaN * delta, Subtract(updateM0, Cast(Literal(1), resultType)))) + } + lazy val updateM1: Expression = { + Add(currentM1, Divide(delta, updateM0)) + } + lazy val updateM3: Expression = { + currentM3 + deltaN * deltaN * delta * (updateM0 - Cast(Literal(1), resultType)) * + (updateM0 - Cast(Literal(2), resultType)) - deltaN * currentM2 * Cast(Literal(3), resultType) + } + lazy val updateM4: Expression = { + currentM4 + deltaN * deltaN * deltaN * delta * (updateM0 - Cast(Literal(1), resultType)) * + (updateM0 * updateM0 - updateM0 * Cast(Literal(3), resultType) + Cast(Literal(3), resultType)) + + deltaN * deltaN * currentM2 * Cast(Literal(6), resultType) - deltaN * currentM3 * Cast(Literal(4), resultType) + } + + Seq( + /* currentM4 = */ If(IsNull(child), currentM4, updateM4), + /* currentM3 = */ If(IsNull(child), currentM3, updateM3), + /* currentM2 = */ If(IsNull(child), currentM2, updateM2), + /* currentMean = */ If(IsNull(child), currentM1, updateM1), + /* currentCount = */ If(IsNull(child), currentM0,updateM0) + ) + } + + override val mergeExpressions = { + + + lazy val M0 = currentM0.left + currentM0.right + lazy val delta = currentM1.right - currentM1.left + lazy val deltaN = delta / M0 + + lazy val M1 = currentM1.left + delta * (currentM0.right / M0) + + lazy val M2 = currentM2.left + currentM2.right + delta * deltaN * currentM0.left * currentM0.right + + lazy val M3 = currentM3.left + currentM3.right + delta * delta * delta * currentM0.left * + currentM0.right * (currentM0.left - currentM0.right) / (currentM0.right * currentM0.right) + + deltaN * (currentM0.left * currentM2.right - currentM0.right * currentM2.left) * + Cast(Literal(3), resultType) + + lazy val M4 = currentM4.left + currentM4.right + deltaN * deltaN * deltaN * delta * currentM0.left * + currentM0.right * (currentM0.left * currentM0.left - currentM0.left * currentM0.right + + currentM0.right * currentM0.right) + deltaN * deltaN * Cast(Literal(6), resultType) * + (currentM0.left * currentM0.left * currentM2.right + + currentM0.right * currentM0.right * currentM2.left) + + deltaN * Cast(Literal(4), resultType) * (currentM0.left * currentM3.right - currentM0.right * currentM3.left) + + Seq( + /* currentM4 = */ If(IsNull(currentM4.left), currentM4.right, + If(IsNull(currentM4.right), currentM4.left, M4)), + /* currentM3 = */ If(IsNull(currentM3.left), currentM3.right, + If(IsNull(currentM3.right), currentM3.left, M3)), + /* currentM2 = */ If(IsNull(currentM2.left), currentM2.right, + If(IsNull(currentM2.right), currentM2.left, M2)), + /* currentMean = */ If(IsNull(currentM1.left), currentM1.right, + If(IsNull(currentM1.right), currentM1.left, M1)), + /* currentCount = */ If(IsNull(currentM0.left), currentM0.right, + If(IsNull(currentM0.right), currentM0.left, M0)) + ) + } +} + +case class Skewness(child: Expression) extends StatisticalMoments(child) { + override def prettyName = "skewness" + // TODO: protect against neg sqrt + override val evaluateExpression: Expression = { + Cast(currentM1, resultType) + } +// override val evaluateExpression: Expression = { +// If(EqualTo(currentM0, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), +// If(EqualTo(currentM0, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), +// Cast(Sqrt(currentM0) * currentM3 / Sqrt(currentM2 * currentM2 * currentM2), resultType))) +// } +} + +case class Kurtosis(child: Expression) extends StatisticalMoments(child) { + override def prettyName = "kurtosis" + override val evaluateExpression = { +// Cast(currentM0 * currentM4 / (currentM2 * currentM2) - Cast(Literal(3), resultType), resultType) + If(EqualTo(currentM0, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), + If(EqualTo(currentM0, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), + Cast(currentM0 * currentM4 / (currentM2 * currentM2) - Cast(Literal(3), resultType), resultType))) + } +} + case class Count(child: Expression) extends DeclarativeAggregate { override def children: Seq[Expression] = child :: Nil @@ -101,7 +232,7 @@ case class Count(child: Expression) extends DeclarativeAggregate { // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - private val currentCount = AttributeReference("currentCount", LongType)() + private val currentCount = AttributeReference("currentCount", LongType)() override val aggBufferAttributes = currentCount :: Nil @@ -930,3 +1061,78 @@ object HyperLogLogPlusPlus { ) // scalastyle:on } + +case class KahanAverage(child: Expression) extends AlgebraicAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // Return data type. + override def dataType: DataType = resultType + + // Expected input data type. + // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the + // new version at planning time (after analysis phase). For now, NullType is added at here + // to make it resolved when we have cases like `select avg(null)`. + // We can use our analyzer to cast NullType to the default data type of the NumericType once + // we remove the old aggregate functions. Then, we will not need NullType at here. + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) + + private val resultType = DoubleType + + private val currentMean = AttributeReference("currentMean", resultType)() + private val currentMeanErr = AttributeReference("currentMeanErr", resultType)() + private val currentCount = AttributeReference("currentCount", resultType)() + + override val bufferAttributes = currentMean :: currentMeanErr :: currentCount :: Nil + + override val initialValues = Seq( + /* currentMean = */ Cast(Literal(0), resultType), + /* currentMeanErr = */ Cast(Literal(0), resultType), + /* currentCount = */ Cast(Literal(0), resultType) + ) + def kahan(s1: Expression, c1: Expression, s2: Expression, c2: Expression) = { + val correctedS2 = s2 + (c1 + c2) + val s = s1 + correctedS2 + val diff = (s - s1) + val correction = correctedS2 - diff + (s, correction) + } + + override val updateExpressions = { + val x = Coalesce( + Cast(child, resultType) :: Cast(Literal(0), resultType) :: Nil + ) + lazy val updateCount = { + If(IsNull(child), currentCount, currentCount + Cast(Literal(1), resultType)) + } + lazy val delta = Subtract(x, currentMean) + lazy val (updateMean, updateMeanErr) = kahan(currentMean, currentMeanErr, delta / updateCount, Cast(Literal(0), resultType)) + Seq( + /* currentMean = */ If(IsNull(child), currentMean, updateMean), + If(IsNull(child), currentMeanErr, updateMeanErr), + /* currentCount = */ If(IsNull(child), currentCount, updateCount) + ) + } + + + override val mergeExpressions = { + lazy val updateCount = currentCount.left + currentCount.right + lazy val delta = currentMean.right - currentMean.left + + lazy val (updateMean, updateMeanErr) = kahan(currentMean.left, currentMeanErr.left, delta * (currentCount.right / updateCount), Cast(Literal(0), resultType)) + + Seq( + /* currentMean = */ If(IsNull(currentMean.left), currentMean.right, + If(IsNull(currentMean.right), currentMean.left, updateMean)), + If(IsNull(currentMeanErr.left), currentMeanErr.right, + If(IsNull(currentMeanErr.right), currentMeanErr.left, updateMeanErr)), + /* currentCount = */ If(IsNull(currentCount.left), currentCount.right, + If(IsNull(currentCount.right), currentCount.left, updateCount)) + ) + } + + // If all input are nulls, currentCount will be 0 and we will get null after the division. + override val evaluateExpression = currentMean * Cast(Literal(1), resultType) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala index 12bdab0915801..1be66b0747a60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala @@ -48,6 +48,18 @@ object Utils { mode = aggregate.Complete, isDistinct = false) + case expressions.Skewness(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Skewness(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Kurtosis(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Kurtosis(child), + mode = aggregate.Complete, + isDistinct = false) + case expressions.Count(child) => aggregate.AggregateExpression2( aggregateFunction = aggregate.Count(child), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 95061c4635879..75b651cdcfd6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -428,6 +428,63 @@ case class Average(child: Expression) extends UnaryExpression with PartialAggreg TypeUtils.checkForNumericExpr(child.dataType, "function average") } +abstract class StatisticalMoments1(child: Expression) extends UnaryExpression with PartialAggregate1 { + + override def prettyName: String = "kurtosis" + + override def nullable: Boolean = true + + override def dataType: DataType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + // Add 4 digits after decimal point, like Hive + DecimalType.bounded(precision + 4, scale + 4) + case _ => + DoubleType + } + + override def asPartial: SplitEvaluation = { + child.dataType match { + case DecimalType.Fixed(precision, scale) => + val partialSum = Alias(Sum(child), "PartialSum")() + val partialCount = Alias(Count(child), "PartialCount")() + + // partialSum already increase the precision by 10 + val castedSum = Cast(Sum(partialSum.toAttribute), partialSum.dataType) + val castedCount = Cast(Sum(partialCount.toAttribute), partialSum.dataType) + SplitEvaluation( + Cast(Divide(castedSum, castedCount), dataType), + partialCount :: partialSum :: Nil) + + case _ => + val partialSum = Alias(Sum(child), "PartialSum")() + val partialCount = Alias(Count(child), "PartialCount")() + + val castedSum = Cast(Sum(partialSum.toAttribute), dataType) + val castedCount = Cast(Sum(partialCount.toAttribute), dataType) + SplitEvaluation( + Divide(castedSum, castedCount), + partialCount :: partialSum :: Nil) + } + } + + override def newInstance(): AverageFunction = new AverageFunction(child, this) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function average") +} + +case class Kurtosis(child: Expression) extends StatisticalMoments1(child) { + + override def toString: String = s"KURTOSIS($child)" +} + +case class Skewness(child: Expression) extends StatisticalMoments1(child) { + + override def prettyName: String = "skewness" + + override def toString: String = s"SKEWNESS($child)" +} + case class AverageFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 61a17fd7db0fe..9031da6fcf1af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -93,6 +93,7 @@ case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes protected override def nullSafeEval(input: Any): Any = numeric.abs(input) } + abstract class BinaryArithmetic extends BinaryOperator { override def dataType: DataType = left.dataType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 102b802ad0a0a..9b242caac994e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -128,6 +128,8 @@ class GroupedData protected[sql]( case "stddev_pop" => StddevPop case "stddev_samp" => StddevSamp case "sum" => Sum + case "skewness" => Skewness + case "kurtosis" => Kurtosis case "count" | "size" => // Turn count(*) into count(1) (inputExpr: Expression) => inputExpr match { @@ -250,6 +252,30 @@ class GroupedData protected[sql]( aggregateNumericColumns(colNames : _*)(Average) } + /** + * Compute the average value for each numeric columns for each group. This is an alias for `avg`. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the average values for them. + * + * @since 1.3.0 + */ + @scala.annotation.varargs + def skewness(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(Skewness) + } + + /** + * Compute the average value for each numeric columns for each group. This is an alias for `avg`. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the average values for them. + * + * @since 1.3.0 + */ + @scala.annotation.varargs + def kurtosis(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(Kurtosis) + } + /** * Compute the max value for each numeric columns for each group. * The resulting [[DataFrame]] will also contain the grouping columns. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 15c864a8ab641..23fff2d07f80e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -278,6 +278,42 @@ object functions { */ def mean(columnName: String): Column = avg(columnName) + /** + * Aggregate function: returns the average of the values in a group. + * Alias for avg. + * + * @group agg_funcs + * @since 1.6.0 + */ + def skewness(e: Column): Column = Skewness(e.expr) + + /** + * Aggregate function: returns the average of the values in a group. + * Alias for avg. + * + * @group agg_funcs + * @since 1.6.0 + */ + def skewness(columnName: String): Column = skewness(Column(columnName)) + + /** + * Aggregate function: returns the average of the values in a group. + * Alias for avg. + * + * @group agg_funcs + * @since 1.6.0 + */ + def kurtosis(e: Column): Column = Kurtosis(e.expr) + + /** + * Aggregate function: returns the average of the values in a group. + * Alias for avg. + * + * @group agg_funcs + * @since 1.6.0 + */ + def kurtosis(columnName: String): Column = kurtosis(Column(columnName)) + /** * Aggregate function: returns the minimum value of the expression in a group. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index f5ef9ffd7f4f2..07b868639b24f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -221,4 +221,40 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { emptyTableData.agg(sumDistinct('a)), Row(null)) } + + test("moments") { + checkAnswer( + testData2.agg(skewness('a)), + Row(0.0)) + + checkAnswer( + testData2.agg(kurtosis('a)), + Row(-1.5)) + } + + test("zero moments") { + val emptyTableData = Seq((1,2)).toDF("a", "b") + assert(emptyTableData.count() === 1) + + checkAnswer( + emptyTableData.agg(skewness('a)), + Row(0.0)) + + checkAnswer( + emptyTableData.agg(kurtosis('a)), + Row(0.0)) + } + + test("null moments") { + val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") + assert(emptyTableData.count() === 0) + + checkAnswer( + emptyTableData.agg(skewness('a)), + Row(null)) + + checkAnswer( + emptyTableData.agg(kurtosis('a)), + Row(null)) + } } From cf52ed7ce32e9b703282a91f04172df4f6b2e635 Mon Sep 17 00:00:00 2001 From: sethah Date: Fri, 18 Sep 2015 13:05:10 -0700 Subject: [PATCH 02/22] Adding kahan updates to higher order aggregate stats --- .../expressions/aggregate/functions.scala | 165 +++++++++++++----- .../org/apache/spark/sql/SQLQuerySuite.scala | 44 ++++- 2 files changed, 160 insertions(+), 49 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index b35b7e5ef4982..9f5df29a3bf26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -113,7 +113,7 @@ abstract class StatisticalMoments(child: Expression) extends AlgebraicAggregate protected val currentM3 = AttributeReference("currentM3", resultType)() protected val currentM2 = AttributeReference("currentM2", resultType)() protected val currentM1 = AttributeReference("currentM1", resultType)() - protected val currentM0 = AttributeReference("currentM2", resultType)() + protected val currentM0 = AttributeReference("currentM0", resultType)() override val bufferAttributes = currentM4 :: currentM3 :: currentM2 :: currentM1 :: currentM0 :: Nil @@ -155,8 +155,8 @@ abstract class StatisticalMoments(child: Expression) extends AlgebraicAggregate /* currentM4 = */ If(IsNull(child), currentM4, updateM4), /* currentM3 = */ If(IsNull(child), currentM3, updateM3), /* currentM2 = */ If(IsNull(child), currentM2, updateM2), - /* currentMean = */ If(IsNull(child), currentM1, updateM1), - /* currentCount = */ If(IsNull(child), currentM0,updateM0) + /* currentM1 = */ If(IsNull(child), currentM1, updateM1), + /* currentM0 = */ If(IsNull(child), currentM0,updateM0) ) } @@ -171,8 +171,8 @@ abstract class StatisticalMoments(child: Expression) extends AlgebraicAggregate lazy val M2 = currentM2.left + currentM2.right + delta * deltaN * currentM0.left * currentM0.right - lazy val M3 = currentM3.left + currentM3.right + delta * delta * delta * currentM0.left * - currentM0.right * (currentM0.left - currentM0.right) / (currentM0.right * currentM0.right) + + lazy val M3 = currentM3.left + currentM3.right + deltaN * deltaN * delta * currentM0.left * + currentM0.right * (currentM0.left - currentM0.right) + deltaN * (currentM0.left * currentM2.right - currentM0.right * currentM2.left) * Cast(Literal(3), resultType) @@ -190,31 +190,27 @@ abstract class StatisticalMoments(child: Expression) extends AlgebraicAggregate If(IsNull(currentM3.right), currentM3.left, M3)), /* currentM2 = */ If(IsNull(currentM2.left), currentM2.right, If(IsNull(currentM2.right), currentM2.left, M2)), - /* currentMean = */ If(IsNull(currentM1.left), currentM1.right, + /* currentM1 = */ If(IsNull(currentM1.left), currentM1.right, If(IsNull(currentM1.right), currentM1.left, M1)), - /* currentCount = */ If(IsNull(currentM0.left), currentM0.right, + /* currentM0 = */ If(IsNull(currentM0.left), currentM0.right, If(IsNull(currentM0.right), currentM0.left, M0)) ) } } case class Skewness(child: Expression) extends StatisticalMoments(child) { - override def prettyName = "skewness" + override def prettyName = "min" // TODO: protect against neg sqrt override val evaluateExpression: Expression = { - Cast(currentM1, resultType) + If(EqualTo(currentM0, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), + If(EqualTo(currentM0, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), + Cast(Sqrt(currentM0) * currentM3 / Sqrt(currentM2 * currentM2 * currentM2), resultType))) } -// override val evaluateExpression: Expression = { -// If(EqualTo(currentM0, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), -// If(EqualTo(currentM0, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), -// Cast(Sqrt(currentM0) * currentM3 / Sqrt(currentM2 * currentM2 * currentM2), resultType))) -// } } case class Kurtosis(child: Expression) extends StatisticalMoments(child) { override def prettyName = "kurtosis" override val evaluateExpression = { -// Cast(currentM0 * currentM4 / (currentM2 * currentM2) - Cast(Literal(3), resultType), resultType) If(EqualTo(currentM0, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), If(EqualTo(currentM0, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), Cast(currentM0 * currentM4 / (currentM2 * currentM2) - Cast(Literal(3), resultType), resultType))) @@ -1062,7 +1058,8 @@ object HyperLogLogPlusPlus { // scalastyle:on } -case class KahanAverage(child: Expression) extends AlgebraicAggregate { + +abstract class StableMoments(child: Expression) extends AlgebraicAggregate { override def children: Seq[Expression] = child :: Nil @@ -1079,19 +1076,33 @@ case class KahanAverage(child: Expression) extends AlgebraicAggregate { // we remove the old aggregate functions. Then, we will not need NullType at here. override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) - private val resultType = DoubleType + protected val resultType = DoubleType - private val currentMean = AttributeReference("currentMean", resultType)() - private val currentMeanErr = AttributeReference("currentMeanErr", resultType)() - private val currentCount = AttributeReference("currentCount", resultType)() + protected val currentM4 = AttributeReference("currentM4", resultType)() + protected val currentM3 = AttributeReference("currentM3", resultType)() + protected val currentM2 = AttributeReference("currentM2", resultType)() + protected val currentM1 = AttributeReference("currentM1", resultType)() + protected val currentM4Err = AttributeReference("currentM4Err", resultType)() + protected val currentM3Err = AttributeReference("currentM3Err", resultType)() + protected val currentM2Err = AttributeReference("currentM2Err", resultType)() + protected val currentM1Err = AttributeReference("currentM1Err", resultType)() + protected val currentM0 = AttributeReference("currentM0", resultType)() - override val bufferAttributes = currentMean :: currentMeanErr :: currentCount :: Nil + override val bufferAttributes = currentM4 :: currentM3 :: currentM2 :: currentM1 :: + currentM4Err :: currentM3Err :: currentM2Err :: currentM1Err :: currentM0 :: Nil override val initialValues = Seq( - /* currentMean = */ Cast(Literal(0), resultType), - /* currentMeanErr = */ Cast(Literal(0), resultType), - /* currentCount = */ Cast(Literal(0), resultType) + /* currentM4 = */ Cast(Literal(0), resultType), + /* currentM3 = */ Cast(Literal(0), resultType), + /* currentM2 = */ Cast(Literal(0), resultType), + /* currentM1 = */ Cast(Literal(0), resultType), + /* currentM4Err = */ Cast(Literal(0), resultType), + /* currentM3Err = */ Cast(Literal(0), resultType), + /* currentM2Err = */ Cast(Literal(0), resultType), + /* currentM1Err = */ Cast(Literal(0), resultType), + /* currentM0 = */ Cast(Literal(0), resultType) ) + def kahan(s1: Expression, c1: Expression, s2: Expression, c2: Expression) = { val correctedS2 = s2 + (c1 + c2) val s = s1 + correctedS2 @@ -1104,35 +1115,103 @@ case class KahanAverage(child: Expression) extends AlgebraicAggregate { val x = Coalesce( Cast(child, resultType) :: Cast(Literal(0), resultType) :: Nil ) - lazy val updateCount = { - If(IsNull(child), currentCount, currentCount + Cast(Literal(1), resultType)) + lazy val updateM0: Expression = { + If(IsNull(child), currentM0, currentM0 + Cast(Literal(1), resultType)) } - lazy val delta = Subtract(x, currentMean) - lazy val (updateMean, updateMeanErr) = kahan(currentMean, currentMeanErr, delta / updateCount, Cast(Literal(0), resultType)) + lazy val delta = Subtract(x, currentM1) + lazy val deltaN = Divide(delta, updateM0) + + lazy val (updateM2, updateM2Err) = kahan(currentM2, currentM2Err, deltaN * delta * (updateM0 - Cast(Literal(1), resultType)), Cast(Literal(0), resultType)) + lazy val (updateM1, updateM1Err) = kahan(currentM1, currentM1Err, deltaN, Cast(Literal(0), resultType)) + + lazy val (updateM3, updateM3Err) = kahan(currentM3, currentM3Err, deltaN * deltaN * delta * (updateM0 - Cast(Literal(1), resultType)) * + (updateM0 - Cast(Literal(2), resultType)) - deltaN * currentM2 * Cast(Literal(3), resultType), Cast(Literal(0), resultType)) + + lazy val (updateM4, updateM4Err) = kahan(currentM4, currentM4Err, deltaN * deltaN * deltaN * delta * (updateM0 - Cast(Literal(1), resultType)) * + (updateM0 * updateM0 - updateM0 * Cast(Literal(3), resultType) + Cast(Literal(3), resultType)) + + deltaN * deltaN * currentM2 * Cast(Literal(6), resultType) - deltaN * currentM3 * Cast(Literal(4), resultType), + Cast(Literal(0), resultType)) + Seq( - /* currentMean = */ If(IsNull(child), currentMean, updateMean), - If(IsNull(child), currentMeanErr, updateMeanErr), - /* currentCount = */ If(IsNull(child), currentCount, updateCount) + /* currentM4 = */ If(IsNull(child), currentM4, updateM4), + /* currentM3 = */ If(IsNull(child), currentM3, updateM3), + /* currentM2 = */ If(IsNull(child), currentM2, updateM2), + /* currentM1 = */ If(IsNull(child), currentM1, updateM1), + /* currentM4Err = */ If(IsNull(child), currentM4Err, updateM4Err), + /* currentM3Err = */ If(IsNull(child), currentM3Err, updateM3Err), + /* currentM2Err = */ If(IsNull(child), currentM2Err, updateM2Err), + /* currentM1Err = */ If(IsNull(child), currentM1Err, updateM1Err), + /* currentM0 = */ If(IsNull(child), currentM0, updateM0) ) } - override val mergeExpressions = { - lazy val updateCount = currentCount.left + currentCount.right - lazy val delta = currentMean.right - currentMean.left - lazy val (updateMean, updateMeanErr) = kahan(currentMean.left, currentMeanErr.left, delta * (currentCount.right / updateCount), Cast(Literal(0), resultType)) + lazy val M0 = currentM0.left + currentM0.right + lazy val delta = currentM1.right - currentM1.left + lazy val deltaN = delta / M0 + + lazy val (updateM1, updateM1Err) = kahan(currentM1.left, currentM1Err.left, delta * (currentM0.right / M0), Cast(Literal(0), resultType)) + + lazy val (tmpM2, tmpM2Err) = kahan(currentM2.left, currentM2Err.left, currentM2.right, currentM2Err.right) + lazy val (updateM2, updateM2Err) = kahan(tmpM2, tmpM2Err, delta * deltaN * currentM0.left * currentM0.right, Cast(Literal(0), resultType)) + + lazy val (tmpM3, tmpM3Err) = kahan(currentM3.left, currentM3Err.left, currentM3.right, currentM3Err.right) + lazy val (updateM3, updateM3Err) = kahan(tmpM3, tmpM3Err, deltaN * deltaN * delta * currentM0.left * + currentM0.right * (currentM0.left - currentM0.right) + + deltaN * (currentM0.left * currentM2.right - currentM0.right * currentM2.left) * + Cast(Literal(3), resultType), Cast(Literal(0), resultType)) + + lazy val (tmpM4, tmpM4Err) = kahan(currentM4.left, currentM4Err.left, currentM4.right, currentM4Err.right) + lazy val (updateM4, updateM4Err) = kahan(tmpM4, tmpM4Err, deltaN * deltaN * deltaN * delta * currentM0.left * + currentM0.right * (currentM0.left * currentM0.left - currentM0.left * currentM0.right + + currentM0.right * currentM0.right) + deltaN * deltaN * Cast(Literal(6), resultType) * + (currentM0.left * currentM0.left * currentM2.right + + currentM0.right * currentM0.right * currentM2.left) + + deltaN * Cast(Literal(4), resultType) * (currentM0.left * currentM3.right - currentM0.right * currentM3.left), + Cast(Literal(0), resultType)) Seq( - /* currentMean = */ If(IsNull(currentMean.left), currentMean.right, - If(IsNull(currentMean.right), currentMean.left, updateMean)), - If(IsNull(currentMeanErr.left), currentMeanErr.right, - If(IsNull(currentMeanErr.right), currentMeanErr.left, updateMeanErr)), - /* currentCount = */ If(IsNull(currentCount.left), currentCount.right, - If(IsNull(currentCount.right), currentCount.left, updateCount)) + /* currentM4 = */ If(IsNull(currentM4.left), currentM4.right, + If(IsNull(currentM4.right), currentM4.left, updateM4)), + /* currentM3 = */ If(IsNull(currentM3.left), currentM3.right, + If(IsNull(currentM3.right), currentM3.left, updateM3)), + /* currentM2 = */ If(IsNull(currentM2.left), currentM2.right, + If(IsNull(currentM2.right), currentM2.left, updateM2)), + /* currentM1 = */ If(IsNull(currentM1.left), currentM1.right, + If(IsNull(currentM1.right), currentM1.left, updateM1)), + /* currentM4Err = */ If(IsNull(currentM4Err.left), currentM4Err.right, + If(IsNull(currentM4Err.right), currentM4Err.left, updateM4Err)), + /* currentM3Err = */ If(IsNull(currentM3Err.left), currentM3Err.right, + If(IsNull(currentM3Err.right), currentM3Err.left, updateM3Err)), + /* currentM2Err = */ If(IsNull(currentM2Err.left), currentM2Err.right, + If(IsNull(currentM2Err.right), currentM2Err.left, updateM2Err)), + /* currentM1Err = */ If(IsNull(currentM1Err.left), currentM1Err.right, + If(IsNull(currentM1Err.right), currentM1Err.left, updateM1Err)), + /* currentM0 = */ If(IsNull(currentM0.left), currentM0.right, + If(IsNull(currentM0.right), currentM0.left, M0)) ) } +} - // If all input are nulls, currentCount will be 0 and we will get null after the division. - override val evaluateExpression = currentMean * Cast(Literal(1), resultType) +case class KahanSkewness(child: Expression) extends StableMoments(child) { + override def prettyName = "skewness" + // TODO: protect against neg sqrt + // TODO: skewness divides by zero if var is 0 + override val evaluateExpression: Expression = { + If(EqualTo(currentM0, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), + If(EqualTo(currentM0, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), + Cast(Sqrt(currentM0) * currentM3 / Sqrt(currentM2 * currentM2 * currentM2), resultType))) + } +} + +case class KahanKurtosis(child: Expression) extends StableMoments(child) { + override def prettyName = "skewness" + // TODO: protect against neg sqrt + // TODO: skewness divides by zero if var is 0 + override val evaluateExpression: Expression = { + If(EqualTo(currentM0, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), + If(EqualTo(currentM0, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), + Cast(currentM0 * currentM4 / (currentM2 * currentM2) - Cast(Literal(3), resultType), resultType))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 298c32290697a..ccffd45c03ec1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -336,6 +336,20 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { testCodeGen( "SELECT stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2", Row(math.sqrt(1.5 / 5), math.sqrt(1.5 / 6), math.sqrt(1.5 / 5)) :: Nil) + // SKEWNESS + testCodeGen( + "SELECT a, skewness(b) FROM testData2 GROUP BY a", + (1 to 3).map(i => Row(i, 0.0))) + testCodeGen( + "SELECT skewness(b) FROM testData2", + Row(0.0) :: Nil) + // KURTOSIS + testCodeGen( + "SELECT a, kurtosis(b) FROM testData2 GROUP BY a", + (1 to 3).map(i => Row(i, -2.0))) + testCodeGen( + "SELECT kurtosis(b) FROM testData2", + Row(-2.0) :: Nil) // Some combinations. testCodeGen( """ @@ -356,8 +370,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(100, 1, 50.5, 300, 100) :: Nil) // Aggregate with Code generation handling all null values testCodeGen( - "SELECT sum('a'), avg('a'), stddev('a'), count(null) FROM testData", - Row(null, null, null, 0) :: Nil) + "SELECT sum('a'), avg('a'), stddev('a'), skewness('a'), kurtosis('a'), count(null) FROM testData", + Row(null, null, null, null, null, 0) :: Nil) } finally { sqlContext.dropTempTable("testData3x") sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue) @@ -523,8 +537,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("aggregates with nulls") { checkAnswer( - sql("SELECT MIN(a), MAX(a), AVG(a), STDDEV(a), SUM(a), COUNT(a) FROM nullInts"), - Row(1, 3, 2, 1, 6, 3) + sql("SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a), AVG(a), STDDEV(a), SUM(a), COUNT(a) FROM nullInts"), + Row(0, -1.5, 1, 3, 2, 1, 6, 3) ) } @@ -732,13 +746,31 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer( sql("SELECT STDDEV_SAMP(a) FROM testData2"), Row(math.sqrt(4/5.0)) + } + + test("skewness") { + checkAnswer( + sql("SELECT skewness(a) FROM testData2"), + Row(0.0) + ) + } + + test("kurtosis") { + checkAnswer( + sql("SELECT kurtosis(a) FROM testData2"), + Row(-1.5) ) } test("stddev agg") { checkAnswer( - sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"), - (1 to 3).map(i => Row(i, math.sqrt(1/2.0), math.sqrt(1/4.0), math.sqrt(1/2.0)))) + sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"), + (1 to 3).map(i => Row(i, math.sqrt(1 / 2.0), math.sqrt(1 / 4.0), math.sqrt(1 / 2.0)))) + } + + test("skewness and kurtosis agg") { + sql("SELECT a, skewness(b), kurtosis(b) FROM testData2 GROUP BY a"), + (1 to 3).map(i => Row(i, 0.0, -2.0))) } test("inner join where, one match per row") { From 7ecf50ea0384cf850a957c288bcdcd7726ac8a24 Mon Sep 17 00:00:00 2001 From: sethah Date: Mon, 21 Sep 2015 09:48:14 -0700 Subject: [PATCH 03/22] adding zero division protection --- .../expressions/aggregate/functions.scala | 22 +++++++++++++------ .../sql/catalyst/expressions/aggregates.scala | 2 ++ 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 9f5df29a3bf26..36ad2576d633c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -199,21 +199,25 @@ abstract class StatisticalMoments(child: Expression) extends AlgebraicAggregate } case class Skewness(child: Expression) extends StatisticalMoments(child) { - override def prettyName = "min" + override def prettyName = "skewness" // TODO: protect against neg sqrt + // skewness = sqrt(M_0) * M_3 / M_2^(3/2) override val evaluateExpression: Expression = { If(EqualTo(currentM0, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), If(EqualTo(currentM0, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), - Cast(Sqrt(currentM0) * currentM3 / Sqrt(currentM2 * currentM2 * currentM2), resultType))) + If(EqualTo(currentM2, Cast(Literal(0), resultType)), Cast(Literal(0), resultType), + Cast(Sqrt(currentM0) * currentM3 / Sqrt(currentM2 * currentM2 * currentM2), resultType)))) } } case class Kurtosis(child: Expression) extends StatisticalMoments(child) { override def prettyName = "kurtosis" + // kurtosis = M_0 * M_4 / M_2^2 - 3 override val evaluateExpression = { If(EqualTo(currentM0, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), If(EqualTo(currentM0, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), - Cast(currentM0 * currentM4 / (currentM2 * currentM2) - Cast(Literal(3), resultType), resultType))) + If(EqualTo(currentM2, Cast(Literal(0), resultType)), Cast(Literal(0), resultType), + Cast(currentM0 * currentM4 / (currentM2 * currentM2) - Cast(Literal(3), resultType), resultType)))) } } @@ -1103,6 +1107,8 @@ abstract class StableMoments(child: Expression) extends AlgebraicAggregate { /* currentM0 = */ Cast(Literal(0), resultType) ) + // Kahan update implemented according to: + // http://researcher.watson.ibm.com/researcher/files/us-ytian/stability.pdf def kahan(s1: Expression, c1: Expression, s2: Expression, c2: Expression) = { val correctedS2 = s2 + (c1 + c2) val s = s1 + correctedS2 @@ -1195,23 +1201,25 @@ abstract class StableMoments(child: Expression) extends AlgebraicAggregate { } case class KahanSkewness(child: Expression) extends StableMoments(child) { - override def prettyName = "skewness" + override def prettyName = "kahanskewness" // TODO: protect against neg sqrt // TODO: skewness divides by zero if var is 0 override val evaluateExpression: Expression = { If(EqualTo(currentM0, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), If(EqualTo(currentM0, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), - Cast(Sqrt(currentM0) * currentM3 / Sqrt(currentM2 * currentM2 * currentM2), resultType))) + If(EqualTo(currentM2, Cast(Literal(0), resultType)), Cast(Literal(0), resultType), + Cast(Sqrt(currentM0) * currentM3 / Sqrt(currentM2 * currentM2 * currentM2), resultType)))) } } case class KahanKurtosis(child: Expression) extends StableMoments(child) { - override def prettyName = "skewness" + override def prettyName = "kahankurtosis" // TODO: protect against neg sqrt // TODO: skewness divides by zero if var is 0 override val evaluateExpression: Expression = { If(EqualTo(currentM0, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), If(EqualTo(currentM0, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), - Cast(currentM0 * currentM4 / (currentM2 * currentM2) - Cast(Literal(3), resultType), resultType))) + If(EqualTo(currentM2, Cast(Literal(0), resultType)), Cast(Literal(0), resultType), + Cast(currentM0 * currentM4 / (currentM2 * currentM2) - Cast(Literal(3), resultType), resultType)))) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 75b651cdcfd6d..5315d2d0f4f7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -475,6 +475,8 @@ abstract class StatisticalMoments1(child: Expression) extends UnaryExpression wi case class Kurtosis(child: Expression) extends StatisticalMoments1(child) { + override def prettyName: String = "kurtosis" + override def toString: String = s"KURTOSIS($child)" } From 579b9f2150f1ec988b658aa5eaf9849305f7ad17 Mon Sep 17 00:00:00 2001 From: sethah Date: Mon, 5 Oct 2015 21:12:48 -0700 Subject: [PATCH 04/22] Adding order check to reduce calculation overhead --- .../expressions/aggregate/functions.scala | 281 ++++++------------ .../sql/catalyst/expressions/aggregates.scala | 5 +- .../spark/sql/DataFrameAggregateSuite.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 8 +- 4 files changed, 98 insertions(+), 198 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 36ad2576d633c..50ccde1d07687 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -92,6 +92,8 @@ case class Average(child: Expression) extends DeclarativeAggregate { abstract class StatisticalMoments(child: Expression) extends AlgebraicAggregate { + def highestOrder: Int + override def children: Seq[Expression] = child :: Nil override def nullable: Boolean = true @@ -115,7 +117,7 @@ abstract class StatisticalMoments(child: Expression) extends AlgebraicAggregate protected val currentM1 = AttributeReference("currentM1", resultType)() protected val currentM0 = AttributeReference("currentM0", resultType)() - override val bufferAttributes = currentM4 :: currentM3 :: currentM2 :: currentM1 :: currentM0 :: Nil + override val bufferAttributes = List(currentM4, currentM3, currentM2, currentM1, currentM0) override val initialValues = Seq( /* currentM4 = */ Cast(Literal(0), resultType), @@ -126,29 +128,37 @@ abstract class StatisticalMoments(child: Expression) extends AlgebraicAggregate ) override val updateExpressions = { - val x = Coalesce( - Cast(child, resultType) :: Cast(Literal(0), resultType) :: Nil - ) lazy val updateM0: Expression = { - If(IsNull(child), currentM0, currentM0 + Cast(Literal(1), resultType)) + Add(currentM0, Cast(Literal(1), resultType)) } - lazy val delta = Subtract(x, currentM1) + lazy val delta = Subtract(Cast(child, resultType), currentM1) lazy val deltaN = Divide(delta, updateM0) - lazy val updateM2: Expression = { - Add(currentM2, Multiply(deltaN * delta, Subtract(updateM0, Cast(Literal(1), resultType)))) - } lazy val updateM1: Expression = { Add(currentM1, Divide(delta, updateM0)) } - lazy val updateM3: Expression = { + + lazy val updateM2: Expression = if (highestOrder >= 2) { + Add(currentM2, Multiply(deltaN * delta, Subtract(updateM0, Cast(Literal(1), resultType)))) + } else { + Cast(Literal(0), resultType) + } + + lazy val updateM3: Expression = if (highestOrder >= 3) { currentM3 + deltaN * deltaN * delta * (updateM0 - Cast(Literal(1), resultType)) * - (updateM0 - Cast(Literal(2), resultType)) - deltaN * currentM2 * Cast(Literal(3), resultType) + (updateM0 - Cast(Literal(2), resultType)) - + deltaN * currentM2 * Cast(Literal(3), resultType) + } else { + Cast(Literal(0), resultType) } - lazy val updateM4: Expression = { + + lazy val updateM4: Expression = if (highestOrder >= 4) { currentM4 + deltaN * deltaN * deltaN * delta * (updateM0 - Cast(Literal(1), resultType)) * - (updateM0 * updateM0 - updateM0 * Cast(Literal(3), resultType) + Cast(Literal(3), resultType)) + - deltaN * deltaN * currentM2 * Cast(Literal(6), resultType) - deltaN * currentM3 * Cast(Literal(4), resultType) + (updateM0 * updateM0 - updateM0 * Cast(Literal(3), resultType) + + Cast(Literal(3), resultType)) + deltaN * deltaN * currentM2 * + Cast(Literal(6), resultType) - deltaN * currentM3 * Cast(Literal(4), resultType) + } else { + Cast(Literal(0), resultType) } Seq( @@ -156,7 +166,7 @@ abstract class StatisticalMoments(child: Expression) extends AlgebraicAggregate /* currentM3 = */ If(IsNull(child), currentM3, updateM3), /* currentM2 = */ If(IsNull(child), currentM2, updateM2), /* currentM1 = */ If(IsNull(child), currentM1, updateM1), - /* currentM0 = */ If(IsNull(child), currentM0,updateM0) + /* currentM0 = */ If(IsNull(child), currentM0, updateM0) ) } @@ -169,19 +179,32 @@ abstract class StatisticalMoments(child: Expression) extends AlgebraicAggregate lazy val M1 = currentM1.left + delta * (currentM0.right / M0) - lazy val M2 = currentM2.left + currentM2.right + delta * deltaN * currentM0.left * currentM0.right + lazy val M2 = if (highestOrder >= 2) { + currentM2.left + currentM2.right + delta * deltaN * currentM0.left * currentM0.right + } else { + Cast(Literal(0), resultType) + } - lazy val M3 = currentM3.left + currentM3.right + deltaN * deltaN * delta * currentM0.left * - currentM0.right * (currentM0.left - currentM0.right) + - deltaN * (currentM0.left * currentM2.right - currentM0.right * currentM2.left) * - Cast(Literal(3), resultType) + lazy val M3 = if (highestOrder >= 3) { + currentM3.left + currentM3.right + deltaN * deltaN * delta * currentM0.left * + currentM0.right * (currentM0.left - currentM0.right) + + deltaN * (currentM0.left * currentM2.right - currentM0.right * currentM2.left) * + Cast(Literal(3), resultType) + } else { + Cast(Literal(0), resultType) + } - lazy val M4 = currentM4.left + currentM4.right + deltaN * deltaN * deltaN * delta * currentM0.left * - currentM0.right * (currentM0.left * currentM0.left - currentM0.left * currentM0.right + - currentM0.right * currentM0.right) + deltaN * deltaN * Cast(Literal(6), resultType) * - (currentM0.left * currentM0.left * currentM2.right + - currentM0.right * currentM0.right * currentM2.left) + - deltaN * Cast(Literal(4), resultType) * (currentM0.left * currentM3.right - currentM0.right * currentM3.left) + lazy val M4 = if (highestOrder >= 4) { + currentM4.left + currentM4.right + deltaN * deltaN * deltaN * delta * currentM0.left * + currentM0.right * (currentM0.left * currentM0.left - currentM0.left * currentM0.right + + currentM0.right * currentM0.right) + deltaN * deltaN * Cast(Literal(6), resultType) * + (currentM0.left * currentM0.left * currentM2.right + + currentM0.right * currentM0.right * currentM2.left) + + deltaN * Cast(Literal(4), resultType) * + (currentM0.left * currentM3.right - currentM0.right * currentM3.left) + } else { + Cast(Literal(0), resultType) + } Seq( /* currentM4 = */ If(IsNull(currentM4.left), currentM4.right, @@ -198,26 +221,60 @@ abstract class StatisticalMoments(child: Expression) extends AlgebraicAggregate } } +//case class Average(child: Expression) extends StatisticalMoments(child) { +// +// override def highestOrder = 1 +// +// override def prettyName = "average" +// // average = M_1 +// override val evaluateExpression: Expression = { +// If(EqualTo(currentM0, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), +// If(EqualTo(currentM0, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), +// Cast(currentM1, resultType))) +// } +//} +// +//case class Variance(child: Expression) extends StatisticalMoments(child) { +// +// override def highestOrder = 2 +// +// override def prettyName = "variance" +// // variance = M_2 / M_0 +// override val evaluateExpression: Expression = { +// If(EqualTo(currentM0, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), +// If(EqualTo(currentM0, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), +// Cast(currentM2 / currentM0, resultType))) +// } +//} + case class Skewness(child: Expression) extends StatisticalMoments(child) { - override def prettyName = "skewness" + + override def highestOrder: Int = 3 + + override def prettyName: String = "skewness" // TODO: protect against neg sqrt // skewness = sqrt(M_0) * M_3 / M_2^(3/2) override val evaluateExpression: Expression = { If(EqualTo(currentM0, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), If(EqualTo(currentM0, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), If(EqualTo(currentM2, Cast(Literal(0), resultType)), Cast(Literal(0), resultType), - Cast(Sqrt(currentM0) * currentM3 / Sqrt(currentM2 * currentM2 * currentM2), resultType)))) + Cast(Sqrt(currentM0) * currentM3 / + Sqrt(currentM2 * currentM2 * currentM2), resultType)))) } } case class Kurtosis(child: Expression) extends StatisticalMoments(child) { - override def prettyName = "kurtosis" + + override def highestOrder: Int = 4 + + override def prettyName: String = "kurtosis" // kurtosis = M_0 * M_4 / M_2^2 - 3 override val evaluateExpression = { If(EqualTo(currentM0, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), - If(EqualTo(currentM0, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), + If(EqualTo(currentM0, Cast(Literal(1), resultType)), Cast(Literal(-3), resultType), If(EqualTo(currentM2, Cast(Literal(0), resultType)), Cast(Literal(0), resultType), - Cast(currentM0 * currentM4 / (currentM2 * currentM2) - Cast(Literal(3), resultType), resultType)))) + Cast(currentM0 * currentM4 / (currentM2 * currentM2) - + Cast(Literal(3), resultType), resultType)))) } } @@ -1061,165 +1118,3 @@ object HyperLogLogPlusPlus { ) // scalastyle:on } - - -abstract class StableMoments(child: Expression) extends AlgebraicAggregate { - - override def children: Seq[Expression] = child :: Nil - - override def nullable: Boolean = true - - // Return data type. - override def dataType: DataType = resultType - - // Expected input data type. - // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the - // new version at planning time (after analysis phase). For now, NullType is added at here - // to make it resolved when we have cases like `select avg(null)`. - // We can use our analyzer to cast NullType to the default data type of the NumericType once - // we remove the old aggregate functions. Then, we will not need NullType at here. - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) - - protected val resultType = DoubleType - - protected val currentM4 = AttributeReference("currentM4", resultType)() - protected val currentM3 = AttributeReference("currentM3", resultType)() - protected val currentM2 = AttributeReference("currentM2", resultType)() - protected val currentM1 = AttributeReference("currentM1", resultType)() - protected val currentM4Err = AttributeReference("currentM4Err", resultType)() - protected val currentM3Err = AttributeReference("currentM3Err", resultType)() - protected val currentM2Err = AttributeReference("currentM2Err", resultType)() - protected val currentM1Err = AttributeReference("currentM1Err", resultType)() - protected val currentM0 = AttributeReference("currentM0", resultType)() - - override val bufferAttributes = currentM4 :: currentM3 :: currentM2 :: currentM1 :: - currentM4Err :: currentM3Err :: currentM2Err :: currentM1Err :: currentM0 :: Nil - - override val initialValues = Seq( - /* currentM4 = */ Cast(Literal(0), resultType), - /* currentM3 = */ Cast(Literal(0), resultType), - /* currentM2 = */ Cast(Literal(0), resultType), - /* currentM1 = */ Cast(Literal(0), resultType), - /* currentM4Err = */ Cast(Literal(0), resultType), - /* currentM3Err = */ Cast(Literal(0), resultType), - /* currentM2Err = */ Cast(Literal(0), resultType), - /* currentM1Err = */ Cast(Literal(0), resultType), - /* currentM0 = */ Cast(Literal(0), resultType) - ) - - // Kahan update implemented according to: - // http://researcher.watson.ibm.com/researcher/files/us-ytian/stability.pdf - def kahan(s1: Expression, c1: Expression, s2: Expression, c2: Expression) = { - val correctedS2 = s2 + (c1 + c2) - val s = s1 + correctedS2 - val diff = (s - s1) - val correction = correctedS2 - diff - (s, correction) - } - - override val updateExpressions = { - val x = Coalesce( - Cast(child, resultType) :: Cast(Literal(0), resultType) :: Nil - ) - lazy val updateM0: Expression = { - If(IsNull(child), currentM0, currentM0 + Cast(Literal(1), resultType)) - } - lazy val delta = Subtract(x, currentM1) - lazy val deltaN = Divide(delta, updateM0) - - lazy val (updateM2, updateM2Err) = kahan(currentM2, currentM2Err, deltaN * delta * (updateM0 - Cast(Literal(1), resultType)), Cast(Literal(0), resultType)) - lazy val (updateM1, updateM1Err) = kahan(currentM1, currentM1Err, deltaN, Cast(Literal(0), resultType)) - - lazy val (updateM3, updateM3Err) = kahan(currentM3, currentM3Err, deltaN * deltaN * delta * (updateM0 - Cast(Literal(1), resultType)) * - (updateM0 - Cast(Literal(2), resultType)) - deltaN * currentM2 * Cast(Literal(3), resultType), Cast(Literal(0), resultType)) - - lazy val (updateM4, updateM4Err) = kahan(currentM4, currentM4Err, deltaN * deltaN * deltaN * delta * (updateM0 - Cast(Literal(1), resultType)) * - (updateM0 * updateM0 - updateM0 * Cast(Literal(3), resultType) + Cast(Literal(3), resultType)) + - deltaN * deltaN * currentM2 * Cast(Literal(6), resultType) - deltaN * currentM3 * Cast(Literal(4), resultType), - Cast(Literal(0), resultType)) - - Seq( - /* currentM4 = */ If(IsNull(child), currentM4, updateM4), - /* currentM3 = */ If(IsNull(child), currentM3, updateM3), - /* currentM2 = */ If(IsNull(child), currentM2, updateM2), - /* currentM1 = */ If(IsNull(child), currentM1, updateM1), - /* currentM4Err = */ If(IsNull(child), currentM4Err, updateM4Err), - /* currentM3Err = */ If(IsNull(child), currentM3Err, updateM3Err), - /* currentM2Err = */ If(IsNull(child), currentM2Err, updateM2Err), - /* currentM1Err = */ If(IsNull(child), currentM1Err, updateM1Err), - /* currentM0 = */ If(IsNull(child), currentM0, updateM0) - ) - } - - override val mergeExpressions = { - - lazy val M0 = currentM0.left + currentM0.right - lazy val delta = currentM1.right - currentM1.left - lazy val deltaN = delta / M0 - - lazy val (updateM1, updateM1Err) = kahan(currentM1.left, currentM1Err.left, delta * (currentM0.right / M0), Cast(Literal(0), resultType)) - - lazy val (tmpM2, tmpM2Err) = kahan(currentM2.left, currentM2Err.left, currentM2.right, currentM2Err.right) - lazy val (updateM2, updateM2Err) = kahan(tmpM2, tmpM2Err, delta * deltaN * currentM0.left * currentM0.right, Cast(Literal(0), resultType)) - - lazy val (tmpM3, tmpM3Err) = kahan(currentM3.left, currentM3Err.left, currentM3.right, currentM3Err.right) - lazy val (updateM3, updateM3Err) = kahan(tmpM3, tmpM3Err, deltaN * deltaN * delta * currentM0.left * - currentM0.right * (currentM0.left - currentM0.right) + - deltaN * (currentM0.left * currentM2.right - currentM0.right * currentM2.left) * - Cast(Literal(3), resultType), Cast(Literal(0), resultType)) - - lazy val (tmpM4, tmpM4Err) = kahan(currentM4.left, currentM4Err.left, currentM4.right, currentM4Err.right) - lazy val (updateM4, updateM4Err) = kahan(tmpM4, tmpM4Err, deltaN * deltaN * deltaN * delta * currentM0.left * - currentM0.right * (currentM0.left * currentM0.left - currentM0.left * currentM0.right + - currentM0.right * currentM0.right) + deltaN * deltaN * Cast(Literal(6), resultType) * - (currentM0.left * currentM0.left * currentM2.right + - currentM0.right * currentM0.right * currentM2.left) + - deltaN * Cast(Literal(4), resultType) * (currentM0.left * currentM3.right - currentM0.right * currentM3.left), - Cast(Literal(0), resultType)) - - Seq( - /* currentM4 = */ If(IsNull(currentM4.left), currentM4.right, - If(IsNull(currentM4.right), currentM4.left, updateM4)), - /* currentM3 = */ If(IsNull(currentM3.left), currentM3.right, - If(IsNull(currentM3.right), currentM3.left, updateM3)), - /* currentM2 = */ If(IsNull(currentM2.left), currentM2.right, - If(IsNull(currentM2.right), currentM2.left, updateM2)), - /* currentM1 = */ If(IsNull(currentM1.left), currentM1.right, - If(IsNull(currentM1.right), currentM1.left, updateM1)), - /* currentM4Err = */ If(IsNull(currentM4Err.left), currentM4Err.right, - If(IsNull(currentM4Err.right), currentM4Err.left, updateM4Err)), - /* currentM3Err = */ If(IsNull(currentM3Err.left), currentM3Err.right, - If(IsNull(currentM3Err.right), currentM3Err.left, updateM3Err)), - /* currentM2Err = */ If(IsNull(currentM2Err.left), currentM2Err.right, - If(IsNull(currentM2Err.right), currentM2Err.left, updateM2Err)), - /* currentM1Err = */ If(IsNull(currentM1Err.left), currentM1Err.right, - If(IsNull(currentM1Err.right), currentM1Err.left, updateM1Err)), - /* currentM0 = */ If(IsNull(currentM0.left), currentM0.right, - If(IsNull(currentM0.right), currentM0.left, M0)) - ) - } -} - -case class KahanSkewness(child: Expression) extends StableMoments(child) { - override def prettyName = "kahanskewness" - // TODO: protect against neg sqrt - // TODO: skewness divides by zero if var is 0 - override val evaluateExpression: Expression = { - If(EqualTo(currentM0, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), - If(EqualTo(currentM0, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), - If(EqualTo(currentM2, Cast(Literal(0), resultType)), Cast(Literal(0), resultType), - Cast(Sqrt(currentM0) * currentM3 / Sqrt(currentM2 * currentM2 * currentM2), resultType)))) - } -} - -case class KahanKurtosis(child: Expression) extends StableMoments(child) { - override def prettyName = "kahankurtosis" - // TODO: protect against neg sqrt - // TODO: skewness divides by zero if var is 0 - override val evaluateExpression: Expression = { - If(EqualTo(currentM0, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), - If(EqualTo(currentM0, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), - If(EqualTo(currentM2, Cast(Literal(0), resultType)), Cast(Literal(0), resultType), - Cast(currentM0 * currentM4 / (currentM2 * currentM2) - Cast(Literal(3), resultType), resultType)))) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 5315d2d0f4f7f..67a727562ea46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -428,10 +428,9 @@ case class Average(child: Expression) extends UnaryExpression with PartialAggreg TypeUtils.checkForNumericExpr(child.dataType, "function average") } +// placeholder so code will compile abstract class StatisticalMoments1(child: Expression) extends UnaryExpression with PartialAggregate1 { - override def prettyName: String = "kurtosis" - override def nullable: Boolean = true override def dataType: DataType = child.dataType match { @@ -473,6 +472,7 @@ abstract class StatisticalMoments1(child: Expression) extends UnaryExpression wi TypeUtils.checkForNumericExpr(child.dataType, "function average") } +// placeholder so code will compile case class Kurtosis(child: Expression) extends StatisticalMoments1(child) { override def prettyName: String = "kurtosis" @@ -480,6 +480,7 @@ case class Kurtosis(child: Expression) extends StatisticalMoments1(child) { override def toString: String = s"KURTOSIS($child)" } +// placeholder so code will compile case class Skewness(child: Expression) extends StatisticalMoments1(child) { override def prettyName: String = "skewness" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 07b868639b24f..909672dec77c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -233,7 +233,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } test("zero moments") { - val emptyTableData = Seq((1,2)).toDF("a", "b") + val emptyTableData = Seq((1, 2)).toDF("a", "b") assert(emptyTableData.count() === 1) checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index ccffd45c03ec1..c5b05fcef9e0f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -370,7 +370,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(100, 1, 50.5, 300, 100) :: Nil) // Aggregate with Code generation handling all null values testCodeGen( - "SELECT sum('a'), avg('a'), stddev('a'), skewness('a'), kurtosis('a'), count(null) FROM testData", + "SELECT sum('a'), avg('a'), stddev('a'), skewness('a')," + + "kurtosis('a'), count(null) FROM testData", Row(null, null, null, null, null, 0) :: Nil) } finally { sqlContext.dropTempTable("testData3x") @@ -537,7 +538,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("aggregates with nulls") { checkAnswer( - sql("SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a), AVG(a), STDDEV(a), SUM(a), COUNT(a) FROM nullInts"), + sql("SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a)," + + "AVG(a), STDDEV(a), SUM(a), COUNT(a) FROM nullInts"), Row(0, -1.5, 1, 3, 2, 1, 6, 3) ) } @@ -746,6 +748,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer( sql("SELECT STDDEV_SAMP(a) FROM testData2"), Row(math.sqrt(4/5.0)) + ) } test("skewness") { @@ -769,6 +772,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("skewness and kurtosis agg") { + checkAnswer( sql("SELECT a, skewness(b), kurtosis(b) FROM testData2 GROUP BY a"), (1 to 3).map(i => Row(i, 0.0, -2.0))) } From 230f66c63076fb5006a2b1727a11fc7f92de1a21 Mon Sep 17 00:00:00 2001 From: sethah Date: Mon, 5 Oct 2015 21:20:41 -0700 Subject: [PATCH 05/22] style and scaladoc fixes --- .../scala/org/apache/spark/sql/GroupedData.scala | 12 ++++++------ .../scala/org/apache/spark/sql/functions.scala | 16 ++++++++-------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 9b242caac994e..a524314532bad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -253,11 +253,11 @@ class GroupedData protected[sql]( } /** - * Compute the average value for each numeric columns for each group. This is an alias for `avg`. + * Compute the skewness for each numeric columns for each group. This is an alias for `skewness`. * The resulting [[DataFrame]] will also contain the grouping columns. - * When specified columns are given, only compute the average values for them. + * When specified columns are given, only compute the skewness values for them. * - * @since 1.3.0 + * @since 1.6.0 */ @scala.annotation.varargs def skewness(colNames: String*): DataFrame = { @@ -265,11 +265,11 @@ class GroupedData protected[sql]( } /** - * Compute the average value for each numeric columns for each group. This is an alias for `avg`. + * Compute the kurtosis for each numeric columns for each group. This is an alias for `kurtosis`. * The resulting [[DataFrame]] will also contain the grouping columns. - * When specified columns are given, only compute the average values for them. + * When specified columns are given, only compute the kurtosis values for them. * - * @since 1.3.0 + * @since 1.6.0 */ @scala.annotation.varargs def kurtosis(colNames: String*): DataFrame = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 23fff2d07f80e..a052745a1de34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -279,8 +279,8 @@ object functions { def mean(columnName: String): Column = avg(columnName) /** - * Aggregate function: returns the average of the values in a group. - * Alias for avg. + * Aggregate function: returns the skewness of the values in a group. + * Alias for skewness. * * @group agg_funcs * @since 1.6.0 @@ -288,8 +288,8 @@ object functions { def skewness(e: Column): Column = Skewness(e.expr) /** - * Aggregate function: returns the average of the values in a group. - * Alias for avg. + * Aggregate function: returns the skewness of the values in a group. + * Alias for skewness. * * @group agg_funcs * @since 1.6.0 @@ -297,8 +297,8 @@ object functions { def skewness(columnName: String): Column = skewness(Column(columnName)) /** - * Aggregate function: returns the average of the values in a group. - * Alias for avg. + * Aggregate function: returns the kurtosis of the values in a group. + * Alias for kurtosis. * * @group agg_funcs * @since 1.6.0 @@ -306,8 +306,8 @@ object functions { def kurtosis(e: Column): Column = Kurtosis(e.expr) /** - * Aggregate function: returns the average of the values in a group. - * Alias for avg. + * Aggregate function: returns the kurtosis of the values in a group. + * Alias for kurtosis. * * @group agg_funcs * @since 1.6.0 From 1c4c4d0e38713d7c1fde8fdc169acc4d84c9711f Mon Sep 17 00:00:00 2001 From: sethah Date: Tue, 6 Oct 2015 13:56:38 -0700 Subject: [PATCH 06/22] updating kurtosis test --- .../expressions/aggregate/functions.scala | 140 ++++++++++++++---- .../sql/catalyst/expressions/aggregates.scala | 121 +++++++-------- .../spark/sql/DataFrameAggregateSuite.scala | 2 +- 3 files changed, 173 insertions(+), 90 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 50ccde1d07687..6930d32997e2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -90,9 +90,10 @@ case class Average(child: Expression) extends DeclarativeAggregate { } } -abstract class StatisticalMoments(child: Expression) extends AlgebraicAggregate { +abstract class CentralMomentAgg(child: Expression) extends AlgebraicAggregate { - def highestOrder: Int + // specify the maximum order moment needed for the computation + def maxMoment: Int override def children: Seq[Expression] = child :: Nil @@ -134,17 +135,19 @@ abstract class StatisticalMoments(child: Expression) extends AlgebraicAggregate lazy val delta = Subtract(Cast(child, resultType), currentM1) lazy val deltaN = Divide(delta, updateM0) - lazy val updateM1: Expression = { + lazy val updateM1: Expression = if (maxMoment >= 1) { Add(currentM1, Divide(delta, updateM0)) + } else { + Cast(Literal(0), resultType) } - lazy val updateM2: Expression = if (highestOrder >= 2) { + lazy val updateM2: Expression = if (maxMoment >= 2) { Add(currentM2, Multiply(deltaN * delta, Subtract(updateM0, Cast(Literal(1), resultType)))) } else { Cast(Literal(0), resultType) } - lazy val updateM3: Expression = if (highestOrder >= 3) { + lazy val updateM3: Expression = if (maxMoment >= 3) { currentM3 + deltaN * deltaN * delta * (updateM0 - Cast(Literal(1), resultType)) * (updateM0 - Cast(Literal(2), resultType)) - deltaN * currentM2 * Cast(Literal(3), resultType) @@ -152,7 +155,7 @@ abstract class StatisticalMoments(child: Expression) extends AlgebraicAggregate Cast(Literal(0), resultType) } - lazy val updateM4: Expression = if (highestOrder >= 4) { + lazy val updateM4: Expression = if (maxMoment >= 4) { currentM4 + deltaN * deltaN * deltaN * delta * (updateM0 - Cast(Literal(1), resultType)) * (updateM0 * updateM0 - updateM0 * Cast(Literal(3), resultType) + Cast(Literal(3), resultType)) + deltaN * deltaN * currentM2 * @@ -173,19 +176,23 @@ abstract class StatisticalMoments(child: Expression) extends AlgebraicAggregate override val mergeExpressions = { - lazy val M0 = currentM0.left + currentM0.right + lazy val updateM0 = currentM0.left + currentM0.right lazy val delta = currentM1.right - currentM1.left - lazy val deltaN = delta / M0 + lazy val deltaN = delta / updateM0 - lazy val M1 = currentM1.left + delta * (currentM0.right / M0) + lazy val updateM1 = if (maxMoment >= 1) { + currentM1.left + delta * (currentM0.right / updateM0) + } else { + Cast(Literal(0), resultType) + } - lazy val M2 = if (highestOrder >= 2) { + lazy val updateM2 = if (maxMoment >= 2) { currentM2.left + currentM2.right + delta * deltaN * currentM0.left * currentM0.right } else { Cast(Literal(0), resultType) } - lazy val M3 = if (highestOrder >= 3) { + lazy val updateM3 = if (maxMoment >= 3) { currentM3.left + currentM3.right + deltaN * deltaN * delta * currentM0.left * currentM0.right * (currentM0.left - currentM0.right) + deltaN * (currentM0.left * currentM2.right - currentM0.right * currentM2.left) * @@ -194,7 +201,7 @@ abstract class StatisticalMoments(child: Expression) extends AlgebraicAggregate Cast(Literal(0), resultType) } - lazy val M4 = if (highestOrder >= 4) { + lazy val updateM4 = if (maxMoment >= 4) { currentM4.left + currentM4.right + deltaN * deltaN * deltaN * delta * currentM0.left * currentM0.right * (currentM0.left * currentM0.left - currentM0.left * currentM0.right + currentM0.right * currentM0.right) + deltaN * deltaN * Cast(Literal(6), resultType) * @@ -208,24 +215,25 @@ abstract class StatisticalMoments(child: Expression) extends AlgebraicAggregate Seq( /* currentM4 = */ If(IsNull(currentM4.left), currentM4.right, - If(IsNull(currentM4.right), currentM4.left, M4)), + If(IsNull(currentM4.right), currentM4.left, updateM4)), /* currentM3 = */ If(IsNull(currentM3.left), currentM3.right, - If(IsNull(currentM3.right), currentM3.left, M3)), + If(IsNull(currentM3.right), currentM3.left, updateM3)), /* currentM2 = */ If(IsNull(currentM2.left), currentM2.right, - If(IsNull(currentM2.right), currentM2.left, M2)), + If(IsNull(currentM2.right), currentM2.left, updateM2)), /* currentM1 = */ If(IsNull(currentM1.left), currentM1.right, - If(IsNull(currentM1.right), currentM1.left, M1)), + If(IsNull(currentM1.right), currentM1.left, updateM1)), /* currentM0 = */ If(IsNull(currentM0.left), currentM0.right, - If(IsNull(currentM0.right), currentM0.left, M0)) + If(IsNull(currentM0.right), currentM0.left, updateM0)) ) } } -//case class Average(child: Expression) extends StatisticalMoments(child) { +//case class Average(child: Expression) extends CentralMomentAgg(child) { +// +// override def maxMoment: Int = 1 // -// override def highestOrder = 1 +// override def prettyName: String = "average" // -// override def prettyName = "average" // // average = M_1 // override val evaluateExpression: Expression = { // If(EqualTo(currentM0, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), @@ -234,24 +242,96 @@ abstract class StatisticalMoments(child: Expression) extends AlgebraicAggregate // } //} // -//case class Variance(child: Expression) extends StatisticalMoments(child) { +//case class Stddev(child: Expression) extends CentralMomentAgg(child) { +// +// override def maxMoment: Int = 2 +// +// override def prettyName: String = "stddev" +// +// // stddev = sqrt(M_2 / (M_0 - 1)) +// override val evaluateExpression: Expression = { +// If(EqualTo(currentM0, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), +// If(EqualTo(currentM0, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), +// Cast(Sqrt(currentM2 / (currentM0 - Cast(Literal(1), resultType))), resultType))) +// } +//} +// +//case class StddevSamp(child: Expression) extends CentralMomentAgg(child) { +// +// override def maxMoment: Int = 2 +// +// override def prettyName: String = "stddev_samp" +// +// // stddev_samp = sqrt(M_2 / (M_0 - 1)) +// override val evaluateExpression: Expression = { +// If(EqualTo(currentM0, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), +// If(EqualTo(currentM0, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), +// Cast(Sqrt(currentM2 / (currentM0 - Cast(Literal(1), resultType))), resultType))) +// } +//} +// +//case class StddevPop(child: Expression) extends CentralMomentAgg(child) { +// +// override def maxMoment: Int = 2 // -// override def highestOrder = 2 +// override def prettyName: String = "stddev_pop" +// +// // stddev_pop = sqrt(M_2 / M_0) +// override val evaluateExpression: Expression = { +// If(EqualTo(currentM0, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), +// If(EqualTo(currentM0, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), +// Cast(Sqrt(currentM2 / currentM0), resultType))) +// } +//} +// +//case class Variance(child: Expression) extends CentralMomentAgg(child) { +// +// override def maxMoment: Int = 2 +// +// override def prettyName: String = "variance" +// +// // variance = M_2 / (M_0 - 1) +// override val evaluateExpression: Expression = { +// If(EqualTo(currentM0, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), +// If(EqualTo(currentM0, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), +// Cast(currentM2 / (currentM0 - Cast(Literal(1), resultType)), resultType))) +// } +//} +// +//case class VariancePop(child: Expression) extends CentralMomentAgg(child) { +// +// override def maxMoment: Int = 2 +// +// override def prettyName: String = "var_pop" // -// override def prettyName = "variance" // // variance = M_2 / M_0 // override val evaluateExpression: Expression = { // If(EqualTo(currentM0, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), // If(EqualTo(currentM0, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), -// Cast(currentM2 / currentM0, resultType))) +// Cast(currentM2 / currentM0, resultType))) +// } +//} +// +//case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) { +// +// override def maxMoment: Int = 2 +// +// override def prettyName: String = "var_samp" +// +// // variance = M_2 / (M_0 - 1) +// override val evaluateExpression: Expression = { +// If(EqualTo(currentM0, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), +// If(EqualTo(currentM0, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), +// Cast(currentM2 / (currentM0 - Cast(Literal(1), resultType)), resultType))) // } //} -case class Skewness(child: Expression) extends StatisticalMoments(child) { +case class Skewness(child: Expression) extends CentralMomentAgg(child) { - override def highestOrder: Int = 3 + override def maxMoment: Int = 3 override def prettyName: String = "skewness" + // TODO: protect against neg sqrt // skewness = sqrt(M_0) * M_3 / M_2^(3/2) override val evaluateExpression: Expression = { @@ -263,12 +343,14 @@ case class Skewness(child: Expression) extends StatisticalMoments(child) { } } -case class Kurtosis(child: Expression) extends StatisticalMoments(child) { +case class Kurtosis(child: Expression) extends CentralMomentAgg(child) { - override def highestOrder: Int = 4 + override def maxMoment: Int = 4 override def prettyName: String = "kurtosis" + // kurtosis = M_0 * M_4 / M_2^2 - 3 + // NOTE: this is the formula for excess kurtosis, which is default for R and NumPy override val evaluateExpression = { If(EqualTo(currentM0, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), If(EqualTo(currentM0, Cast(Literal(1), resultType)), Cast(Literal(-3), resultType), @@ -289,7 +371,7 @@ case class Count(child: Expression) extends DeclarativeAggregate { // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - private val currentCount = AttributeReference("currentCount", LongType)() + private val currentCount = AttributeReference("currentCount", LongType)() override val aggBufferAttributes = currentCount :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 67a727562ea46..aea88fed11c9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -428,66 +428,6 @@ case class Average(child: Expression) extends UnaryExpression with PartialAggreg TypeUtils.checkForNumericExpr(child.dataType, "function average") } -// placeholder so code will compile -abstract class StatisticalMoments1(child: Expression) extends UnaryExpression with PartialAggregate1 { - - override def nullable: Boolean = true - - override def dataType: DataType = child.dataType match { - case DecimalType.Fixed(precision, scale) => - // Add 4 digits after decimal point, like Hive - DecimalType.bounded(precision + 4, scale + 4) - case _ => - DoubleType - } - - override def asPartial: SplitEvaluation = { - child.dataType match { - case DecimalType.Fixed(precision, scale) => - val partialSum = Alias(Sum(child), "PartialSum")() - val partialCount = Alias(Count(child), "PartialCount")() - - // partialSum already increase the precision by 10 - val castedSum = Cast(Sum(partialSum.toAttribute), partialSum.dataType) - val castedCount = Cast(Sum(partialCount.toAttribute), partialSum.dataType) - SplitEvaluation( - Cast(Divide(castedSum, castedCount), dataType), - partialCount :: partialSum :: Nil) - - case _ => - val partialSum = Alias(Sum(child), "PartialSum")() - val partialCount = Alias(Count(child), "PartialCount")() - - val castedSum = Cast(Sum(partialSum.toAttribute), dataType) - val castedCount = Cast(Sum(partialCount.toAttribute), dataType) - SplitEvaluation( - Divide(castedSum, castedCount), - partialCount :: partialSum :: Nil) - } - } - - override def newInstance(): AverageFunction = new AverageFunction(child, this) - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function average") -} - -// placeholder so code will compile -case class Kurtosis(child: Expression) extends StatisticalMoments1(child) { - - override def prettyName: String = "kurtosis" - - override def toString: String = s"KURTOSIS($child)" -} - -// placeholder so code will compile -case class Skewness(child: Expression) extends StatisticalMoments1(child) { - - override def prettyName: String = "skewness" - - override def toString: String = s"SKEWNESS($child)" -} - case class AverageFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { @@ -1051,3 +991,64 @@ case class StddevFunction( } } } + +// placeholder +abstract class CentralMomentAgg1(child: Expression) + extends UnaryExpression with PartialAggregate1 { + + override def nullable: Boolean = true + + override def dataType: DataType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + // Add 4 digits after decimal point, like Hive + DecimalType.bounded(precision + 4, scale + 4) + case _ => + DoubleType + } + + override def asPartial: SplitEvaluation = { + child.dataType match { + case DecimalType.Fixed(precision, scale) => + val partialSum = Alias(Sum(child), "PartialSum")() + val partialCount = Alias(Count(child), "PartialCount")() + + // partialSum already increase the precision by 10 + val castedSum = Cast(Sum(partialSum.toAttribute), partialSum.dataType) + val castedCount = Cast(Sum(partialCount.toAttribute), partialSum.dataType) + SplitEvaluation( + Cast(Divide(castedSum, castedCount), dataType), + partialCount :: partialSum :: Nil) + + case _ => + val partialSum = Alias(Sum(child), "PartialSum")() + val partialCount = Alias(Count(child), "PartialCount")() + + val castedSum = Cast(Sum(partialSum.toAttribute), dataType) + val castedCount = Cast(Sum(partialCount.toAttribute), dataType) + SplitEvaluation( + Divide(castedSum, castedCount), + partialCount :: partialSum :: Nil) + } + } + + override def newInstance(): AverageFunction = new AverageFunction(child, this) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function average") +} + +// placeholder +case class Kurtosis(child: Expression) extends CentralMomentAgg1(child) { + + override def prettyName: String = "kurtosis" + + override def toString: String = s"KURTOSIS($child)" +} + +// placeholder +case class Skewness(child: Expression) extends CentralMomentAgg1(child) { + + override def prettyName: String = "skewness" + + override def toString: String = s"SKEWNESS($child)" +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 909672dec77c3..0475d98fb8738 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -242,7 +242,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { checkAnswer( emptyTableData.agg(kurtosis('a)), - Row(0.0)) + Row(-3.0)) } test("null moments") { From dc223bc183e120a1378f81ab669acf2b632c4a13 Mon Sep 17 00:00:00 2001 From: sethah Date: Sun, 18 Oct 2015 15:50:43 -0700 Subject: [PATCH 07/22] converting from codegen to imperative aggregate --- .../catalyst/analysis/FunctionRegistry.scala | 3 + .../catalyst/analysis/HiveTypeCoercion.scala | 3 + .../spark/sql/catalyst/dsl/package.scala | 3 + .../expressions/aggregate/functions.scala | 714 ++++++++---------- .../expressions/aggregate/utils.scala | 46 +- .../sql/catalyst/expressions/aggregates.scala | 50 +- .../spark/sql/catalyst/expressions/rows.scala | 2 +- .../org/apache/spark/sql/GroupedData.scala | 42 +- .../org/apache/spark/sql/functions.scala | 117 ++- .../org/apache/spark/sql/SQLQuerySuite.scala | 65 +- 10 files changed, 568 insertions(+), 477 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d513bf8cdfd63..ed9fcfe014f0c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -189,6 +189,9 @@ object FunctionRegistry { expression[StddevPop]("stddev_pop"), expression[StddevSamp]("stddev_samp"), expression[Sum]("sum"), + expression[Variance]("variance"), + expression[VariancePop]("var_pop"), + expression[VarianceSamp]("var_samp"), expression[Skewness]("skewness"), expression[Kurtosis]("kurtosis"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index a8300069ede55..3c675672dab85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -300,6 +300,9 @@ object HiveTypeCoercion { case Stddev(e @ StringType()) => Stddev(Cast(e, DoubleType)) case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType)) case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType)) + case Variance(e @ StringType()) => Variance(Cast(e, DoubleType)) + case VariancePop(e @ StringType()) => VariancePop(Cast(e, DoubleType)) + case VarianceSamp(e @ StringType()) => VarianceSamp(Cast(e, DoubleType)) case Skewness(e @ StringType()) => Skewness(Cast(e, DoubleType)) case Kurtosis(e @ StringType()) => Kurtosis(Cast(e, DoubleType)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index e7224ae243ece..787f67a297e33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -162,6 +162,9 @@ package object dsl { def stddev(e: Expression): Expression = Stddev(e) def stddev_pop(e: Expression): Expression = StddevPop(e) def stddev_samp(e: Expression): Expression = StddevSamp(e) + def variance(e: Expression): Expression = Variance(e) + def var_pop(e: Expression): Expression = VariancePop(e) + def var_samp(e: Expression): Expression = VarianceSamp(e) def skewness(e: Expression): Expression = Skewness(e) def kurtosis(e: Expression): Expression = Kurtosis(e) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 6930d32997e2f..6296230c723ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -90,276 +90,6 @@ case class Average(child: Expression) extends DeclarativeAggregate { } } -abstract class CentralMomentAgg(child: Expression) extends AlgebraicAggregate { - - // specify the maximum order moment needed for the computation - def maxMoment: Int - - override def children: Seq[Expression] = child :: Nil - - override def nullable: Boolean = true - - // Return data type. - override def dataType: DataType = resultType - - // Expected input data type. - // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the - // new version at planning time (after analysis phase). For now, NullType is added at here - // to make it resolved when we have cases like `select avg(null)`. - // We can use our analyzer to cast NullType to the default data type of the NumericType once - // we remove the old aggregate functions. Then, we will not need NullType at here. - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) - - protected val resultType = DoubleType - - protected val currentM4 = AttributeReference("currentM4", resultType)() - protected val currentM3 = AttributeReference("currentM3", resultType)() - protected val currentM2 = AttributeReference("currentM2", resultType)() - protected val currentM1 = AttributeReference("currentM1", resultType)() - protected val currentM0 = AttributeReference("currentM0", resultType)() - - override val bufferAttributes = List(currentM4, currentM3, currentM2, currentM1, currentM0) - - override val initialValues = Seq( - /* currentM4 = */ Cast(Literal(0), resultType), - /* currentM3 = */ Cast(Literal(0), resultType), - /* currentM2 = */ Cast(Literal(0), resultType), - /* currentM1 = */ Cast(Literal(0), resultType), - /* currentM0 = */ Cast(Literal(0), resultType) - ) - - override val updateExpressions = { - lazy val updateM0: Expression = { - Add(currentM0, Cast(Literal(1), resultType)) - } - lazy val delta = Subtract(Cast(child, resultType), currentM1) - lazy val deltaN = Divide(delta, updateM0) - - lazy val updateM1: Expression = if (maxMoment >= 1) { - Add(currentM1, Divide(delta, updateM0)) - } else { - Cast(Literal(0), resultType) - } - - lazy val updateM2: Expression = if (maxMoment >= 2) { - Add(currentM2, Multiply(deltaN * delta, Subtract(updateM0, Cast(Literal(1), resultType)))) - } else { - Cast(Literal(0), resultType) - } - - lazy val updateM3: Expression = if (maxMoment >= 3) { - currentM3 + deltaN * deltaN * delta * (updateM0 - Cast(Literal(1), resultType)) * - (updateM0 - Cast(Literal(2), resultType)) - - deltaN * currentM2 * Cast(Literal(3), resultType) - } else { - Cast(Literal(0), resultType) - } - - lazy val updateM4: Expression = if (maxMoment >= 4) { - currentM4 + deltaN * deltaN * deltaN * delta * (updateM0 - Cast(Literal(1), resultType)) * - (updateM0 * updateM0 - updateM0 * Cast(Literal(3), resultType) + - Cast(Literal(3), resultType)) + deltaN * deltaN * currentM2 * - Cast(Literal(6), resultType) - deltaN * currentM3 * Cast(Literal(4), resultType) - } else { - Cast(Literal(0), resultType) - } - - Seq( - /* currentM4 = */ If(IsNull(child), currentM4, updateM4), - /* currentM3 = */ If(IsNull(child), currentM3, updateM3), - /* currentM2 = */ If(IsNull(child), currentM2, updateM2), - /* currentM1 = */ If(IsNull(child), currentM1, updateM1), - /* currentM0 = */ If(IsNull(child), currentM0, updateM0) - ) - } - - override val mergeExpressions = { - - - lazy val updateM0 = currentM0.left + currentM0.right - lazy val delta = currentM1.right - currentM1.left - lazy val deltaN = delta / updateM0 - - lazy val updateM1 = if (maxMoment >= 1) { - currentM1.left + delta * (currentM0.right / updateM0) - } else { - Cast(Literal(0), resultType) - } - - lazy val updateM2 = if (maxMoment >= 2) { - currentM2.left + currentM2.right + delta * deltaN * currentM0.left * currentM0.right - } else { - Cast(Literal(0), resultType) - } - - lazy val updateM3 = if (maxMoment >= 3) { - currentM3.left + currentM3.right + deltaN * deltaN * delta * currentM0.left * - currentM0.right * (currentM0.left - currentM0.right) + - deltaN * (currentM0.left * currentM2.right - currentM0.right * currentM2.left) * - Cast(Literal(3), resultType) - } else { - Cast(Literal(0), resultType) - } - - lazy val updateM4 = if (maxMoment >= 4) { - currentM4.left + currentM4.right + deltaN * deltaN * deltaN * delta * currentM0.left * - currentM0.right * (currentM0.left * currentM0.left - currentM0.left * currentM0.right + - currentM0.right * currentM0.right) + deltaN * deltaN * Cast(Literal(6), resultType) * - (currentM0.left * currentM0.left * currentM2.right + - currentM0.right * currentM0.right * currentM2.left) + - deltaN * Cast(Literal(4), resultType) * - (currentM0.left * currentM3.right - currentM0.right * currentM3.left) - } else { - Cast(Literal(0), resultType) - } - - Seq( - /* currentM4 = */ If(IsNull(currentM4.left), currentM4.right, - If(IsNull(currentM4.right), currentM4.left, updateM4)), - /* currentM3 = */ If(IsNull(currentM3.left), currentM3.right, - If(IsNull(currentM3.right), currentM3.left, updateM3)), - /* currentM2 = */ If(IsNull(currentM2.left), currentM2.right, - If(IsNull(currentM2.right), currentM2.left, updateM2)), - /* currentM1 = */ If(IsNull(currentM1.left), currentM1.right, - If(IsNull(currentM1.right), currentM1.left, updateM1)), - /* currentM0 = */ If(IsNull(currentM0.left), currentM0.right, - If(IsNull(currentM0.right), currentM0.left, updateM0)) - ) - } -} - -//case class Average(child: Expression) extends CentralMomentAgg(child) { -// -// override def maxMoment: Int = 1 -// -// override def prettyName: String = "average" -// -// // average = M_1 -// override val evaluateExpression: Expression = { -// If(EqualTo(currentM0, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), -// If(EqualTo(currentM0, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), -// Cast(currentM1, resultType))) -// } -//} -// -//case class Stddev(child: Expression) extends CentralMomentAgg(child) { -// -// override def maxMoment: Int = 2 -// -// override def prettyName: String = "stddev" -// -// // stddev = sqrt(M_2 / (M_0 - 1)) -// override val evaluateExpression: Expression = { -// If(EqualTo(currentM0, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), -// If(EqualTo(currentM0, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), -// Cast(Sqrt(currentM2 / (currentM0 - Cast(Literal(1), resultType))), resultType))) -// } -//} -// -//case class StddevSamp(child: Expression) extends CentralMomentAgg(child) { -// -// override def maxMoment: Int = 2 -// -// override def prettyName: String = "stddev_samp" -// -// // stddev_samp = sqrt(M_2 / (M_0 - 1)) -// override val evaluateExpression: Expression = { -// If(EqualTo(currentM0, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), -// If(EqualTo(currentM0, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), -// Cast(Sqrt(currentM2 / (currentM0 - Cast(Literal(1), resultType))), resultType))) -// } -//} -// -//case class StddevPop(child: Expression) extends CentralMomentAgg(child) { -// -// override def maxMoment: Int = 2 -// -// override def prettyName: String = "stddev_pop" -// -// // stddev_pop = sqrt(M_2 / M_0) -// override val evaluateExpression: Expression = { -// If(EqualTo(currentM0, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), -// If(EqualTo(currentM0, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), -// Cast(Sqrt(currentM2 / currentM0), resultType))) -// } -//} -// -//case class Variance(child: Expression) extends CentralMomentAgg(child) { -// -// override def maxMoment: Int = 2 -// -// override def prettyName: String = "variance" -// -// // variance = M_2 / (M_0 - 1) -// override val evaluateExpression: Expression = { -// If(EqualTo(currentM0, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), -// If(EqualTo(currentM0, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), -// Cast(currentM2 / (currentM0 - Cast(Literal(1), resultType)), resultType))) -// } -//} -// -//case class VariancePop(child: Expression) extends CentralMomentAgg(child) { -// -// override def maxMoment: Int = 2 -// -// override def prettyName: String = "var_pop" -// -// // variance = M_2 / M_0 -// override val evaluateExpression: Expression = { -// If(EqualTo(currentM0, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), -// If(EqualTo(currentM0, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), -// Cast(currentM2 / currentM0, resultType))) -// } -//} -// -//case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) { -// -// override def maxMoment: Int = 2 -// -// override def prettyName: String = "var_samp" -// -// // variance = M_2 / (M_0 - 1) -// override val evaluateExpression: Expression = { -// If(EqualTo(currentM0, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), -// If(EqualTo(currentM0, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), -// Cast(currentM2 / (currentM0 - Cast(Literal(1), resultType)), resultType))) -// } -//} - -case class Skewness(child: Expression) extends CentralMomentAgg(child) { - - override def maxMoment: Int = 3 - - override def prettyName: String = "skewness" - - // TODO: protect against neg sqrt - // skewness = sqrt(M_0) * M_3 / M_2^(3/2) - override val evaluateExpression: Expression = { - If(EqualTo(currentM0, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), - If(EqualTo(currentM0, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), - If(EqualTo(currentM2, Cast(Literal(0), resultType)), Cast(Literal(0), resultType), - Cast(Sqrt(currentM0) * currentM3 / - Sqrt(currentM2 * currentM2 * currentM2), resultType)))) - } -} - -case class Kurtosis(child: Expression) extends CentralMomentAgg(child) { - - override def maxMoment: Int = 4 - - override def prettyName: String = "kurtosis" - - // kurtosis = M_0 * M_4 / M_2^2 - 3 - // NOTE: this is the formula for excess kurtosis, which is default for R and NumPy - override val evaluateExpression = { - If(EqualTo(currentM0, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), - If(EqualTo(currentM0, Cast(Literal(1), resultType)), Cast(Literal(-3), resultType), - If(EqualTo(currentM2, Cast(Literal(0), resultType)), Cast(Literal(0), resultType), - Cast(currentM0 * currentM4 / (currentM2 * currentM2) - - Cast(Literal(3), resultType), resultType)))) - } -} - case class Count(child: Expression) extends DeclarativeAggregate { override def children: Seq[Expression] = child :: Nil @@ -597,149 +327,6 @@ case class Min(child: Expression) extends DeclarativeAggregate { override val evaluateExpression = min } -// Compute the sample standard deviation of a column -case class Stddev(child: Expression) extends StddevAgg(child) { - - override def isSample: Boolean = true - override def prettyName: String = "stddev" -} - -// Compute the population standard deviation of a column -case class StddevPop(child: Expression) extends StddevAgg(child) { - - override def isSample: Boolean = false - override def prettyName: String = "stddev_pop" -} - -// Compute the sample standard deviation of a column -case class StddevSamp(child: Expression) extends StddevAgg(child) { - - override def isSample: Boolean = true - override def prettyName: String = "stddev_samp" -} - -// Compute standard deviation based on online algorithm specified here: -// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance -abstract class StddevAgg(child: Expression) extends DeclarativeAggregate { - - override def children: Seq[Expression] = child :: Nil - - override def nullable: Boolean = true - - def isSample: Boolean - - // Return data type. - override def dataType: DataType = resultType - - // Expected input data type. - // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the - // new version at planning time (after analysis phase). For now, NullType is added at here - // to make it resolved when we have cases like `select stddev(null)`. - // We can use our analyzer to cast NullType to the default data type of the NumericType once - // we remove the old aggregate functions. Then, we will not need NullType at here. - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) - - private val resultType = DoubleType - - private val preCount = AttributeReference("preCount", resultType)() - private val currentCount = AttributeReference("currentCount", resultType)() - private val preAvg = AttributeReference("preAvg", resultType)() - private val currentAvg = AttributeReference("currentAvg", resultType)() - private val currentMk = AttributeReference("currentMk", resultType)() - - override val aggBufferAttributes = preCount :: currentCount :: preAvg :: - currentAvg :: currentMk :: Nil - - override val initialValues = Seq( - /* preCount = */ Cast(Literal(0), resultType), - /* currentCount = */ Cast(Literal(0), resultType), - /* preAvg = */ Cast(Literal(0), resultType), - /* currentAvg = */ Cast(Literal(0), resultType), - /* currentMk = */ Cast(Literal(0), resultType) - ) - - override val updateExpressions = { - - // update average - // avg = avg + (value - avg)/count - def avgAdd: Expression = { - currentAvg + ((Cast(child, resultType) - currentAvg) / currentCount) - } - - // update sum of square of difference from mean - // Mk = Mk + (value - preAvg) * (value - updatedAvg) - def mkAdd: Expression = { - val delta1 = Cast(child, resultType) - preAvg - val delta2 = Cast(child, resultType) - currentAvg - currentMk + (delta1 * delta2) - } - - Seq( - /* preCount = */ If(IsNull(child), preCount, currentCount), - /* currentCount = */ If(IsNull(child), currentCount, - Add(currentCount, Cast(Literal(1), resultType))), - /* preAvg = */ If(IsNull(child), preAvg, currentAvg), - /* currentAvg = */ If(IsNull(child), currentAvg, avgAdd), - /* currentMk = */ If(IsNull(child), currentMk, mkAdd) - ) - } - - override val mergeExpressions = { - - // count merge - def countMerge: Expression = { - currentCount.left + currentCount.right - } - - // average merge - def avgMerge: Expression = { - ((currentAvg.left * preCount) + (currentAvg.right * currentCount.right)) / - (preCount + currentCount.right) - } - - // update sum of square differences - def mkMerge: Expression = { - val avgDelta = currentAvg.right - preAvg - val mkDelta = (avgDelta * avgDelta) * (preCount * currentCount.right) / - (preCount + currentCount.right) - - currentMk.left + currentMk.right + mkDelta - } - - Seq( - /* preCount = */ If(IsNull(currentCount.left), - Cast(Literal(0), resultType), currentCount.left), - /* currentCount = */ If(IsNull(currentCount.left), currentCount.right, - If(IsNull(currentCount.right), currentCount.left, countMerge)), - /* preAvg = */ If(IsNull(currentAvg.left), Cast(Literal(0), resultType), currentAvg.left), - /* currentAvg = */ If(IsNull(currentAvg.left), currentAvg.right, - If(IsNull(currentAvg.right), currentAvg.left, avgMerge)), - /* currentMk = */ If(IsNull(currentMk.left), currentMk.right, - If(IsNull(currentMk.right), currentMk.left, mkMerge)) - ) - } - - override val evaluateExpression = { - // when currentCount == 0, return null - // when currentCount == 1, return 0 - // when currentCount >1 - // stddev_samp = sqrt (currentMk/(currentCount -1)) - // stddev_pop = sqrt (currentMk/currentCount) - val varCol = { - if (isSample) { - currentMk / Cast((currentCount - Cast(Literal(1), resultType)), resultType) - } - else { - currentMk / currentCount - } - } - - If(EqualTo(currentCount, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), - If(EqualTo(currentCount, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), - Cast(Sqrt(varCol), resultType))) - } -} - case class Sum(child: Expression) extends DeclarativeAggregate { override def children: Seq[Expression] = child :: Nil @@ -1200,3 +787,304 @@ object HyperLogLogPlusPlus { ) // scalastyle:on } + +abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate with Serializable { + + /** + * The maximum central moment order to be computed + */ + protected def maxMoment: Int + + override def children: Seq[Expression] = Seq(child) + + override def nullable: Boolean = false + + override def dataType: DataType = DoubleType + + // Expected input data type. + // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the + // new version at planning time (after analysis phase). For now, NullType is added at here + // to make it resolved when we have cases like `select avg(null)`. + // We can use our analyzer to cast NullType to the default data type of the NumericType once + // we remove the old aggregate functions. Then, we will not need NullType at here. + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) + + override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + + def cloneBufferAttributes: Seq[Attribute] = aggBufferAttributes.map(_.newInstance()) + + /** + * The number of central moments to store in the buffer + */ + private[this] val numMoments = 5 + + override val aggBufferAttributes: Seq[AttributeReference] = Seq.tabulate(numMoments) { i => + AttributeReference(s"M$i", DoubleType)() + } + + /** + * Initialize all moments to zero + */ + override def initialize(buffer: MutableRow): Unit = { + var aggIndex = 0 + while (aggIndex < numMoments) { + buffer.setDouble(mutableAggBufferOffset + aggIndex, 0.0) + aggIndex += 1 + } + } + + /** + * Update the central moments buffer. + */ + override def update(buffer: MutableRow, input: InternalRow): Unit = { + val v = child.eval(input) + if (v != null) { + val updateValue = v match { + case d: java.lang.Number => d.doubleValue() + case _ => 0.0 + } + val currentM0 = buffer.getDouble(mutableAggBufferOffset) + val currentM1 = buffer.getDouble(mutableAggBufferOffset + 1) + val currentM2 = buffer.getDouble(mutableAggBufferOffset + 2) + val currentM3 = buffer.getDouble(mutableAggBufferOffset + 3) + val currentM4 = buffer.getDouble(mutableAggBufferOffset + 4) + + val updateM0 = currentM0 + 1.0 + val delta = updateValue - currentM1 + val deltaN = delta / updateM0 + + val updateM1 = currentM1 + delta / updateM0 + val updateM2 = if (maxMoment >= 2) { + currentM2 + delta * (delta - deltaN) + } else { + 0.0 + } + val delta2 = delta * delta + val deltaN2 = deltaN * deltaN + val updateM3 = if (maxMoment >= 3) { + currentM3 - 3.0 * deltaN * updateM2 + delta * (delta2 - deltaN2) + } else { + 0.0 + } + val updateM4 = if (maxMoment >= 4) { + currentM4 - 4.0 * deltaN * updateM3 - 6.0 * deltaN2 * updateM2 + + delta * (delta * delta2 - deltaN * deltaN2) + } else { + 0.0 + } + + buffer.setDouble(mutableAggBufferOffset, updateM0) + buffer.setDouble(mutableAggBufferOffset + 1, updateM1) + buffer.setDouble(mutableAggBufferOffset + 2, updateM2) + buffer.setDouble(mutableAggBufferOffset + 3, updateM3) + buffer.setDouble(mutableAggBufferOffset + 4, updateM4) + } + } + + /** Merge two central moment buffers. */ + override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + val zeroMoment1 = buffer1.getDouble(mutableAggBufferOffset) + val zeroMoment2 = buffer2.getDouble(inputAggBufferOffset) + val firstMoment1 = buffer1.getDouble(mutableAggBufferOffset + 1) + val firstMoment2 = buffer2.getDouble(inputAggBufferOffset + 1) + val secondMoment1 = buffer1.getDouble(mutableAggBufferOffset + 2) + val secondMoment2 = buffer2.getDouble(inputAggBufferOffset + 2) + val thirdMoment1 = buffer1.getDouble(mutableAggBufferOffset + 3) + val thirdMoment2 = buffer2.getDouble(inputAggBufferOffset + 3) + val fourthMoment1 = buffer1.getDouble(mutableAggBufferOffset + 4) + val fourthMoment2 = buffer2.getDouble(inputAggBufferOffset + 4) + + val zeroMoment = zeroMoment1 + zeroMoment2 + val delta = firstMoment2 - firstMoment1 + val deltaN = delta / zeroMoment + + val firstMoment = firstMoment1 + deltaN * zeroMoment2 + + val secondMoment = if (maxMoment >= 2) { + secondMoment1 + secondMoment2 + delta * deltaN * zeroMoment1 * zeroMoment2 + } else { + 0.0 + } + + val thirdMoment = if (maxMoment >= 3) { + thirdMoment1 + thirdMoment2 + deltaN * deltaN * delta * zeroMoment1 * zeroMoment2 * + (zeroMoment1 - zeroMoment2) + 3.0 * deltaN * + (zeroMoment1 * secondMoment2 - zeroMoment2 * secondMoment1) + } else { + 0.0 + } + + val fourthMoment = if (maxMoment >= 4) { + fourthMoment1 + fourthMoment2 + deltaN * deltaN * deltaN * delta * zeroMoment1 * + zeroMoment2 * (zeroMoment1 * zeroMoment1 - zeroMoment1 * zeroMoment2 + + zeroMoment2 * zeroMoment2) + deltaN * deltaN * 6.0 * + (zeroMoment1 * zeroMoment1 * secondMoment2 + zeroMoment2 * zeroMoment2 * secondMoment1) + + 4.0 * deltaN * (zeroMoment1 * thirdMoment2 - zeroMoment2 * thirdMoment1) + } else { + 0.0 + } + + buffer1.setDouble(mutableAggBufferOffset, zeroMoment) + buffer1.setDouble(mutableAggBufferOffset + 1, firstMoment) + buffer1.setDouble(mutableAggBufferOffset + 2, secondMoment) + buffer1.setDouble(mutableAggBufferOffset + 3, thirdMoment) + buffer1.setDouble(mutableAggBufferOffset + 4, fourthMoment) + } +} + +case class Stddev(child: Expression) extends CentralMomentAgg(child) { + + override def prettyName: String = "stddev" + + protected val maxMoment = 2 + + def eval(buffer: InternalRow): Any = { + // stddev = sqrt(M2 / (M0 - 1)) + val M0 = buffer.getDouble(mutableAggBufferOffset) + val M2 = buffer.getDouble(mutableAggBufferOffset + 2) + + if (M0 == 0.0) { + 0.0 + } else { + math.sqrt(M2 / (M0 - 1.0)) + } + } +} + +case class StddevSamp(child: Expression) extends CentralMomentAgg(child) { + + override def prettyName: String = "stddev_samp" + + protected val maxMoment = 2 + + override def eval(buffer: InternalRow): Any = { + // stddev_samp = sqrt(M2 / (M0 - 1)) + val M0 = buffer.getDouble(mutableAggBufferOffset) + val M2 = buffer.getDouble(mutableAggBufferOffset + 2) + + if (M0 == 0.0) { + 0.0 + } else { + math.sqrt(M2 / (M0 - 1.0)) + } + } +} + +case class StddevPop(child: Expression) extends CentralMomentAgg(child) { + + override def prettyName: String = "stddev_pop" + + val maxMoment = 2 + + override def eval(buffer: InternalRow): Any = { + // stddev_pop = sqrt(M2 / M0) + val M0 = buffer.getDouble(mutableAggBufferOffset) + val M2 = buffer.getDouble(mutableAggBufferOffset + 2) + + if (M0 == 0.0) { + 0.0 + } else { + math.sqrt(M2 / M0) + } + } +} + +case class Variance(child: Expression) extends CentralMomentAgg(child) { + + override def prettyName: String = "variance" + + protected val maxMoment = 2 + + override def eval(buffer: InternalRow): Any = { + // stddev = M2 / (M0 - 1) + val M0 = buffer.getDouble(mutableAggBufferOffset) + val M2 = buffer.getDouble(mutableAggBufferOffset + 2) + + if (M0 == 0.0) { + 0.0 + } else { + M2 / (M0 - 1.0) + } + } +} + +case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) { + + override def prettyName: String = "variance_samp" + + protected val maxMoment = 2 + + override def eval(buffer: InternalRow): Any = { + // var_samp = M2 / (M0 - 1) + val M0 = buffer.getDouble(mutableAggBufferOffset) + val M2 = buffer.getDouble(mutableAggBufferOffset + 2) + + if (M0 == 0.0) { + 0.0 + } else { + M2 / (M0 - 1.0) + } + } +} + +case class VariancePop(child: Expression) extends CentralMomentAgg(child) { + + override def prettyName: String = "variance_pop" + + val maxMoment = 2 + + override def eval(buffer: InternalRow): Any = { + // var_pop = M2 / M0 + val M0 = buffer.getDouble(mutableAggBufferOffset) + val M2 = buffer.getDouble(mutableAggBufferOffset + 2) + + if (M0 == 0.0) { + 0.0 + } else { + M2 / M0 + } + } +} + +case class Skewness(child: Expression) extends CentralMomentAgg(child) { + + override def prettyName: String = "skewness" + + protected val maxMoment = 3 + + override def eval(buffer: InternalRow): Any = { + // skewness = sqrt(M0) * M3 / sqrt(M2^3) + val M0 = buffer.getDouble(mutableAggBufferOffset) + val M2 = buffer.getDouble(mutableAggBufferOffset + 2) + val M3 = buffer.getDouble(mutableAggBufferOffset + 3) + + if (M0 == 0.0 || M2 == 0.0) { + 0.0 + } else { + math.sqrt(M0) * M3 / math.sqrt(M2 * M2 * M2) + } + } +} + +case class Kurtosis(child: Expression) extends CentralMomentAgg(child) { + + override def prettyName: String = "kurtosis" + + protected val maxMoment = 4 + + override def eval(buffer: InternalRow): Any = { + // kurtosis = M0 * M4 / M2^2 - 3.0 + // NOTE: this is the formula for excess kurtosis, which is default for R and NumPy + val M0 = buffer.getDouble(mutableAggBufferOffset) + val M2 = buffer.getDouble(mutableAggBufferOffset + 2) + val M4 = buffer.getDouble(mutableAggBufferOffset + 4) + + if (M0 == 0.0) { + 0.0 + } else if (M2 == 0.0) { + -3.0 + } else { + M0 * M4 / (M2 * M2) - 3.0 + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala index 1be66b0747a60..59f33ad760583 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala @@ -48,18 +48,6 @@ object Utils { mode = aggregate.Complete, isDistinct = false) - case expressions.Skewness(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Skewness(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Kurtosis(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Kurtosis(child), - mode = aggregate.Complete, - isDistinct = false) - case expressions.Count(child) => aggregate.AggregateExpression2( aggregateFunction = aggregate.Count(child), @@ -79,6 +67,12 @@ object Utils { mode = aggregate.Complete, isDistinct = false) + case expressions.Kurtosis(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Kurtosis(child), + mode = aggregate.Complete, + isDistinct = false) + case expressions.Last(child, ignoreNulls) => aggregate.AggregateExpression2( aggregateFunction = aggregate.Last(child, ignoreNulls), @@ -97,6 +91,12 @@ object Utils { mode = aggregate.Complete, isDistinct = false) + case expressions.Skewness(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Skewness(child), + mode = aggregate.Complete, + isDistinct = false) + case expressions.Stddev(child) => aggregate.AggregateExpression2( aggregateFunction = aggregate.Stddev(child), @@ -127,11 +127,31 @@ object Utils { mode = aggregate.Complete, isDistinct = true) - case expressions.ApproxCountDistinct(child, rsd) => + case expressions.ApproxCountDistinct(child, rsd) => { aggregate.AggregateExpression2( aggregateFunction = aggregate.HyperLogLogPlusPlus(child, rsd), mode = aggregate.Complete, isDistinct = false) + } + + case expressions.Variance(child) => { + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Variance(child), + mode = aggregate.Complete, + isDistinct = false) + } + + case expressions.VariancePop(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.VariancePop(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.VarianceSamp(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.VarianceSamp(child), + mode = aggregate.Complete, + isDistinct = false) } // Check if there is any expressions.AggregateExpression1 left. // If so, we cannot convert this plan. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index aea88fed11c9a..e5570c707bf96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -1023,7 +1023,8 @@ abstract class CentralMomentAgg1(child: Expression) val partialSum = Alias(Sum(child), "PartialSum")() val partialCount = Alias(Count(child), "PartialCount")() - val castedSum = Cast(Sum(partialSum.toAttribute), dataType) +// val castedSum = Cast(Sum(partialSum.toAttribute), dataType) + val castedSum = Cast(Literal(0.0), dataType) val castedCount = Cast(Sum(partialCount.toAttribute), dataType) SplitEvaluation( Divide(castedSum, castedCount), @@ -1052,3 +1053,50 @@ case class Skewness(child: Expression) extends CentralMomentAgg1(child) { override def toString: String = s"SKEWNESS($child)" } + +// placeholder +case class MySkewness(child: Expression) extends CentralMomentAgg1(child) { + + override def prettyName: String = "skewness" + + override def toString: String = s"SKEWNESS($child)" +} + +// placeholder +//case class Variance(child: Expression) extends CentralMomentAgg1(child) { +// +// override def prettyName: String = "variance" +// +// override def toString: String = s"VARIANCE($child)" +//} + +// Compute the sample standard deviation of a column +case class Variance(child: Expression) extends StddevAgg1(child) { + + override def toString: String = s"VARIANCE($child)" + override def isSample: Boolean = true +} + +// placeholder +case class VariancePop(child: Expression) extends CentralMomentAgg1(child) { + + override def prettyName: String = "variance_pop" + + override def toString: String = s"VAR_POP($child)" +} + +// placeholder +case class VarianceSamp(child: Expression) extends CentralMomentAgg1(child) { + + override def prettyName: String = "variance_samp" + + override def toString: String = s"VAR_SAMP($child)" +} + +// placeholder +case class MyKurtosis(child: Expression) extends CentralMomentAgg1(child) { + + override def prettyName: String = "kurtosis" + + override def toString: String = s"KURTOSIS($child)" +} \ No newline at end of file diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 017efd2a166a7..ddfcba7cfb0d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -167,7 +167,7 @@ abstract class MutableRow extends InternalRow { // default implementation (slow) def setBoolean(i: Int, value: Boolean): Unit = { update(i, value) } - def setByte(i: Int, value: Byte): Unit = { update(i, value) } + def setByte(i: Int, value: Byte): Unit = { (i, value) } def setShort(i: Int, value: Short): Unit = { update(i, value) } def setInt(i: Int, value: Int): Unit = { update(i, value) } def setLong(i: Int, value: Long): Unit = { update(i, value) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index a524314532bad..5f1b299c0cdb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -127,6 +127,9 @@ class GroupedData protected[sql]( case "stddev" => Stddev case "stddev_pop" => StddevPop case "stddev_samp" => StddevSamp + case "variance" => Variance + case "var_pop" => VariancePop + case "var_samp" => VarianceSamp case "sum" => Sum case "skewness" => Skewness case "kurtosis" => Kurtosis @@ -253,7 +256,7 @@ class GroupedData protected[sql]( } /** - * Compute the skewness for each numeric columns for each group. This is an alias for `skewness`. + * Compute the skewness for each numeric columns for each group. * The resulting [[DataFrame]] will also contain the grouping columns. * When specified columns are given, only compute the skewness values for them. * @@ -265,7 +268,6 @@ class GroupedData protected[sql]( } /** - * Compute the kurtosis for each numeric columns for each group. This is an alias for `kurtosis`. * The resulting [[DataFrame]] will also contain the grouping columns. * When specified columns are given, only compute the kurtosis values for them. * @@ -359,4 +361,40 @@ class GroupedData protected[sql]( def sum(colNames: String*): DataFrame = { aggregateNumericColumns(colNames : _*)(Sum) } + + /** + * Compute the sample variance for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the variance for them. + * + * @since 1.6.0 + */ + @scala.annotation.varargs + def variance(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(Variance) + } + + /** + * Compute the population variance for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the variance for them. + * + * @since 1.6.0 + */ + @scala.annotation.varargs + def var_pop(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(VariancePop) + } + + /** + * Compute the sample variance for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the variance for them. + * + * @since 1.6.0 + */ + @scala.annotation.varargs + def var_samp(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(VarianceSamp) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index a052745a1de34..08ad17b9b5b67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -228,6 +228,22 @@ object functions { */ def first(columnName: String): Column = first(Column(columnName)) + /** + * Aggregate function: returns the kurtosis of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def kurtosis(e: Column): Column = Kurtosis(e.expr) + + /** + * Aggregate function: returns the kurtosis of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def kurtosis(columnName: String): Column = kurtosis(Column(columnName)) + /** * Aggregate function: returns the last value in a group. * @@ -279,65 +295,72 @@ object functions { def mean(columnName: String): Column = avg(columnName) /** - * Aggregate function: returns the skewness of the values in a group. - * Alias for skewness. + * Aggregate function: returns the minimum value of the expression in a group. * * @group agg_funcs - * @since 1.6.0 + * @since 1.3.0 */ - def skewness(e: Column): Column = Skewness(e.expr) + def min(e: Column): Column = Min(e.expr) + + /** + * Aggregate function: returns the minimum value of the column in a group. + * + * @group agg_funcs + * @since 1.3.0 + */ + def min(columnName: String): Column = min(Column(columnName)) /** * Aggregate function: returns the skewness of the values in a group. - * Alias for skewness. * * @group agg_funcs * @since 1.6.0 */ - def skewness(columnName: String): Column = skewness(Column(columnName)) + def skewness(e: Column): Column = Skewness(e.expr) /** - * Aggregate function: returns the kurtosis of the values in a group. - * Alias for kurtosis. + * Aggregate function: returns the skewness of the values in a group. * * @group agg_funcs * @since 1.6.0 */ - def kurtosis(e: Column): Column = Kurtosis(e.expr) + def skewness(columnName: String): Column = skewness(Column(columnName)) /** - * Aggregate function: returns the kurtosis of the values in a group. - * Alias for kurtosis. + * Aggregate function: returns the unbiased sample standard deviation of + * the expression in a group. * * @group agg_funcs * @since 1.6.0 */ - def kurtosis(columnName: String): Column = kurtosis(Column(columnName)) + def stddev(e: Column): Column = Stddev(e.expr) /** - * Aggregate function: returns the minimum value of the expression in a group. + * Aggregate function: returns the unbiased sample standard deviation of + * the expression in a group. * * @group agg_funcs - * @since 1.3.0 + * @since 1.6.0 */ - def min(e: Column): Column = Min(e.expr) + def stddev(columnName: String): Column = stddev(Column(columnName)) /** - * Aggregate function: returns the minimum value of the column in a group. + * Aggregate function: returns the unbiased sample standard deviation of + * the expression in a group. * * @group agg_funcs - * @since 1.3.0 + * @since 1.6.0 */ - def min(columnName: String): Column = min(Column(columnName)) + def stddev_samp(e: Column): Column = StddevSamp(e.expr) /** - * Aggregate function: returns the unbiased sample standard deviation - * of the expression in a group. + * Aggregate function: returns the unbiased sample standard deviation of + * the expression in a group. * * @group agg_funcs * @since 1.6.0 */ - def stddev(e: Column): Column = Stddev(e.expr) + def stddev_samp(columnName: String): Column = stddev_samp(Column(columnName)) /** * Aggregate function: returns the population standard deviation of @@ -349,13 +372,13 @@ object functions { def stddev_pop(e: Column): Column = StddevPop(e.expr) /** - * Aggregate function: returns the unbiased sample standard deviation of + * Aggregate function: returns the population standard deviation of * the expression in a group. * * @group agg_funcs * @since 1.6.0 */ - def stddev_samp(e: Column): Column = StddevSamp(e.expr) + def stddev_pop(columnName: String): Column = stddev_pop(Column(columnName)) /** * Aggregate function: returns the sum of all values in the expression. @@ -389,6 +412,54 @@ object functions { */ def sumDistinct(columnName: String): Column = sumDistinct(Column(columnName)) + /** + * Aggregate function: returns the unbiased variance of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def variance(e: Column): Column = Variance(e.expr) + + /** + * Aggregate function: returns the unbiased variance of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def variance(columnName: String): Column = variance(Column(columnName)) + + /** + * Aggregate function: returns the unbiased variance of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def var_samp(e: Column): Column = VarianceSamp(e.expr) + + /** + * Aggregate function: returns the unbiased variance of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def var_samp(columnName: String): Column = var_samp(Column(columnName)) + + /** + * Aggregate function: returns the population variance of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def var_pop(e: Column): Column = VariancePop(e.expr) + + /** + * Aggregate function: returns the population variance of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def var_pop(columnName: String): Column = var_pop(Column(columnName)) + ////////////////////////////////////////////////////////////////////////////////////////////// // Window functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index c5b05fcef9e0f..def4dd6c45e3e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -329,27 +329,24 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { testCodeGen( "SELECT min(key) FROM testData3x", Row(1) :: Nil) - // STDDEV - testCodeGen( - "SELECT a, stddev(b), stddev_pop(b) FROM testData2 GROUP BY a", - (1 to 3).map(i => Row(i, math.sqrt(0.5), math.sqrt(0.25)))) - testCodeGen( - "SELECT stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2", - Row(math.sqrt(1.5 / 5), math.sqrt(1.5 / 6), math.sqrt(1.5 / 5)) :: Nil) - // SKEWNESS - testCodeGen( - "SELECT a, skewness(b) FROM testData2 GROUP BY a", - (1 to 3).map(i => Row(i, 0.0))) - testCodeGen( - "SELECT skewness(b) FROM testData2", - Row(0.0) :: Nil) - // KURTOSIS - testCodeGen( - "SELECT a, kurtosis(b) FROM testData2 GROUP BY a", - (1 to 3).map(i => Row(i, -2.0))) - testCodeGen( - "SELECT kurtosis(b) FROM testData2", - Row(-2.0) :: Nil) +// // STDDEV +// testCodeGen( +// "SELECT a, stddev(b), stddev_pop(b) FROM testData2 GROUP BY a", +// (1 to 3).map(i => Row(i, math.sqrt(0.5), math.sqrt(0.25)))) +// // SKEWNESS +// testCodeGen( +// "SELECT a, skewness(b) FROM testData2 GROUP BY a", +// (1 to 3).map(i => Row(i, 0.0))) +// testCodeGen( +// "SELECT skewness(b) FROM testData2", +// Row(0.0) :: Nil) +// // KURTOSIS +// testCodeGen( +// "SELECT a, kurtosis(b) FROM testData2 GROUP BY a", +// (1 to 3).map(i => Row(i, -2.0))) +// testCodeGen( +// "SELECT kurtosis(b) FROM testData2", +// Row(-2.0) :: Nil) // Some combinations. testCodeGen( """ @@ -370,9 +367,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(100, 1, 50.5, 300, 100) :: Nil) // Aggregate with Code generation handling all null values testCodeGen( - "SELECT sum('a'), avg('a'), stddev('a'), skewness('a')," + - "kurtosis('a'), count(null) FROM testData", - Row(null, null, null, null, null, 0) :: Nil) + "SELECT sum('a'), avg('a'), count(null) FROM testData", + Row(null, null, 0) :: Nil) } finally { sqlContext.dropTempTable("testData3x") sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue) @@ -751,6 +747,27 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) } + test("var_samp") { + checkAnswer( + sql("SELECT VAR_SAMP(a) FROM testData2"), + Row(4/5.0) + ) + } + + test("variance") { + checkAnswer( + sql("SELECT VARIANCE(a) FROM testData2"), + Row(4/5.0) + ) + } + + test("var_pop") { + checkAnswer( + sql("SELECT VAR_POP(a) FROM testData2"), + Row(4/6.0) + ) + } + test("skewness") { checkAnswer( sql("SELECT skewness(a) FROM testData2"), From 83fb6825e7412df373a6a71921dc7c539d7efab3 Mon Sep 17 00:00:00 2001 From: sethah Date: Mon, 19 Oct 2015 08:59:11 -0700 Subject: [PATCH 08/22] cleaning up and style fixes --- .../expressions/aggregate/utils.scala | 6 ++--- .../sql/catalyst/expressions/aggregates.scala | 27 ++----------------- .../sql/catalyst/expressions/arithmetic.scala | 1 - .../spark/sql/catalyst/expressions/rows.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 18 ------------- 5 files changed, 5 insertions(+), 49 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala index 59f33ad760583..c911ec53f1ba0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala @@ -127,19 +127,17 @@ object Utils { mode = aggregate.Complete, isDistinct = true) - case expressions.ApproxCountDistinct(child, rsd) => { + case expressions.ApproxCountDistinct(child, rsd) => aggregate.AggregateExpression2( aggregateFunction = aggregate.HyperLogLogPlusPlus(child, rsd), mode = aggregate.Complete, isDistinct = false) - } - case expressions.Variance(child) => { + case expressions.Variance(child) => aggregate.AggregateExpression2( aggregateFunction = aggregate.Variance(child), mode = aggregate.Complete, isDistinct = false) - } case expressions.VariancePop(child) => aggregate.AggregateExpression2( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index e5570c707bf96..554746f0e31e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -1055,26 +1055,11 @@ case class Skewness(child: Expression) extends CentralMomentAgg1(child) { } // placeholder -case class MySkewness(child: Expression) extends CentralMomentAgg1(child) { +case class Variance(child: Expression) extends CentralMomentAgg1(child) { - override def prettyName: String = "skewness" - - override def toString: String = s"SKEWNESS($child)" -} - -// placeholder -//case class Variance(child: Expression) extends CentralMomentAgg1(child) { -// -// override def prettyName: String = "variance" -// -// override def toString: String = s"VARIANCE($child)" -//} - -// Compute the sample standard deviation of a column -case class Variance(child: Expression) extends StddevAgg1(child) { + override def prettyName: String = "variance" override def toString: String = s"VARIANCE($child)" - override def isSample: Boolean = true } // placeholder @@ -1091,12 +1076,4 @@ case class VarianceSamp(child: Expression) extends CentralMomentAgg1(child) { override def prettyName: String = "variance_samp" override def toString: String = s"VAR_SAMP($child)" -} - -// placeholder -case class MyKurtosis(child: Expression) extends CentralMomentAgg1(child) { - - override def prettyName: String = "kurtosis" - - override def toString: String = s"KURTOSIS($child)" } \ No newline at end of file diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 9031da6fcf1af..61a17fd7db0fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -93,7 +93,6 @@ case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes protected override def nullSafeEval(input: Any): Any = numeric.abs(input) } - abstract class BinaryArithmetic extends BinaryOperator { override def dataType: DataType = left.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index ddfcba7cfb0d0..017efd2a166a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -167,7 +167,7 @@ abstract class MutableRow extends InternalRow { // default implementation (slow) def setBoolean(i: Int, value: Boolean): Unit = { update(i, value) } - def setByte(i: Int, value: Byte): Unit = { (i, value) } + def setByte(i: Int, value: Byte): Unit = { update(i, value) } def setShort(i: Int, value: Short): Unit = { update(i, value) } def setInt(i: Int, value: Int): Unit = { update(i, value) } def setLong(i: Int, value: Long): Unit = { update(i, value) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index def4dd6c45e3e..cad679950c5d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -329,24 +329,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { testCodeGen( "SELECT min(key) FROM testData3x", Row(1) :: Nil) -// // STDDEV -// testCodeGen( -// "SELECT a, stddev(b), stddev_pop(b) FROM testData2 GROUP BY a", -// (1 to 3).map(i => Row(i, math.sqrt(0.5), math.sqrt(0.25)))) -// // SKEWNESS -// testCodeGen( -// "SELECT a, skewness(b) FROM testData2 GROUP BY a", -// (1 to 3).map(i => Row(i, 0.0))) -// testCodeGen( -// "SELECT skewness(b) FROM testData2", -// Row(0.0) :: Nil) -// // KURTOSIS -// testCodeGen( -// "SELECT a, kurtosis(b) FROM testData2 GROUP BY a", -// (1 to 3).map(i => Row(i, -2.0))) -// testCodeGen( -// "SELECT kurtosis(b) FROM testData2", -// Row(-2.0) :: Nil) // Some combinations. testCodeGen( """ From d54fb0d2f6cfba1e419076eaf8495c3389c9d2f3 Mon Sep 17 00:00:00 2001 From: sethah Date: Mon, 19 Oct 2015 09:57:46 -0700 Subject: [PATCH 09/22] cast child expression to double type --- .../sql/catalyst/expressions/aggregate/functions.scala | 8 +++----- .../src/main/scala/org/apache/spark/sql/GroupedData.scala | 1 + 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 6296230c723ae..e6abe1f34bcba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -811,8 +811,6 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) - def cloneBufferAttributes: Seq[Attribute] = aggBufferAttributes.map(_.newInstance()) - /** * The number of central moments to store in the buffer */ @@ -837,10 +835,10 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w * Update the central moments buffer. */ override def update(buffer: MutableRow, input: InternalRow): Unit = { - val v = child.eval(input) + val v = Cast(child, DoubleType).eval(input) if (v != null) { val updateValue = v match { - case d: java.lang.Number => d.doubleValue() + case d: Double => d case _ => 0.0 } val currentM0 = buffer.getDouble(mutableAggBufferOffset) @@ -1031,7 +1029,7 @@ case class VariancePop(child: Expression) extends CentralMomentAgg(child) { override def prettyName: String = "variance_pop" - val maxMoment = 2 + protected val maxMoment = 2 override def eval(buffer: InternalRow): Any = { // var_pop = M2 / M0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 5f1b299c0cdb1..dc96384a4d28d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -268,6 +268,7 @@ class GroupedData protected[sql]( } /** + * Compute the kurtosis for each numeric columns for each group. * The resulting [[DataFrame]] will also contain the grouping columns. * When specified columns are given, only compute the kurtosis values for them. * From 853922ad44cac443bb75c2c8bf3eb55fe3987a68 Mon Sep 17 00:00:00 2001 From: sethah Date: Tue, 20 Oct 2015 09:03:35 -0700 Subject: [PATCH 10/22] using vars for aggregator --- .../expressions/aggregate/functions.scala | 584 ++++++++++++------ .../sql/catalyst/expressions/aggregates.scala | 2 +- 2 files changed, 398 insertions(+), 188 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index e6abe1f34bcba..18a241fecb64c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -327,6 +327,149 @@ case class Min(child: Expression) extends DeclarativeAggregate { override val evaluateExpression = min } +// Compute the sample standard deviation of a column +case class Stddev(child: Expression) extends StddevAgg(child) { + + override def isSample: Boolean = true + override def prettyName: String = "stddev" +} + +// Compute the population standard deviation of a column +case class StddevPop(child: Expression) extends StddevAgg(child) { + + override def isSample: Boolean = false + override def prettyName: String = "stddev_pop" +} + +// Compute the sample standard deviation of a column +case class StddevSamp(child: Expression) extends StddevAgg(child) { + + override def isSample: Boolean = true + override def prettyName: String = "stddev_samp" +} + +// Compute standard deviation based on online algorithm specified here: +// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance +abstract class StddevAgg(child: Expression) extends DeclarativeAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + def isSample: Boolean + + // Return data type. + override def dataType: DataType = resultType + + // Expected input data type. + // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the + // new version at planning time (after analysis phase). For now, NullType is added at here + // to make it resolved when we have cases like `select stddev(null)`. + // We can use our analyzer to cast NullType to the default data type of the NumericType once + // we remove the old aggregate functions. Then, we will not need NullType at here. + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) + + private val resultType = DoubleType + + private val preCount = AttributeReference("preCount", resultType)() + private val currentCount = AttributeReference("currentCount", resultType)() + private val preAvg = AttributeReference("preAvg", resultType)() + private val currentAvg = AttributeReference("currentAvg", resultType)() + private val currentMk = AttributeReference("currentMk", resultType)() + + override val aggBufferAttributes = preCount :: currentCount :: preAvg :: + currentAvg :: currentMk :: Nil + + override val initialValues = Seq( + /* preCount = */ Cast(Literal(0), resultType), + /* currentCount = */ Cast(Literal(0), resultType), + /* preAvg = */ Cast(Literal(0), resultType), + /* currentAvg = */ Cast(Literal(0), resultType), + /* currentMk = */ Cast(Literal(0), resultType) + ) + + override val updateExpressions = { + + // update average + // avg = avg + (value - avg)/count + def avgAdd: Expression = { + currentAvg + ((Cast(child, resultType) - currentAvg) / currentCount) + } + + // update sum of square of difference from mean + // Mk = Mk + (value - preAvg) * (value - updatedAvg) + def mkAdd: Expression = { + val delta1 = Cast(child, resultType) - preAvg + val delta2 = Cast(child, resultType) - currentAvg + currentMk + (delta1 * delta2) + } + + Seq( + /* preCount = */ If(IsNull(child), preCount, currentCount), + /* currentCount = */ If(IsNull(child), currentCount, + Add(currentCount, Cast(Literal(1), resultType))), + /* preAvg = */ If(IsNull(child), preAvg, currentAvg), + /* currentAvg = */ If(IsNull(child), currentAvg, avgAdd), + /* currentMk = */ If(IsNull(child), currentMk, mkAdd) + ) + } + + override val mergeExpressions = { + + // count merge + def countMerge: Expression = { + currentCount.left + currentCount.right + } + + // average merge + def avgMerge: Expression = { + ((currentAvg.left * preCount) + (currentAvg.right * currentCount.right)) / + (preCount + currentCount.right) + } + + // update sum of square differences + def mkMerge: Expression = { + val avgDelta = currentAvg.right - preAvg + val mkDelta = (avgDelta * avgDelta) * (preCount * currentCount.right) / + (preCount + currentCount.right) + + currentMk.left + currentMk.right + mkDelta + } + + Seq( + /* preCount = */ If(IsNull(currentCount.left), + Cast(Literal(0), resultType), currentCount.left), + /* currentCount = */ If(IsNull(currentCount.left), currentCount.right, + If(IsNull(currentCount.right), currentCount.left, countMerge)), + /* preAvg = */ If(IsNull(currentAvg.left), Cast(Literal(0), resultType), currentAvg.left), + /* currentAvg = */ If(IsNull(currentAvg.left), currentAvg.right, + If(IsNull(currentAvg.right), currentAvg.left, avgMerge)), + /* currentMk = */ If(IsNull(currentMk.left), currentMk.right, + If(IsNull(currentMk.right), currentMk.left, mkMerge)) + ) + } + + override val evaluateExpression = { + // when currentCount == 0, return null + // when currentCount == 1, return 0 + // when currentCount >1 + // stddev_samp = sqrt (currentMk/(currentCount -1)) + // stddev_pop = sqrt (currentMk/currentCount) + val varCol = { + if (isSample) { + currentMk / Cast((currentCount - Cast(Literal(1), resultType)), resultType) + } + else { + currentMk / currentCount + } + } + + If(EqualTo(currentCount, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), + If(EqualTo(currentCount, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), + Cast(Sqrt(varCol), resultType))) + } +} + case class Sum(child: Expression) extends DeclarativeAggregate { override def children: Seq[Expression] = child :: Nil @@ -788,6 +931,23 @@ object HyperLogLogPlusPlus { // scalastyle:on } +/** + * A central moment is the expected value of a specified power of the deviation of a random + * variable from the mean. Central moments are often used to characterize the properties of about + * the shape of a distribution. + * + * This class implements online, one-pass algorithms for computing the central moments of a set of + * points. + * + * References: + * - Xiangrui Meng. "Simpler Online Updates for Arbitrary-Order Central Moments." + * 2015. http://arxiv.org/abs/1510.04923 + * + * @see [[https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics + * Algorithms for calculating variance (Wikipedia)]] + * + * @param child to compute central moments of. + */ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate with Serializable { /** @@ -820,6 +980,11 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w AttributeReference(s"M$i", DoubleType)() } + // Note: although this simply copies aggBufferAttributes, this common code can not be placed + // in the superclass because that will lead to initialization ordering issues. + override val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) + /** * Initialize all moments to zero */ @@ -831,6 +996,11 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w } } + private[this] var delta = 0.0 + private[this] var deltaN = 0.0 + private[this] var delta2 = 0.0 + private[this] var deltaN2 = 0.0 + /** * Update the central moments buffer. */ @@ -841,248 +1011,288 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w case d: Double => d case _ => 0.0 } - val currentM0 = buffer.getDouble(mutableAggBufferOffset) - val currentM1 = buffer.getDouble(mutableAggBufferOffset + 1) - val currentM2 = buffer.getDouble(mutableAggBufferOffset + 2) - val currentM3 = buffer.getDouble(mutableAggBufferOffset + 3) - val currentM4 = buffer.getDouble(mutableAggBufferOffset + 4) - - val updateM0 = currentM0 + 1.0 - val delta = updateValue - currentM1 - val deltaN = delta / updateM0 - - val updateM1 = currentM1 + delta / updateM0 - val updateM2 = if (maxMoment >= 2) { - currentM2 + delta * (delta - deltaN) - } else { - 0.0 + var n = buffer.getDouble(mutableAggBufferOffset) + var mean = buffer.getDouble(mutableAggBufferOffset + 1) + var M2 = 0.0 + var M3 = 0.0 + var M4 = 0.0 + + n += 1.0 + delta = updateValue - mean + deltaN = delta / n + mean += deltaN + buffer.setDouble(mutableAggBufferOffset, n) + buffer.setDouble(mutableAggBufferOffset + 1, mean) + + if (maxMoment >= 2) { + M2 = buffer.getDouble(mutableAggBufferOffset + 2) + M2 += delta * (delta - deltaN) + buffer.setDouble(mutableAggBufferOffset + 2, M2) } - val delta2 = delta * delta - val deltaN2 = deltaN * deltaN - val updateM3 = if (maxMoment >= 3) { - currentM3 - 3.0 * deltaN * updateM2 + delta * (delta2 - deltaN2) - } else { - 0.0 + + if (maxMoment >= 3) { + delta2 = delta * delta + deltaN2 = deltaN * deltaN + M3 = buffer.getDouble(mutableAggBufferOffset + 3) + println(M3) + M3 += -3.0 * deltaN * M2 + delta * (delta2 - deltaN2) + println(M3) + buffer.setDouble(mutableAggBufferOffset + 3, M3) } - val updateM4 = if (maxMoment >= 4) { - currentM4 - 4.0 * deltaN * updateM3 - 6.0 * deltaN2 * updateM2 + + + if (maxMoment >= 4) { + M4 = buffer.getDouble(mutableAggBufferOffset + 4) + M4 += -4.0 * deltaN * M3 - 6.0 * deltaN2 * M2 + delta * (delta * delta2 - deltaN * deltaN2) - } else { - 0.0 + buffer.setDouble(mutableAggBufferOffset + 4, M4) } - - buffer.setDouble(mutableAggBufferOffset, updateM0) - buffer.setDouble(mutableAggBufferOffset + 1, updateM1) - buffer.setDouble(mutableAggBufferOffset + 2, updateM2) - buffer.setDouble(mutableAggBufferOffset + 3, updateM3) - buffer.setDouble(mutableAggBufferOffset + 4, updateM4) } } /** Merge two central moment buffers. */ override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - val zeroMoment1 = buffer1.getDouble(mutableAggBufferOffset) - val zeroMoment2 = buffer2.getDouble(inputAggBufferOffset) - val firstMoment1 = buffer1.getDouble(mutableAggBufferOffset + 1) - val firstMoment2 = buffer2.getDouble(inputAggBufferOffset + 1) - val secondMoment1 = buffer1.getDouble(mutableAggBufferOffset + 2) - val secondMoment2 = buffer2.getDouble(inputAggBufferOffset + 2) - val thirdMoment1 = buffer1.getDouble(mutableAggBufferOffset + 3) - val thirdMoment2 = buffer2.getDouble(inputAggBufferOffset + 3) - val fourthMoment1 = buffer1.getDouble(mutableAggBufferOffset + 4) - val fourthMoment2 = buffer2.getDouble(inputAggBufferOffset + 4) - - val zeroMoment = zeroMoment1 + zeroMoment2 - val delta = firstMoment2 - firstMoment1 - val deltaN = delta / zeroMoment - - val firstMoment = firstMoment1 + deltaN * zeroMoment2 - - val secondMoment = if (maxMoment >= 2) { - secondMoment1 + secondMoment2 + delta * deltaN * zeroMoment1 * zeroMoment2 - } else { - 0.0 + val n1 = buffer1.getDouble(mutableAggBufferOffset) + val n2 = buffer2.getDouble(inputAggBufferOffset) + val mean1 = buffer1.getDouble(mutableAggBufferOffset + 1) + val mean2 = buffer2.getDouble(inputAggBufferOffset + 1) + + var secondMoment1 = 0.0 + var secondMoment2 = 0.0 + var secondMoment = 0.0 + + var thirdMoment1 = 0.0 + var thirdMoment2 = 0.0 + var thirdMoment = 0.0 + + var fourthMoment1 = 0.0 + var fourthMoment2 = 0.0 + var fourthMoment = 0.0 + + val n = n1 + n2 + delta = mean2 - mean1 + deltaN = delta / n + val mean = mean1 + deltaN * n2 + + buffer1.setDouble(mutableAggBufferOffset, n) + buffer1.setDouble(mutableAggBufferOffset + 1, mean) + + if (maxMoment >= 2) { + secondMoment1 = buffer1.getDouble(mutableAggBufferOffset + 2) + secondMoment2 = buffer2.getDouble(inputAggBufferOffset + 2) + secondMoment = secondMoment1 + secondMoment2 + delta * deltaN * n1 * n2 + buffer1.setDouble(mutableAggBufferOffset + 2, secondMoment) } - val thirdMoment = if (maxMoment >= 3) { - thirdMoment1 + thirdMoment2 + deltaN * deltaN * delta * zeroMoment1 * zeroMoment2 * - (zeroMoment1 - zeroMoment2) + 3.0 * deltaN * - (zeroMoment1 * secondMoment2 - zeroMoment2 * secondMoment1) - } else { - 0.0 - } - val fourthMoment = if (maxMoment >= 4) { - fourthMoment1 + fourthMoment2 + deltaN * deltaN * deltaN * delta * zeroMoment1 * - zeroMoment2 * (zeroMoment1 * zeroMoment1 - zeroMoment1 * zeroMoment2 + - zeroMoment2 * zeroMoment2) + deltaN * deltaN * 6.0 * - (zeroMoment1 * zeroMoment1 * secondMoment2 + zeroMoment2 * zeroMoment2 * secondMoment1) + - 4.0 * deltaN * (zeroMoment1 * thirdMoment2 - zeroMoment2 * thirdMoment1) - } else { - 0.0 + if (maxMoment >= 3) { + thirdMoment1 = buffer1.getDouble(mutableAggBufferOffset + 3) + thirdMoment2 = buffer2.getDouble(inputAggBufferOffset + 3) + thirdMoment = thirdMoment1 + thirdMoment2 + deltaN * deltaN * delta * n1 * n2 * + (n1 - n2) + 3.0 * deltaN * (n1 * secondMoment2 - n2 * secondMoment1) + buffer1.setDouble(mutableAggBufferOffset + 3, thirdMoment) } - buffer1.setDouble(mutableAggBufferOffset, zeroMoment) - buffer1.setDouble(mutableAggBufferOffset + 1, firstMoment) - buffer1.setDouble(mutableAggBufferOffset + 2, secondMoment) - buffer1.setDouble(mutableAggBufferOffset + 3, thirdMoment) - buffer1.setDouble(mutableAggBufferOffset + 4, fourthMoment) - } -} - -case class Stddev(child: Expression) extends CentralMomentAgg(child) { - - override def prettyName: String = "stddev" - - protected val maxMoment = 2 - - def eval(buffer: InternalRow): Any = { - // stddev = sqrt(M2 / (M0 - 1)) - val M0 = buffer.getDouble(mutableAggBufferOffset) - val M2 = buffer.getDouble(mutableAggBufferOffset + 2) - - if (M0 == 0.0) { - 0.0 - } else { - math.sqrt(M2 / (M0 - 1.0)) + if (maxMoment >= 4) { + fourthMoment1 = buffer1.getDouble(mutableAggBufferOffset + 4) + fourthMoment2 = buffer2.getDouble(inputAggBufferOffset + 4) + fourthMoment = fourthMoment1 + fourthMoment2 + deltaN * deltaN * deltaN * delta * n1 * + n2 * (n1 * n1 - n1 * n2 + n2 * n2) + deltaN * deltaN * 6.0 * + (n1 * n1 * secondMoment2 + n2 * n2 * secondMoment1) + + 4.0 * deltaN * (n1 * thirdMoment2 - n2 * thirdMoment1) + buffer1.setDouble(mutableAggBufferOffset + 4, fourthMoment) } } -} - -case class StddevSamp(child: Expression) extends CentralMomentAgg(child) { - - override def prettyName: String = "stddev_samp" - protected val maxMoment = 2 + def eval(buffer: InternalRow): Any = this match { + case _: VariancePop => + // stddev = M2 / (M0 - 1) + val M0 = buffer.getDouble(mutableAggBufferOffset) + val M2 = buffer.getDouble(mutableAggBufferOffset + 2) - override def eval(buffer: InternalRow): Any = { - // stddev_samp = sqrt(M2 / (M0 - 1)) - val M0 = buffer.getDouble(mutableAggBufferOffset) - val M2 = buffer.getDouble(mutableAggBufferOffset + 2) + if (M0 == 0.0) { + Double.NaN + } else { + M2 / M0 + } + case _: VarianceSamp => + // stddev = M2 / (M0 - 1) + val M0 = buffer.getDouble(mutableAggBufferOffset) + val M2 = buffer.getDouble(mutableAggBufferOffset + 2) - if (M0 == 0.0) { - 0.0 - } else { - math.sqrt(M2 / (M0 - 1.0)) - } + if (M0 == 0.0) { + Double.NaN + } else { + M2 / (M0 - 1.0) + } + case _: Variance => + // stddev = M2 / (M0 - 1) + val M0 = buffer.getDouble(mutableAggBufferOffset) + val M1 = buffer.getDouble(mutableAggBufferOffset + 1) + val M2 = buffer.getDouble(mutableAggBufferOffset + 2) + + if (M0 == 0.0) { + Double.NaN + } else { + M2 / (M0 - 1.0) + } + case _: Skewness => + // skewness = sqrt(M0) * M3 / sqrt(M2^3) + val M0 = buffer.getDouble(mutableAggBufferOffset) + val M2 = buffer.getDouble(mutableAggBufferOffset + 2) + val M3 = buffer.getDouble(mutableAggBufferOffset + 3) + + if (M0 == 0.0 || M2 == 0.0) { + Double.NaN + } else { + math.sqrt(M0) * M3 / math.sqrt(M2 * M2 * M2) + } + case _: Kurtosis => + // kurtosis = M0 * M4 / M2^2 - 3.0 + // NOTE: this is the formula for excess kurtosis, which is default for R and NumPy + val M0 = buffer.getDouble(mutableAggBufferOffset) + val M2 = buffer.getDouble(mutableAggBufferOffset + 2) + val M4 = buffer.getDouble(mutableAggBufferOffset + 4) + + if (M0 == 0.0 || M2 == 0.0) { + Double.NaN + } else { + M0 * M4 / (M2 * M2) - 3.0 + } + case _ => 0.0 } } -case class StddevPop(child: Expression) extends CentralMomentAgg(child) { - - override def prettyName: String = "stddev_pop" +case class Variance(child: Expression, mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { - val maxMoment = 2 - - override def eval(buffer: InternalRow): Any = { - // stddev_pop = sqrt(M2 / M0) - val M0 = buffer.getDouble(mutableAggBufferOffset) - val M2 = buffer.getDouble(mutableAggBufferOffset + 2) - - if (M0 == 0.0) { - 0.0 - } else { - math.sqrt(M2 / M0) - } - } -} + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) -case class Variance(child: Expression) extends CentralMomentAgg(child) { + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) override def prettyName: String = "variance" protected val maxMoment = 2 - override def eval(buffer: InternalRow): Any = { - // stddev = M2 / (M0 - 1) - val M0 = buffer.getDouble(mutableAggBufferOffset) - val M2 = buffer.getDouble(mutableAggBufferOffset + 2) - - if (M0 == 0.0) { - 0.0 - } else { - M2 / (M0 - 1.0) - } - } +// override def eval(buffer: InternalRow): Any = { +// // stddev = M2 / (M0 - 1) +// val M0 = buffer.getDouble(mutableAggBufferOffset) +// val M2 = buffer.getDouble(mutableAggBufferOffset + 2) +// +// if (M0 == 0.0) { +// 0.0 +// } else { +// M2 / (M0 - 1.0) +// } +// } } -case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) { +case class VarianceSamp(child: Expression, mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) override def prettyName: String = "variance_samp" protected val maxMoment = 2 - override def eval(buffer: InternalRow): Any = { - // var_samp = M2 / (M0 - 1) - val M0 = buffer.getDouble(mutableAggBufferOffset) - val M2 = buffer.getDouble(mutableAggBufferOffset + 2) - - if (M0 == 0.0) { - 0.0 - } else { - M2 / (M0 - 1.0) - } - } +// override def eval(buffer: InternalRow): Any = { +// // var_samp = M2 / (M0 - 1) +// val M0 = buffer.getDouble(mutableAggBufferOffset) +// val M2 = buffer.getDouble(mutableAggBufferOffset + 2) +// +// if (M0 == 0.0) { +// 0.0 +// } else { +// M2 / (M0 - 1.0) +// } +// } } -case class VariancePop(child: Expression) extends CentralMomentAgg(child) { +case class VariancePop(child: Expression, mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) override def prettyName: String = "variance_pop" protected val maxMoment = 2 - override def eval(buffer: InternalRow): Any = { - // var_pop = M2 / M0 - val M0 = buffer.getDouble(mutableAggBufferOffset) - val M2 = buffer.getDouble(mutableAggBufferOffset + 2) - - if (M0 == 0.0) { - 0.0 - } else { - M2 / M0 - } - } +// override def eval(buffer: InternalRow): Any = { +// // var_pop = M2 / M0 +// val M0 = buffer.getDouble(mutableAggBufferOffset) +// val M2 = buffer.getDouble(mutableAggBufferOffset + 2) +// +// if (M0 == 0.0) { +// 0.0 +// } else { +// M2 / M0 +// } +// } } -case class Skewness(child: Expression) extends CentralMomentAgg(child) { +case class Skewness(child: Expression, mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) override def prettyName: String = "skewness" protected val maxMoment = 3 - override def eval(buffer: InternalRow): Any = { - // skewness = sqrt(M0) * M3 / sqrt(M2^3) - val M0 = buffer.getDouble(mutableAggBufferOffset) - val M2 = buffer.getDouble(mutableAggBufferOffset + 2) - val M3 = buffer.getDouble(mutableAggBufferOffset + 3) - - if (M0 == 0.0 || M2 == 0.0) { - 0.0 - } else { - math.sqrt(M0) * M3 / math.sqrt(M2 * M2 * M2) - } - } +// override def eval(buffer: InternalRow): Any = { +// // skewness = sqrt(M0) * M3 / sqrt(M2^3) +// val M0 = buffer.getDouble(mutableAggBufferOffset) +// val M2 = buffer.getDouble(mutableAggBufferOffset + 2) +// val M3 = buffer.getDouble(mutableAggBufferOffset + 3) +// +// if (M0 == 0.0 || M2 == 0.0) { +// 0.0 +// } else { +// math.sqrt(M0) * M3 / math.sqrt(M2 * M2 * M2) +// } +// } } -case class Kurtosis(child: Expression) extends CentralMomentAgg(child) { +case class Kurtosis(child: Expression, mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) override def prettyName: String = "kurtosis" protected val maxMoment = 4 - override def eval(buffer: InternalRow): Any = { - // kurtosis = M0 * M4 / M2^2 - 3.0 - // NOTE: this is the formula for excess kurtosis, which is default for R and NumPy - val M0 = buffer.getDouble(mutableAggBufferOffset) - val M2 = buffer.getDouble(mutableAggBufferOffset + 2) - val M4 = buffer.getDouble(mutableAggBufferOffset + 4) - - if (M0 == 0.0) { - 0.0 - } else if (M2 == 0.0) { - -3.0 - } else { - M0 * M4 / (M2 * M2) - 3.0 - } - } +// override def eval(buffer: InternalRow): Any = { +// // kurtosis = M0 * M4 / M2^2 - 3.0 +// // NOTE: this is the formula for excess kurtosis, which is default for R and NumPy +// val M0 = buffer.getDouble(mutableAggBufferOffset) +// val M2 = buffer.getDouble(mutableAggBufferOffset + 2) +// val M4 = buffer.getDouble(mutableAggBufferOffset + 4) +// +// if (M0 == 0.0) { +// 0.0 +// } else if (M2 == 0.0) { +// -3.0 +// } else { +// M0 * M4 / (M2 * M2) - 3.0 +// } +// } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 554746f0e31e4..4a898d7599e6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -1076,4 +1076,4 @@ case class VarianceSamp(child: Expression) extends CentralMomentAgg1(child) { override def prettyName: String = "variance_samp" override def toString: String = s"VAR_SAMP($child)" -} \ No newline at end of file +} From 4a5350eb35742745266821b0c0c96a6f08bb3d88 Mon Sep 17 00:00:00 2001 From: sethah Date: Tue, 20 Oct 2015 18:58:13 -0700 Subject: [PATCH 11/22] restructuring eval method --- .../expressions/aggregate/functions.scala | 257 ++++++++---------- 1 file changed, 107 insertions(+), 150 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 18a241fecb64c..f6f16e86974c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -943,7 +943,7 @@ object HyperLogLogPlusPlus { * - Xiangrui Meng. "Simpler Online Updates for Arbitrary-Order Central Moments." * 2015. http://arxiv.org/abs/1510.04923 * - * @see [[https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics + * @see [[https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance * Algorithms for calculating variance (Wikipedia)]] * * @param child to compute central moments of. @@ -953,7 +953,9 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w /** * The maximum central moment order to be computed */ - protected def maxMoment: Int + protected def momentOrder: Int + + protected def sufficientMoments: Array[Int] override def children: Seq[Expression] = Seq(child) @@ -996,6 +998,7 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w } } + // frequently used values for online updates private[this] var delta = 0.0 private[this] var deltaN = 0.0 private[this] var delta2 = 0.0 @@ -1013,9 +1016,9 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w } var n = buffer.getDouble(mutableAggBufferOffset) var mean = buffer.getDouble(mutableAggBufferOffset + 1) - var M2 = 0.0 - var M3 = 0.0 - var M4 = 0.0 + var m2 = 0.0 + var m3 = 0.0 + var m4 = 0.0 n += 1.0 delta = updateValue - mean @@ -1024,32 +1027,32 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w buffer.setDouble(mutableAggBufferOffset, n) buffer.setDouble(mutableAggBufferOffset + 1, mean) - if (maxMoment >= 2) { - M2 = buffer.getDouble(mutableAggBufferOffset + 2) - M2 += delta * (delta - deltaN) - buffer.setDouble(mutableAggBufferOffset + 2, M2) + if (momentOrder >= 2) { + m2 = buffer.getDouble(mutableAggBufferOffset + 2) + m2 += delta * (delta - deltaN) + buffer.setDouble(mutableAggBufferOffset + 2, m2) } - if (maxMoment >= 3) { + if (momentOrder >= 3) { delta2 = delta * delta deltaN2 = deltaN * deltaN - M3 = buffer.getDouble(mutableAggBufferOffset + 3) - println(M3) - M3 += -3.0 * deltaN * M2 + delta * (delta2 - deltaN2) - println(M3) - buffer.setDouble(mutableAggBufferOffset + 3, M3) + m3 = buffer.getDouble(mutableAggBufferOffset + 3) + m3 += -3.0 * deltaN * m2 + delta * (delta2 - deltaN2) + buffer.setDouble(mutableAggBufferOffset + 3, m3) } - if (maxMoment >= 4) { - M4 = buffer.getDouble(mutableAggBufferOffset + 4) - M4 += -4.0 * deltaN * M3 - 6.0 * deltaN2 * M2 + + if (momentOrder >= 4) { + m4 = buffer.getDouble(mutableAggBufferOffset + 4) + m4 += -4.0 * deltaN * m3 - 6.0 * deltaN2 * m2 + delta * (delta * delta2 - deltaN * deltaN2) - buffer.setDouble(mutableAggBufferOffset + 4, M4) + buffer.setDouble(mutableAggBufferOffset + 4, m4) } } } - /** Merge two central moment buffers. */ + /** + * Merge two central moment buffers. + */ override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { val n1 = buffer1.getDouble(mutableAggBufferOffset) val n2 = buffer2.getDouble(inputAggBufferOffset) @@ -1076,7 +1079,9 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w buffer1.setDouble(mutableAggBufferOffset, n) buffer1.setDouble(mutableAggBufferOffset + 1, mean) - if (maxMoment >= 2) { + // higher order moments computed according to: + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics + if (momentOrder >= 2) { secondMoment1 = buffer1.getDouble(mutableAggBufferOffset + 2) secondMoment2 = buffer2.getDouble(inputAggBufferOffset + 2) secondMoment = secondMoment1 + secondMoment2 + delta * deltaN * n1 * n2 @@ -1084,7 +1089,7 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w } - if (maxMoment >= 3) { + if (momentOrder >= 3) { thirdMoment1 = buffer1.getDouble(mutableAggBufferOffset + 3) thirdMoment2 = buffer2.getDouble(inputAggBufferOffset + 3) thirdMoment = thirdMoment1 + thirdMoment2 + deltaN * deltaN * delta * n1 * n2 * @@ -1092,7 +1097,7 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w buffer1.setDouble(mutableAggBufferOffset + 3, thirdMoment) } - if (maxMoment >= 4) { + if (momentOrder >= 4) { fourthMoment1 = buffer1.getDouble(mutableAggBufferOffset + 4) fourthMoment2 = buffer2.getDouble(inputAggBufferOffset + 4) fourthMoment = fourthMoment1 + fourthMoment2 + deltaN * deltaN * deltaN * delta * n1 * @@ -1103,67 +1108,52 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w } } - def eval(buffer: InternalRow): Any = this match { - case _: VariancePop => - // stddev = M2 / (M0 - 1) - val M0 = buffer.getDouble(mutableAggBufferOffset) - val M2 = buffer.getDouble(mutableAggBufferOffset + 2) + /** + * Compute aggregate statistic from sufficient moments. + */ + def getStatistic(n: Double, moments: Array[Double]): Double - if (M0 == 0.0) { - Double.NaN - } else { - M2 / M0 - } - case _: VarianceSamp => - // stddev = M2 / (M0 - 1) - val M0 = buffer.getDouble(mutableAggBufferOffset) - val M2 = buffer.getDouble(mutableAggBufferOffset + 2) + override final def eval(buffer: InternalRow): Any = { + val n = buffer.getDouble(mutableAggBufferOffset) + val moments = sufficientMoments.map { momentIdx => + buffer.getDouble(mutableAggBufferOffset + momentIdx) + } + getStatistic(n, moments) + } +} - if (M0 == 0.0) { - Double.NaN - } else { - M2 / (M0 - 1.0) - } - case _: Variance => - // stddev = M2 / (M0 - 1) - val M0 = buffer.getDouble(mutableAggBufferOffset) - val M1 = buffer.getDouble(mutableAggBufferOffset + 1) - val M2 = buffer.getDouble(mutableAggBufferOffset + 2) - - if (M0 == 0.0) { - Double.NaN - } else { - M2 / (M0 - 1.0) - } - case _: Skewness => - // skewness = sqrt(M0) * M3 / sqrt(M2^3) - val M0 = buffer.getDouble(mutableAggBufferOffset) - val M2 = buffer.getDouble(mutableAggBufferOffset + 2) - val M3 = buffer.getDouble(mutableAggBufferOffset + 3) - - if (M0 == 0.0 || M2 == 0.0) { - Double.NaN - } else { - math.sqrt(M0) * M3 / math.sqrt(M2 * M2 * M2) - } - case _: Kurtosis => - // kurtosis = M0 * M4 / M2^2 - 3.0 - // NOTE: this is the formula for excess kurtosis, which is default for R and NumPy - val M0 = buffer.getDouble(mutableAggBufferOffset) - val M2 = buffer.getDouble(mutableAggBufferOffset + 2) - val M4 = buffer.getDouble(mutableAggBufferOffset + 4) - - if (M0 == 0.0 || M2 == 0.0) { - Double.NaN - } else { - M0 * M4 / (M2 * M2) - 3.0 - } - case _ => 0.0 +abstract class SecondMoment(child: Expression) extends CentralMomentAgg(child) { + + protected val momentOrder = 2 + + protected def isBiased: Boolean + + protected def isStd: Boolean + + protected val sufficientMoments = Array(2) + + override def getStatistic(n: Double, moments: Array[Double]): Double = { + require(moments.length == sufficientMoments.length, + s"$prettyName requires one central moment, received: ${moments.length}") + + val m2 = moments.head + val divisor = if (isBiased) n else n - 1 + val variance = if (n == 0.0) { + Double.NaN + } else { + m2 / divisor + } + + if (isStd) { + math.sqrt(variance) + } else { + variance + } } } case class Variance(child: Expression, mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { + inputAggBufferOffset: Int = 0) extends SecondMoment(child) { override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) @@ -1173,23 +1163,13 @@ case class Variance(child: Expression, mutableAggBufferOffset: Int = 0, override def prettyName: String = "variance" - protected val maxMoment = 2 - -// override def eval(buffer: InternalRow): Any = { -// // stddev = M2 / (M0 - 1) -// val M0 = buffer.getDouble(mutableAggBufferOffset) -// val M2 = buffer.getDouble(mutableAggBufferOffset + 2) -// -// if (M0 == 0.0) { -// 0.0 -// } else { -// M2 / (M0 - 1.0) -// } -// } + override protected val isBiased = false + + override protected val isStd = false } case class VarianceSamp(child: Expression, mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { + inputAggBufferOffset: Int = 0) extends SecondMoment(child) { override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) @@ -1199,23 +1179,13 @@ case class VarianceSamp(child: Expression, mutableAggBufferOffset: Int = 0, override def prettyName: String = "variance_samp" - protected val maxMoment = 2 - -// override def eval(buffer: InternalRow): Any = { -// // var_samp = M2 / (M0 - 1) -// val M0 = buffer.getDouble(mutableAggBufferOffset) -// val M2 = buffer.getDouble(mutableAggBufferOffset + 2) -// -// if (M0 == 0.0) { -// 0.0 -// } else { -// M2 / (M0 - 1.0) -// } -// } + override protected val isBiased = false + + override protected val isStd = false } case class VariancePop(child: Expression, mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { + inputAggBufferOffset: Int = 0) extends SecondMoment(child) { override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) @@ -1225,19 +1195,9 @@ case class VariancePop(child: Expression, mutableAggBufferOffset: Int = 0, override def prettyName: String = "variance_pop" - protected val maxMoment = 2 - -// override def eval(buffer: InternalRow): Any = { -// // var_pop = M2 / M0 -// val M0 = buffer.getDouble(mutableAggBufferOffset) -// val M2 = buffer.getDouble(mutableAggBufferOffset + 2) -// -// if (M0 == 0.0) { -// 0.0 -// } else { -// M2 / M0 -// } -// } + override protected val isBiased = true + + override protected val isStd = false } case class Skewness(child: Expression, mutableAggBufferOffset: Int = 0, @@ -1251,20 +1211,20 @@ case class Skewness(child: Expression, mutableAggBufferOffset: Int = 0, override def prettyName: String = "skewness" - protected val maxMoment = 3 - -// override def eval(buffer: InternalRow): Any = { -// // skewness = sqrt(M0) * M3 / sqrt(M2^3) -// val M0 = buffer.getDouble(mutableAggBufferOffset) -// val M2 = buffer.getDouble(mutableAggBufferOffset + 2) -// val M3 = buffer.getDouble(mutableAggBufferOffset + 3) -// -// if (M0 == 0.0 || M2 == 0.0) { -// 0.0 -// } else { -// math.sqrt(M0) * M3 / math.sqrt(M2 * M2 * M2) -// } -// } + protected val momentOrder = 3 + + protected val sufficientMoments = Array(2, 3) + + override def getStatistic(n: Double, moments: Array[Double]): Double = { + require(moments.length == sufficientMoments.length, + s"skewness requires two central moments, received: ${moments.length}") + val Array(m2, m3) = moments + if (n == 0.0 || m2 == 0.0) { + Double.NaN + } else { + math.sqrt(n) * m3 / math.sqrt(m2 * m2 * m2) + } + } } case class Kurtosis(child: Expression, mutableAggBufferOffset: Int = 0, @@ -1278,21 +1238,18 @@ case class Kurtosis(child: Expression, mutableAggBufferOffset: Int = 0, override def prettyName: String = "kurtosis" - protected val maxMoment = 4 - -// override def eval(buffer: InternalRow): Any = { -// // kurtosis = M0 * M4 / M2^2 - 3.0 -// // NOTE: this is the formula for excess kurtosis, which is default for R and NumPy -// val M0 = buffer.getDouble(mutableAggBufferOffset) -// val M2 = buffer.getDouble(mutableAggBufferOffset + 2) -// val M4 = buffer.getDouble(mutableAggBufferOffset + 4) -// -// if (M0 == 0.0) { -// 0.0 -// } else if (M2 == 0.0) { -// -3.0 -// } else { -// M0 * M4 / (M2 * M2) - 3.0 -// } -// } + protected val momentOrder = 4 + + protected val sufficientMoments = Array(2, 4) + + override def getStatistic(n: Double, moments: Array[Double]): Double = { + require(moments.length == sufficientMoments.length, + s"kurtosis requires two central moments, received: ${moments.length}") + val Array(m2, m4) = moments + if (n == 0.0 || m2 == 0.0) { + Double.NaN + } else { + n * m4 / (m2 * m2) - 3.0 + } + } } From dba511bf1633f9a3b9732179b59dc4741f6805a0 Mon Sep 17 00:00:00 2001 From: sethah Date: Tue, 20 Oct 2015 21:22:31 -0700 Subject: [PATCH 12/22] style cleanup --- .../expressions/aggregate/functions.scala | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index f6f16e86974c5..f0768ba0a9b21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -951,10 +951,13 @@ object HyperLogLogPlusPlus { abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate with Serializable { /** - * The maximum central moment order to be computed + * The maximum central moment order to be computed. */ protected def momentOrder: Int + /** + * Array of sufficient moments need to compute the aggregate statistic. + */ protected def sufficientMoments: Array[Int] override def children: Seq[Expression] = Seq(child) @@ -974,7 +977,7 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) /** - * The number of central moments to store in the buffer + * The number of central moments to store in the buffer. */ private[this] val numMoments = 5 @@ -988,7 +991,7 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w aggBufferAttributes.map(_.newInstance()) /** - * Initialize all moments to zero + * Initialize all moments to zero. */ override def initialize(buffer: MutableRow): Unit = { var aggIndex = 0 @@ -1124,13 +1127,13 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w abstract class SecondMoment(child: Expression) extends CentralMomentAgg(child) { - protected val momentOrder = 2 + override protected val momentOrder = 2 protected def isBiased: Boolean protected def isStd: Boolean - protected val sufficientMoments = Array(2) + override protected val sufficientMoments = Array(2) override def getStatistic(n: Double, moments: Array[Double]): Double = { require(moments.length == sufficientMoments.length, @@ -1211,13 +1214,13 @@ case class Skewness(child: Expression, mutableAggBufferOffset: Int = 0, override def prettyName: String = "skewness" - protected val momentOrder = 3 + override protected val momentOrder = 3 - protected val sufficientMoments = Array(2, 3) + override protected val sufficientMoments = Array(2, 3) override def getStatistic(n: Double, moments: Array[Double]): Double = { require(moments.length == sufficientMoments.length, - s"skewness requires two central moments, received: ${moments.length}") + s"$prettyName requires two central moments, received: ${moments.length}") val Array(m2, m3) = moments if (n == 0.0 || m2 == 0.0) { Double.NaN @@ -1238,13 +1241,13 @@ case class Kurtosis(child: Expression, mutableAggBufferOffset: Int = 0, override def prettyName: String = "kurtosis" - protected val momentOrder = 4 + override protected val momentOrder = 4 - protected val sufficientMoments = Array(2, 4) + override protected val sufficientMoments = Array(2, 4) override def getStatistic(n: Double, moments: Array[Double]): Double = { require(moments.length == sufficientMoments.length, - s"kurtosis requires two central moments, received: ${moments.length}") + s"$prettyName requires two central moments, received: ${moments.length}") val Array(m2, m4) = moments if (n == 0.0 || m2 == 0.0) { Double.NaN From 44c1437ee53364db70b3f8eca8a21e1ea74f75a7 Mon Sep 17 00:00:00 2001 From: sethah Date: Tue, 20 Oct 2015 22:26:28 -0700 Subject: [PATCH 13/22] reverting some stddev changes --- .../catalyst/expressions/aggregate/functions.scala | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index f0768ba0a9b21..8d83fa44de3f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -378,7 +378,7 @@ abstract class StddevAgg(child: Expression) extends DeclarativeAggregate { private val currentMk = AttributeReference("currentMk", resultType)() override val aggBufferAttributes = preCount :: currentCount :: preAvg :: - currentAvg :: currentMk :: Nil + currentAvg :: currentMk :: Nil override val initialValues = Seq( /* preCount = */ Cast(Literal(0), resultType), @@ -407,7 +407,7 @@ abstract class StddevAgg(child: Expression) extends DeclarativeAggregate { Seq( /* preCount = */ If(IsNull(child), preCount, currentCount), /* currentCount = */ If(IsNull(child), currentCount, - Add(currentCount, Cast(Literal(1), resultType))), + Add(currentCount, Cast(Literal(1), resultType))), /* preAvg = */ If(IsNull(child), preAvg, currentAvg), /* currentAvg = */ If(IsNull(child), currentAvg, avgAdd), /* currentMk = */ If(IsNull(child), currentMk, mkAdd) @@ -424,7 +424,7 @@ abstract class StddevAgg(child: Expression) extends DeclarativeAggregate { // average merge def avgMerge: Expression = { ((currentAvg.left * preCount) + (currentAvg.right * currentCount.right)) / - (preCount + currentCount.right) + (preCount + currentCount.right) } // update sum of square differences @@ -438,14 +438,14 @@ abstract class StddevAgg(child: Expression) extends DeclarativeAggregate { Seq( /* preCount = */ If(IsNull(currentCount.left), - Cast(Literal(0), resultType), currentCount.left), + Cast(Literal(0), resultType), currentCount.left), /* currentCount = */ If(IsNull(currentCount.left), currentCount.right, - If(IsNull(currentCount.right), currentCount.left, countMerge)), + If(IsNull(currentCount.right), currentCount.left, countMerge)), /* preAvg = */ If(IsNull(currentAvg.left), Cast(Literal(0), resultType), currentAvg.left), /* currentAvg = */ If(IsNull(currentAvg.left), currentAvg.right, - If(IsNull(currentAvg.right), currentAvg.left, avgMerge)), + If(IsNull(currentAvg.right), currentAvg.left, avgMerge)), /* currentMk = */ If(IsNull(currentMk.left), currentMk.right, - If(IsNull(currentMk.right), currentMk.left, mkMerge)) + If(IsNull(currentMk.right), currentMk.left, mkMerge)) ) } From 7baac9d647358b72e07f2c4f2044cb4d655afc15 Mon Sep 17 00:00:00 2001 From: sethah Date: Wed, 21 Oct 2015 19:21:19 -0700 Subject: [PATCH 14/22] adding helper function for tests with tolerances --- .../org/apache/spark/sql/SQLQuerySuite.scala | 56 ++++++++++++------- 1 file changed, 36 insertions(+), 20 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index cad679950c5d9..56c7ceadf2e80 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -729,39 +729,55 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) } + private[this] def checkAnswerWithTol(dataFrame: DataFrame, + expectedAnswer: Array[Double], + absTol: Double = 0.0): Unit = { + + val sparkAnswer = dataFrame.first().toSeq + require(sparkAnswer.length == expectedAnswer.length, + s"spark and expected answer lengths should" + + s" be equal: ${sparkAnswer.length} != ${expectedAnswer.length}") + + sparkAnswer.zip(expectedAnswer).foreach { + case (spark: Double, expected: Double) => + assert(math.abs(spark - expected) < absTol, + s"actual answer $spark not within $absTol of correct answer $expected") + } + } + test("var_samp") { - checkAnswer( - sql("SELECT VAR_SAMP(a) FROM testData2"), - Row(4/5.0) - ) + val absTol = 1e-8 + val sparkAnswer = sql("SELECT VAR_SAMP(a) FROM testData2") + val expectedAnswer = Array(4.0 / 5.0) + checkAnswerWithTol(sparkAnswer, expectedAnswer, absTol) } test("variance") { - checkAnswer( - sql("SELECT VARIANCE(a) FROM testData2"), - Row(4/5.0) - ) + val absTol = 1e-8 + val sparkAnswer = sql("SELECT VARIANCE(a) FROM testData2") + val expectedAnswer = Array(4.0 / 5.0) + checkAnswerWithTol(sparkAnswer, expectedAnswer, absTol) } test("var_pop") { - checkAnswer( - sql("SELECT VAR_POP(a) FROM testData2"), - Row(4/6.0) - ) + val absTol = 1e-8 + val sparkAnswer = sql("SELECT VAR_POP(a) FROM testData2") + val expectedAnswer = Array(4.0 / 6.0) + checkAnswerWithTol(sparkAnswer, expectedAnswer, absTol) } test("skewness") { - checkAnswer( - sql("SELECT skewness(a) FROM testData2"), - Row(0.0) - ) + val absTol = 1e-8 + val sparkAnswer = sql("SELECT skewness(a) FROM testData2") + val expectedAnswer = Array(0.0) + checkAnswerWithTol(sparkAnswer, expectedAnswer, absTol) } test("kurtosis") { - checkAnswer( - sql("SELECT kurtosis(a) FROM testData2"), - Row(-1.5) - ) + val absTol = 1e-8 + val sparkAnswer = sql("SELECT kurtosis(a) FROM testData2") + val expectedAnswer = Array(-1.5) + checkAnswerWithTol(sparkAnswer, expectedAnswer, absTol) } test("stddev agg") { From 345463ec367444acad658bc0eed7743a44085d43 Mon Sep 17 00:00:00 2001 From: sethah Date: Thu, 22 Oct 2015 11:55:37 -0700 Subject: [PATCH 15/22] more generic tests with tolerance function and placeholders for AggregateFunction1 --- .../sql/catalyst/expressions/aggregates.scala | 78 ++++++++----------- .../spark/sql/DataFrameAggregateSuite.scala | 57 +++++++++++--- .../org/apache/spark/sql/QueryTest.scala | 45 +++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 58 ++++++-------- 4 files changed, 148 insertions(+), 90 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 4a898d7599e6d..7896dcfb4bc82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -993,61 +993,27 @@ case class StddevFunction( } // placeholder -abstract class CentralMomentAgg1(child: Expression) - extends UnaryExpression with PartialAggregate1 { - - override def nullable: Boolean = true - - override def dataType: DataType = child.dataType match { - case DecimalType.Fixed(precision, scale) => - // Add 4 digits after decimal point, like Hive - DecimalType.bounded(precision + 4, scale + 4) - case _ => - DoubleType - } - - override def asPartial: SplitEvaluation = { - child.dataType match { - case DecimalType.Fixed(precision, scale) => - val partialSum = Alias(Sum(child), "PartialSum")() - val partialCount = Alias(Count(child), "PartialCount")() +case class Kurtosis(child: Expression) extends UnaryExpression with AggregateExpression { - // partialSum already increase the precision by 10 - val castedSum = Cast(Sum(partialSum.toAttribute), partialSum.dataType) - val castedCount = Cast(Sum(partialCount.toAttribute), partialSum.dataType) - SplitEvaluation( - Cast(Divide(castedSum, castedCount), dataType), - partialCount :: partialSum :: Nil) + override def nullable: Boolean = false - case _ => - val partialSum = Alias(Sum(child), "PartialSum")() - val partialCount = Alias(Count(child), "PartialCount")() + override def dataType: DoubleType.type = DoubleType -// val castedSum = Cast(Sum(partialSum.toAttribute), dataType) - val castedSum = Cast(Literal(0.0), dataType) - val castedCount = Cast(Sum(partialCount.toAttribute), dataType) - SplitEvaluation( - Divide(castedSum, castedCount), - partialCount :: partialSum :: Nil) - } - } + override def foldable: Boolean = false - override def newInstance(): AverageFunction = new AverageFunction(child, this) + override def prettyName: String = "kurtosis" - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function average") + override def toString: String = s"KURTOSIS($child)" } // placeholder -case class Kurtosis(child: Expression) extends CentralMomentAgg1(child) { +case class Skewness(child: Expression) extends UnaryExpression with AggregateExpression { - override def prettyName: String = "kurtosis" + override def nullable: Boolean = false - override def toString: String = s"KURTOSIS($child)" -} + override def dataType: DoubleType.type = DoubleType -// placeholder -case class Skewness(child: Expression) extends CentralMomentAgg1(child) { + override def foldable: Boolean = false override def prettyName: String = "skewness" @@ -1055,7 +1021,13 @@ case class Skewness(child: Expression) extends CentralMomentAgg1(child) { } // placeholder -case class Variance(child: Expression) extends CentralMomentAgg1(child) { +case class Variance(child: Expression) extends UnaryExpression with AggregateExpression { + + override def nullable: Boolean = false + + override def dataType: DoubleType.type = DoubleType + + override def foldable: Boolean = false override def prettyName: String = "variance" @@ -1063,7 +1035,13 @@ case class Variance(child: Expression) extends CentralMomentAgg1(child) { } // placeholder -case class VariancePop(child: Expression) extends CentralMomentAgg1(child) { +case class VariancePop(child: Expression) extends UnaryExpression with AggregateExpression { + + override def nullable: Boolean = false + + override def dataType: DoubleType.type = DoubleType + + override def foldable: Boolean = false override def prettyName: String = "variance_pop" @@ -1071,7 +1049,13 @@ case class VariancePop(child: Expression) extends CentralMomentAgg1(child) { } // placeholder -case class VarianceSamp(child: Expression) extends CentralMomentAgg1(child) { +case class VarianceSamp(child: Expression) extends UnaryExpression with AggregateExpression { + + override def nullable: Boolean = false + + override def dataType: DoubleType.type = DoubleType + + override def foldable: Boolean = false override def prettyName: String = "variance_samp" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 0475d98fb8738..0da52e966df36 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -223,38 +223,75 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } test("moments") { - checkAnswer( - testData2.agg(skewness('a)), - Row(0.0)) + val absTol = 1e-8 + + val sparkVariance = testData2.agg(variance('a)) + val expectedVariance = Row(4.0 / 5.0) + checkAggregatesWithTol(sparkVariance, expectedVariance, absTol) + val sparkVarianceSamp = testData2.agg(var_samp('a)) + checkAggregatesWithTol(sparkVarianceSamp, expectedVariance, absTol) + + val sparkVariancePop = testData2.agg(var_pop('a)) + val expectedVariancePop = Row(4.0 / 6.0) + checkAggregatesWithTol(sparkVariancePop, expectedVariancePop, absTol) + + val sparkSkewness= testData2.agg(skewness('a)) + val expectedSkewness = Row(0.0) + checkAggregatesWithTol(sparkSkewness, expectedSkewness, absTol) + + val sparkKurtosis = testData2.agg(kurtosis('a)) + val expectedKurtosis = Row(-1.5) + checkAggregatesWithTol(sparkKurtosis, expectedKurtosis, absTol) - checkAnswer( - testData2.agg(kurtosis('a)), - Row(-1.5)) } test("zero moments") { val emptyTableData = Seq((1, 2)).toDF("a", "b") assert(emptyTableData.count() === 1) + checkAnswer( + emptyTableData.agg(variance('a)), + Row(Double.NaN)) + + checkAnswer( + emptyTableData.agg(var_samp('a)), + Row(Double.NaN)) + + checkAnswer( + emptyTableData.agg(var_pop('a)), + Row(Double.NaN)) + checkAnswer( emptyTableData.agg(skewness('a)), - Row(0.0)) + Row(Double.NaN)) checkAnswer( emptyTableData.agg(kurtosis('a)), - Row(-3.0)) + Row(Double.NaN)) } test("null moments") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") assert(emptyTableData.count() === 0) + checkAnswer( + emptyTableData.agg(variance('a)), + Row(Double.NaN)) + + checkAnswer( + emptyTableData.agg(var_samp('a)), + Row(Double.NaN)) + + checkAnswer( + emptyTableData.agg(var_pop('a)), + Row(Double.NaN)) + checkAnswer( emptyTableData.agg(skewness('a)), - Row(null)) + Row(Double.NaN)) checkAnswer( emptyTableData.agg(kurtosis('a)), - Row(null)) + Row(Double.NaN)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index e3c5a426671d0..7ca38c7666dff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -88,6 +88,29 @@ abstract class QueryTest extends PlanTest { checkAnswer(df, expectedAnswer.collect()) } + /** + * Runs the plan and makes sure the answer is within absTol of the expected result. + * @param dataFrame the [[DataFrame]] to be executed + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param absTol the absolute tolerance between actual and expected answers. + */ + protected def checkAggregatesWithTol(dataFrame: DataFrame, expectedAnswer: Seq[Row], absTol: Double): Unit = { + // TODO: catch exceptions in data frame execution + val actualAnswer = dataFrame.collect() + require(actualAnswer.length == expectedAnswer.length, + s"actual num rows ${actualAnswer.length} != expected num of rows ${expectedAnswer.length}") + + actualAnswer.zip(expectedAnswer).foreach { + case (actualRow, expectedRow) => QueryTest.checkAggregatesWithTol(actualRow, expectedRow, absTol) + } + } + + protected def checkAggregatesWithTol(dataFrame: DataFrame, + expectedAnswer: Row, + absTol: Double): Unit = { + checkAggregatesWithTol(dataFrame, Seq(expectedAnswer), absTol) + } + /** * Asserts that a given [[DataFrame]] will be executed using the given number of cached results. */ @@ -168,6 +191,28 @@ object QueryTest { return None } + /** + * Runs the plan and makes sure the answer is within absTol of the expected result. + * @param actualAnswer the actual result in a [[Row]]. + * @param expectedAnswer the expected result in a[[Row]]. + * @param absTol the absolute tolerance between actual and expected answers. + */ + protected def checkAggregatesWithTol(actualAnswer: Row, expectedAnswer: Row, absTol: Double) = { + require(actualAnswer.length == expectedAnswer.length, + s"actual answer length ${actualAnswer.length} != " + + s"expected answer length ${expectedAnswer.length}") + + // TODO: support other numeric types besides Double + // TODO: support struct types? + actualAnswer.toSeq.zip(expectedAnswer.toSeq).foreach { + case (actual: Double, expected: Double) => + assert(math.abs(actual - expected) < absTol, + s"actual answer $actual not within $absTol of correct answer $expected") + case (actual, expected) => + assert(actual == expected, s"$actual did not equal $expected") + } + } + def checkAnswer(df: DataFrame, expectedAnswer: java.util.List[Row]): String = { checkAnswer(df, expectedAnswer.asScala) match { case Some(errorMessage) => errorMessage diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 56c7ceadf2e80..7ae563f5b9de4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -517,8 +517,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("aggregates with nulls") { checkAnswer( sql("SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a)," + - "AVG(a), STDDEV(a), SUM(a), COUNT(a) FROM nullInts"), - Row(0, -1.5, 1, 3, 2, 1, 6, 3) + "AVG(a), VARIANCE(a), STDDEV(a), SUM(a), COUNT(a) FROM nullInts"), + Row(0, -1.5, 1, 3, 2, 1, 1, 6, 3) ) } @@ -711,14 +711,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("stddev") { checkAnswer( sql("SELECT STDDEV(a) FROM testData2"), - Row(math.sqrt(4/5.0)) + Row(math.sqrt(4.0 / 5.0)) ) } test("stddev_pop") { checkAnswer( sql("SELECT STDDEV_POP(a) FROM testData2"), - Row(math.sqrt(4/6.0)) + Row(math.sqrt(4.0 / 6.0)) ) } @@ -729,55 +729,39 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) } - private[this] def checkAnswerWithTol(dataFrame: DataFrame, - expectedAnswer: Array[Double], - absTol: Double = 0.0): Unit = { - - val sparkAnswer = dataFrame.first().toSeq - require(sparkAnswer.length == expectedAnswer.length, - s"spark and expected answer lengths should" + - s" be equal: ${sparkAnswer.length} != ${expectedAnswer.length}") - - sparkAnswer.zip(expectedAnswer).foreach { - case (spark: Double, expected: Double) => - assert(math.abs(spark - expected) < absTol, - s"actual answer $spark not within $absTol of correct answer $expected") - } - } - test("var_samp") { val absTol = 1e-8 val sparkAnswer = sql("SELECT VAR_SAMP(a) FROM testData2") - val expectedAnswer = Array(4.0 / 5.0) - checkAnswerWithTol(sparkAnswer, expectedAnswer, absTol) + val expectedAnswer = Row(4.0 / 5.0) + checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) } test("variance") { val absTol = 1e-8 val sparkAnswer = sql("SELECT VARIANCE(a) FROM testData2") - val expectedAnswer = Array(4.0 / 5.0) - checkAnswerWithTol(sparkAnswer, expectedAnswer, absTol) + val expectedAnswer = Row(4.0 / 5.0) + checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) } test("var_pop") { val absTol = 1e-8 val sparkAnswer = sql("SELECT VAR_POP(a) FROM testData2") - val expectedAnswer = Array(4.0 / 6.0) - checkAnswerWithTol(sparkAnswer, expectedAnswer, absTol) + val expectedAnswer = Row(4.0 / 6.0) + checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) } test("skewness") { val absTol = 1e-8 val sparkAnswer = sql("SELECT skewness(a) FROM testData2") - val expectedAnswer = Array(0.0) - checkAnswerWithTol(sparkAnswer, expectedAnswer, absTol) + val expectedAnswer = Row(0.0) + checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) } test("kurtosis") { val absTol = 1e-8 val sparkAnswer = sql("SELECT kurtosis(a) FROM testData2") - val expectedAnswer = Array(-1.5) - checkAnswerWithTol(sparkAnswer, expectedAnswer, absTol) + val expectedAnswer = Row(-1.5) + checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) } test("stddev agg") { @@ -786,10 +770,18 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { (1 to 3).map(i => Row(i, math.sqrt(1 / 2.0), math.sqrt(1 / 4.0), math.sqrt(1 / 2.0)))) } + test("variance agg") { + val absTol = 1e-8 + val sparkAnswer = sql("SELECT a, variance(b), var_samp(b), var_pop(b) FROM testData2 GROUP BY a") + val expectedAnswer = (1 to 3).map(i => Row(i, 1.0 / 2.0, 1.0 / 2.0, 1.0 / 4.0)) + checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) + } + test("skewness and kurtosis agg") { - checkAnswer( - sql("SELECT a, skewness(b), kurtosis(b) FROM testData2 GROUP BY a"), - (1 to 3).map(i => Row(i, 0.0, -2.0))) + val absTol = 1e-8 + val sparkAnswer = sql("SELECT a, skewness(b), kurtosis(b) FROM testData2 GROUP BY a") + val expectedAnswer = (1 to 3).map(i => Row(i, 0.0, -2.0)) + checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) } test("inner join where, one match per row") { From 3ef2faaa83104273ab185a3190f2d133a04693bb Mon Sep 17 00:00:00 2001 From: sethah Date: Thu, 22 Oct 2015 12:15:26 -0700 Subject: [PATCH 16/22] style and readability updates --- .../sql/catalyst/expressions/aggregate/functions.scala | 4 +--- .../org/apache/spark/sql/DataFrameAggregateSuite.scala | 2 +- .../src/test/scala/org/apache/spark/sql/QueryTest.scala | 7 +++++-- .../test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 5 +++-- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 8d83fa44de3f1..e07e5d383b0d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -994,10 +994,8 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w * Initialize all moments to zero. */ override def initialize(buffer: MutableRow): Unit = { - var aggIndex = 0 - while (aggIndex < numMoments) { + for (aggIndex <- 0 until numMoments) { buffer.setDouble(mutableAggBufferOffset + aggIndex, 0.0) - aggIndex += 1 } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 0da52e966df36..fe1692c7c45a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -235,7 +235,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { val expectedVariancePop = Row(4.0 / 6.0) checkAggregatesWithTol(sparkVariancePop, expectedVariancePop, absTol) - val sparkSkewness= testData2.agg(skewness('a)) + val sparkSkewness = testData2.agg(skewness('a)) val expectedSkewness = Row(0.0) checkAggregatesWithTol(sparkSkewness, expectedSkewness, absTol) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 7ca38c7666dff..ec2dada8e46db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -94,14 +94,17 @@ abstract class QueryTest extends PlanTest { * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. * @param absTol the absolute tolerance between actual and expected answers. */ - protected def checkAggregatesWithTol(dataFrame: DataFrame, expectedAnswer: Seq[Row], absTol: Double): Unit = { + protected def checkAggregatesWithTol(dataFrame: DataFrame, + expectedAnswer: Seq[Row], + absTol: Double): Unit = { // TODO: catch exceptions in data frame execution val actualAnswer = dataFrame.collect() require(actualAnswer.length == expectedAnswer.length, s"actual num rows ${actualAnswer.length} != expected num of rows ${expectedAnswer.length}") actualAnswer.zip(expectedAnswer).foreach { - case (actualRow, expectedRow) => QueryTest.checkAggregatesWithTol(actualRow, expectedRow, absTol) + case (actualRow, expectedRow) => + QueryTest.checkAggregatesWithTol(actualRow, expectedRow, absTol) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 7ae563f5b9de4..24b30b74e8ba0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -767,12 +767,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("stddev agg") { checkAnswer( sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"), - (1 to 3).map(i => Row(i, math.sqrt(1 / 2.0), math.sqrt(1 / 4.0), math.sqrt(1 / 2.0)))) + (1 to 3).map(i => Row(i, math.sqrt(1.0 / 2.0), math.sqrt(1.0 / 4.0), math.sqrt(1.0 / 2.0)))) } test("variance agg") { val absTol = 1e-8 - val sparkAnswer = sql("SELECT a, variance(b), var_samp(b), var_pop(b) FROM testData2 GROUP BY a") + val sparkAnswer = sql("SELECT a, variance(b), var_samp(b), var_pop(b)" + + "FROM testData2 GROUP BY a") val expectedAnswer = (1 to 3).map(i => Row(i, 1.0 / 2.0, 1.0 / 2.0, 1.0 / 4.0)) checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) } From fd3f4d6f9ba5124406a7078c9e7991bf91abdad6 Mon Sep 17 00:00:00 2001 From: sethah Date: Fri, 23 Oct 2015 10:10:09 -0700 Subject: [PATCH 17/22] addressing feedback --- .../expressions/aggregate/functions.scala | 223 +++++++++--------- 1 file changed, 113 insertions(+), 110 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index e07e5d383b0d3..613bcea4a19d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -939,6 +939,9 @@ object HyperLogLogPlusPlus { * This class implements online, one-pass algorithms for computing the central moments of a set of * points. * + * Returns `Double.NaN` when N = 0 or N = 1 + * -third and fourth moments return `Double.NaN` when second moment is zero + * * References: * - Xiangrui Meng. "Simpler Online Updates for Arbitrary-Order Central Moments." * 2015. http://arxiv.org/abs/1510.04923 @@ -951,15 +954,10 @@ object HyperLogLogPlusPlus { abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate with Serializable { /** - * The maximum central moment order to be computed. + * The central moment order to be computed. */ protected def momentOrder: Int - /** - * Array of sufficient moments need to compute the aggregate statistic. - */ - protected def sufficientMoments: Array[Int] - override def children: Seq[Expression] = Seq(child) override def nullable: Boolean = false @@ -977,11 +975,11 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) /** - * The number of central moments to store in the buffer. + * Size of aggregation buffer. */ - private[this] val numMoments = 5 + private[this] val bufferSize = 5 - override val aggBufferAttributes: Seq[AttributeReference] = Seq.tabulate(numMoments) { i => + override val aggBufferAttributes: Seq[AttributeReference] = Seq.tabulate(bufferSize) { i => AttributeReference(s"M$i", DoubleType)() } @@ -990,21 +988,33 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w override val inputAggBufferAttributes: Seq[AttributeReference] = aggBufferAttributes.map(_.newInstance()) + // buffer offsets + private[this] val nOffset = mutableAggBufferOffset + private[this] val meanOffset = mutableAggBufferOffset + 1 + private[this] val secondMomentOffset = mutableAggBufferOffset + 2 + private[this] val thirdMomentOffset = mutableAggBufferOffset + 3 + private[this] val fourthMomentOffset = mutableAggBufferOffset + 4 + + // frequently used values for online updates + private[this] var delta = 0.0 + private[this] var deltaN = 0.0 + private[this] var delta2 = 0.0 + private[this] var deltaN2 = 0.0 + private[this] var n = 0.0 + private[this] var mean = 0.0 + private[this] var m2 = 0.0 + private[this] var m3 = 0.0 + private[this] var m4 = 0.0 + /** * Initialize all moments to zero. */ override def initialize(buffer: MutableRow): Unit = { - for (aggIndex <- 0 until numMoments) { + for (aggIndex <- 0 until bufferSize) { buffer.setDouble(mutableAggBufferOffset + aggIndex, 0.0) } } - // frequently used values for online updates - private[this] var delta = 0.0 - private[this] var deltaN = 0.0 - private[this] var delta2 = 0.0 - private[this] var deltaN2 = 0.0 - /** * Update the central moments buffer. */ @@ -1013,40 +1023,36 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w if (v != null) { val updateValue = v match { case d: Double => d - case _ => 0.0 } - var n = buffer.getDouble(mutableAggBufferOffset) - var mean = buffer.getDouble(mutableAggBufferOffset + 1) - var m2 = 0.0 - var m3 = 0.0 - var m4 = 0.0 + n = buffer.getDouble(nOffset) + mean = buffer.getDouble(meanOffset) n += 1.0 + buffer.setDouble(nOffset, n) delta = updateValue - mean deltaN = delta / n mean += deltaN - buffer.setDouble(mutableAggBufferOffset, n) - buffer.setDouble(mutableAggBufferOffset + 1, mean) + buffer.setDouble(meanOffset, mean) if (momentOrder >= 2) { - m2 = buffer.getDouble(mutableAggBufferOffset + 2) + m2 = buffer.getDouble(secondMomentOffset) m2 += delta * (delta - deltaN) - buffer.setDouble(mutableAggBufferOffset + 2, m2) + buffer.setDouble(secondMomentOffset, m2) } if (momentOrder >= 3) { delta2 = delta * delta deltaN2 = deltaN * deltaN - m3 = buffer.getDouble(mutableAggBufferOffset + 3) + m3 = buffer.getDouble(thirdMomentOffset) m3 += -3.0 * deltaN * m2 + delta * (delta2 - deltaN2) - buffer.setDouble(mutableAggBufferOffset + 3, m3) + buffer.setDouble(thirdMomentOffset, m3) } if (momentOrder >= 4) { - m4 = buffer.getDouble(mutableAggBufferOffset + 4) + m4 = buffer.getDouble(fourthMomentOffset) m4 += -4.0 * deltaN * m3 - 6.0 * deltaN2 * m2 + delta * (delta * delta2 - deltaN * deltaN2) - buffer.setDouble(mutableAggBufferOffset + 4, m4) + buffer.setDouble(fourthMomentOffset, m4) } } } @@ -1055,106 +1061,85 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w * Merge two central moment buffers. */ override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - val n1 = buffer1.getDouble(mutableAggBufferOffset) + val n1 = buffer1.getDouble(nOffset) val n2 = buffer2.getDouble(inputAggBufferOffset) - val mean1 = buffer1.getDouble(mutableAggBufferOffset + 1) + val mean1 = buffer1.getDouble(meanOffset) val mean2 = buffer2.getDouble(inputAggBufferOffset + 1) var secondMoment1 = 0.0 var secondMoment2 = 0.0 - var secondMoment = 0.0 var thirdMoment1 = 0.0 var thirdMoment2 = 0.0 - var thirdMoment = 0.0 var fourthMoment1 = 0.0 var fourthMoment2 = 0.0 - var fourthMoment = 0.0 - val n = n1 + n2 + n = n1 + n2 + buffer1.setDouble(nOffset, n) delta = mean2 - mean1 deltaN = delta / n - val mean = mean1 + deltaN * n2 - - buffer1.setDouble(mutableAggBufferOffset, n) + mean = mean1 + deltaN * n buffer1.setDouble(mutableAggBufferOffset + 1, mean) // higher order moments computed according to: // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics if (momentOrder >= 2) { - secondMoment1 = buffer1.getDouble(mutableAggBufferOffset + 2) + secondMoment1 = buffer1.getDouble(secondMomentOffset) secondMoment2 = buffer2.getDouble(inputAggBufferOffset + 2) - secondMoment = secondMoment1 + secondMoment2 + delta * deltaN * n1 * n2 - buffer1.setDouble(mutableAggBufferOffset + 2, secondMoment) + m2 = secondMoment1 + secondMoment2 + delta * deltaN * n1 * n2 + buffer1.setDouble(secondMomentOffset, m2) } - if (momentOrder >= 3) { - thirdMoment1 = buffer1.getDouble(mutableAggBufferOffset + 3) + thirdMoment1 = buffer1.getDouble(thirdMomentOffset) thirdMoment2 = buffer2.getDouble(inputAggBufferOffset + 3) - thirdMoment = thirdMoment1 + thirdMoment2 + deltaN * deltaN * delta * n1 * n2 * + m3 = thirdMoment1 + thirdMoment2 + deltaN * deltaN * delta * n1 * n2 * (n1 - n2) + 3.0 * deltaN * (n1 * secondMoment2 - n2 * secondMoment1) - buffer1.setDouble(mutableAggBufferOffset + 3, thirdMoment) + buffer1.setDouble(thirdMomentOffset, m3) } if (momentOrder >= 4) { - fourthMoment1 = buffer1.getDouble(mutableAggBufferOffset + 4) + fourthMoment1 = buffer1.getDouble(fourthMomentOffset) fourthMoment2 = buffer2.getDouble(inputAggBufferOffset + 4) - fourthMoment = fourthMoment1 + fourthMoment2 + deltaN * deltaN * deltaN * delta * n1 * + m4 = fourthMoment1 + fourthMoment2 + deltaN * deltaN * deltaN * delta * n1 * n2 * (n1 * n1 - n1 * n2 + n2 * n2) + deltaN * deltaN * 6.0 * (n1 * n1 * secondMoment2 + n2 * n2 * secondMoment1) + 4.0 * deltaN * (n1 * thirdMoment2 - n2 * thirdMoment1) - buffer1.setDouble(mutableAggBufferOffset + 4, fourthMoment) + buffer1.setDouble(fourthMomentOffset, m4) } } /** * Compute aggregate statistic from sufficient moments. + * @param centralMoments Length `momentOrder + 1` array of central moments needed to + * compute the aggregate stat. */ - def getStatistic(n: Double, moments: Array[Double]): Double + def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Double override final def eval(buffer: InternalRow): Any = { - val n = buffer.getDouble(mutableAggBufferOffset) - val moments = sufficientMoments.map { momentIdx => - buffer.getDouble(mutableAggBufferOffset + momentIdx) + val n = buffer.getDouble(nOffset) + val mean = buffer.getDouble(meanOffset) + val moments = Array.ofDim[Double](momentOrder + 1) + moments(0) = n + moments(1) = mean + if (momentOrder >= 2) { + moments(2) = buffer.getDouble(secondMomentOffset) } - getStatistic(n, moments) - } -} - -abstract class SecondMoment(child: Expression) extends CentralMomentAgg(child) { - - override protected val momentOrder = 2 - - protected def isBiased: Boolean - - protected def isStd: Boolean - - override protected val sufficientMoments = Array(2) - - override def getStatistic(n: Double, moments: Array[Double]): Double = { - require(moments.length == sufficientMoments.length, - s"$prettyName requires one central moment, received: ${moments.length}") - - val m2 = moments.head - val divisor = if (isBiased) n else n - 1 - val variance = if (n == 0.0) { - Double.NaN - } else { - m2 / divisor + if (momentOrder >= 3) { + moments(3) = buffer.getDouble(thirdMomentOffset) } - - if (isStd) { - math.sqrt(variance) - } else { - variance + if (momentOrder >= 4) { + moments(4) = buffer.getDouble(fourthMomentOffset) } + + getStatistic(n, mean, moments) } } -case class Variance(child: Expression, mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) extends SecondMoment(child) { +case class Variance(child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) @@ -1164,13 +1149,19 @@ case class Variance(child: Expression, mutableAggBufferOffset: Int = 0, override def prettyName: String = "variance" - override protected val isBiased = false + override protected val momentOrder = 2 + + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + require(moments.length == momentOrder + 1, + s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") - override protected val isStd = false + if (n == 0.0 || n == 1.0) Double.NaN else moments(2) / (n - 1.0) + } } -case class VarianceSamp(child: Expression, mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) extends SecondMoment(child) { +case class VarianceSamp(child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) @@ -1180,13 +1171,19 @@ case class VarianceSamp(child: Expression, mutableAggBufferOffset: Int = 0, override def prettyName: String = "variance_samp" - override protected val isBiased = false + override protected val momentOrder = 2 + + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + require(moments.length == momentOrder + 1, + s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - override protected val isStd = false + if (n == 0.0 || n == 1.0) Double.NaN else moments(2) / (n - 1.0) + } } -case class VariancePop(child: Expression, mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) extends SecondMoment(child) { +case class VariancePop(child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) @@ -1196,13 +1193,19 @@ case class VariancePop(child: Expression, mutableAggBufferOffset: Int = 0, override def prettyName: String = "variance_pop" - override protected val isBiased = true + override protected val momentOrder = 2 - override protected val isStd = false + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + require(moments.length == momentOrder + 1, + s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") + + if (n == 0.0 || n == 1.0) Double.NaN else moments(2) / n + } } -case class Skewness(child: Expression, mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { +case class Skewness(child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) @@ -1214,12 +1217,11 @@ case class Skewness(child: Expression, mutableAggBufferOffset: Int = 0, override protected val momentOrder = 3 - override protected val sufficientMoments = Array(2, 3) - - override def getStatistic(n: Double, moments: Array[Double]): Double = { - require(moments.length == sufficientMoments.length, - s"$prettyName requires two central moments, received: ${moments.length}") - val Array(m2, m3) = moments + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + require(moments.length == momentOrder + 1, + s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") + val m2 = moments(2) + val m3 = moments(3) if (n == 0.0 || m2 == 0.0) { Double.NaN } else { @@ -1228,8 +1230,9 @@ case class Skewness(child: Expression, mutableAggBufferOffset: Int = 0, } } -case class Kurtosis(child: Expression, mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { +case class Kurtosis(child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) @@ -1241,12 +1244,12 @@ case class Kurtosis(child: Expression, mutableAggBufferOffset: Int = 0, override protected val momentOrder = 4 - override protected val sufficientMoments = Array(2, 4) - - override def getStatistic(n: Double, moments: Array[Double]): Double = { - require(moments.length == sufficientMoments.length, - s"$prettyName requires two central moments, received: ${moments.length}") - val Array(m2, m4) = moments + // NOTE: this is the formula for excess kurtosis, which is default for R and SciPy + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + require(moments.length == momentOrder + 1, + s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") + val m2 = moments(2) + val m4 = moments(4) if (n == 0.0 || m2 == 0.0) { Double.NaN } else { From cf8a14bb24924079af2c30234a083a1e6c6d4c23 Mon Sep 17 00:00:00 2001 From: sethah Date: Fri, 23 Oct 2015 16:03:23 -0700 Subject: [PATCH 18/22] correcting error in merge function --- .../expressions/aggregate/functions.scala | 20 ++++++++++--------- .../spark/sql/DataFrameAggregateSuite.scala | 2 +- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 613bcea4a19d7..857cbc157cf54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -939,8 +939,9 @@ object HyperLogLogPlusPlus { * This class implements online, one-pass algorithms for computing the central moments of a set of * points. * - * Returns `Double.NaN` when N = 0 or N = 1 - * -third and fourth moments return `Double.NaN` when second moment is zero + * Behavior: + * - null values are ignored + * - returns `Double.NaN` when the column contains `Double.NaN` values * * References: * - Xiangrui Meng. "Simpler Online Updates for Arbitrary-Order Central Moments." @@ -1024,6 +1025,7 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w val updateValue = v match { case d: Double => d } + n = buffer.getDouble(nOffset) mean = buffer.getDouble(meanOffset) @@ -1078,8 +1080,8 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w n = n1 + n2 buffer1.setDouble(nOffset, n) delta = mean2 - mean1 - deltaN = delta / n - mean = mean1 + deltaN * n + deltaN = if (n == 0.0) 0.0 else delta / n + mean = mean1 + deltaN * n2 buffer1.setDouble(mutableAggBufferOffset + 1, mean) // higher order moments computed according to: @@ -1112,8 +1114,8 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w /** * Compute aggregate statistic from sufficient moments. - * @param centralMoments Length `momentOrder + 1` array of central moments needed to - * compute the aggregate stat. + * @param centralMoments Length `momentOrder + 1` array of central moments (un-normalized) + * needed to compute the aggregate stat. */ def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Double @@ -1121,8 +1123,8 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w val n = buffer.getDouble(nOffset) val mean = buffer.getDouble(meanOffset) val moments = Array.ofDim[Double](momentOrder + 1) - moments(0) = n - moments(1) = mean + moments(0) = 1.0 + moments(1) = 0.0 if (momentOrder >= 2) { moments(2) = buffer.getDouble(secondMomentOffset) } @@ -1199,7 +1201,7 @@ case class VariancePop(child: Expression, require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - if (n == 0.0 || n == 1.0) Double.NaN else moments(2) / n + if (n == 0.0) Double.NaN else moments(2) / n } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index fe1692c7c45a7..cb045e7fd241b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -259,7 +259,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { checkAnswer( emptyTableData.agg(var_pop('a)), - Row(Double.NaN)) + Row(0.0)) checkAnswer( emptyTableData.agg(skewness('a)), From b86386a78ab4b27ca9b6b84afe6db4973204e9e9 Mon Sep 17 00:00:00 2001 From: sethah Date: Fri, 23 Oct 2015 16:43:23 -0700 Subject: [PATCH 19/22] adding back some stddev codegen tests --- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 24b30b74e8ba0..c5274183a2dde 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -329,6 +329,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { testCodeGen( "SELECT min(key) FROM testData3x", Row(1) :: Nil) + // STDDEV + testCodeGen( + "SELECT a, stddev(b), stddev_pop(b) FROM testData2 GROUP BY a", + (1 to 3).map(i => Row(i, math.sqrt(0.5), math.sqrt(0.25)))) + testCodeGen( + "SELECT stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2", + Row(math.sqrt(1.5 / 5), math.sqrt(1.5 / 6), math.sqrt(1.5 / 5)) :: Nil) // Some combinations. testCodeGen( """ @@ -349,8 +356,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(100, 1, 50.5, 300, 100) :: Nil) // Aggregate with Code generation handling all null values testCodeGen( - "SELECT sum('a'), avg('a'), count(null) FROM testData", - Row(null, null, 0) :: Nil) + "SELECT sum('a'), avg('a'), stddev('a'), count(null) FROM testData", + Row(null, null, null, 0) :: Nil) } finally { sqlContext.dropTempTable("testData3x") sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue) From 3045e3b1d82ac73d154dc4d2165e920c74bdc118 Mon Sep 17 00:00:00 2001 From: sethah Date: Mon, 26 Oct 2015 10:06:06 -0700 Subject: [PATCH 20/22] changing variance to default to population variance --- .../catalyst/expressions/aggregate/functions.scala | 2 +- .../scala/org/apache/spark/sql/functions.scala | 4 ++-- .../apache/spark/sql/DataFrameAggregateSuite.scala | 14 +++++++------- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 6 +++--- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 857cbc157cf54..281404f285a98 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -1157,7 +1157,7 @@ case class Variance(child: Expression, require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") - if (n == 0.0 || n == 1.0) Double.NaN else moments(2) / (n - 1.0) + if (n == 0.0) Double.NaN else moments(2) / n } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 08ad17b9b5b67..c1737b1ef663c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -413,7 +413,7 @@ object functions { def sumDistinct(columnName: String): Column = sumDistinct(Column(columnName)) /** - * Aggregate function: returns the unbiased variance of the values in a group. + * Aggregate function: returns the population variance of the values in a group. * * @group agg_funcs * @since 1.6.0 @@ -421,7 +421,7 @@ object functions { def variance(e: Column): Column = Variance(e.expr) /** - * Aggregate function: returns the unbiased variance of the values in a group. + * Aggregate function: returns the population variance of the values in a group. * * @group agg_funcs * @since 1.6.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index cb045e7fd241b..9b23977c765dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -226,14 +226,14 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { val absTol = 1e-8 val sparkVariance = testData2.agg(variance('a)) - val expectedVariance = Row(4.0 / 5.0) + val expectedVariance = Row(4.0 / 6.0) checkAggregatesWithTol(sparkVariance, expectedVariance, absTol) - val sparkVarianceSamp = testData2.agg(var_samp('a)) - checkAggregatesWithTol(sparkVarianceSamp, expectedVariance, absTol) - val sparkVariancePop = testData2.agg(var_pop('a)) - val expectedVariancePop = Row(4.0 / 6.0) - checkAggregatesWithTol(sparkVariancePop, expectedVariancePop, absTol) + checkAggregatesWithTol(sparkVariancePop, expectedVariance, absTol) + + val sparkVarianceSamp = testData2.agg(var_samp('a)) + val expectedVarianceSamp = Row(4.0 / 5.0) + checkAggregatesWithTol(sparkVarianceSamp, expectedVarianceSamp, absTol) val sparkSkewness = testData2.agg(skewness('a)) val expectedSkewness = Row(0.0) @@ -251,7 +251,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { checkAnswer( emptyTableData.agg(variance('a)), - Row(Double.NaN)) + Row(0.0)) checkAnswer( emptyTableData.agg(var_samp('a)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index c5274183a2dde..af653158e3474 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -525,7 +525,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer( sql("SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a)," + "AVG(a), VARIANCE(a), STDDEV(a), SUM(a), COUNT(a) FROM nullInts"), - Row(0, -1.5, 1, 3, 2, 1, 1, 6, 3) + Row(0, -1.5, 1, 3, 2, 2.0 / 3.0, 1, 6, 3) ) } @@ -746,7 +746,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("variance") { val absTol = 1e-8 val sparkAnswer = sql("SELECT VARIANCE(a) FROM testData2") - val expectedAnswer = Row(4.0 / 5.0) + val expectedAnswer = Row(4.0 / 6.0) checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) } @@ -781,7 +781,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { val absTol = 1e-8 val sparkAnswer = sql("SELECT a, variance(b), var_samp(b), var_pop(b)" + "FROM testData2 GROUP BY a") - val expectedAnswer = (1 to 3).map(i => Row(i, 1.0 / 2.0, 1.0 / 2.0, 1.0 / 4.0)) + val expectedAnswer = (1 to 3).map(i => Row(i, 1.0 / 4.0, 1.0 / 2.0, 1.0 / 4.0)) checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) } From ff363cca57e2b1c2bb28e281d014d33b930fd603 Mon Sep 17 00:00:00 2001 From: sethah Date: Thu, 29 Oct 2015 14:06:36 +0100 Subject: [PATCH 21/22] removing fetch_aggregation from whitelist --- .../apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index eed9e436f9af7..9e357bf348c94 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -467,7 +467,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "escape_orderby1", "escape_sortby1", "explain_rearrange", - "fetch_aggregation", "fileformat_mix", "fileformat_sequencefile", "fileformat_text", From f49ce5c0d6594f3b5348eda75e9aefed10617066 Mon Sep 17 00:00:00 2001 From: sethah Date: Thu, 29 Oct 2015 17:53:59 +0100 Subject: [PATCH 22/22] Throw UnsupportedOperationException for AggregateExpression1 --- .../sql/catalyst/expressions/aggregates.scala | 35 ++++++++++++++++--- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 7896dcfb4bc82..411d2f56eba16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -993,7 +993,12 @@ case class StddevFunction( } // placeholder -case class Kurtosis(child: Expression) extends UnaryExpression with AggregateExpression { +case class Kurtosis(child: Expression) extends UnaryExpression with AggregateExpression1 { + + override def newInstance(): AggregateFunction1 = { + throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " + + "please set spark.sql.useAggregate2 = true") + } override def nullable: Boolean = false @@ -1007,7 +1012,12 @@ case class Kurtosis(child: Expression) extends UnaryExpression with AggregateExp } // placeholder -case class Skewness(child: Expression) extends UnaryExpression with AggregateExpression { +case class Skewness(child: Expression) extends UnaryExpression with AggregateExpression1 { + + override def newInstance(): AggregateFunction1 = { + throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " + + "please set spark.sql.useAggregate2 = true") + } override def nullable: Boolean = false @@ -1021,7 +1031,12 @@ case class Skewness(child: Expression) extends UnaryExpression with AggregateExp } // placeholder -case class Variance(child: Expression) extends UnaryExpression with AggregateExpression { +case class Variance(child: Expression) extends UnaryExpression with AggregateExpression1 { + + override def newInstance(): AggregateFunction1 = { + throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " + + "please set spark.sql.useAggregate2 = true") + } override def nullable: Boolean = false @@ -1035,7 +1050,12 @@ case class Variance(child: Expression) extends UnaryExpression with AggregateExp } // placeholder -case class VariancePop(child: Expression) extends UnaryExpression with AggregateExpression { +case class VariancePop(child: Expression) extends UnaryExpression with AggregateExpression1 { + + override def newInstance(): AggregateFunction1 = { + throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " + + "please set spark.sql.useAggregate2 = true") + } override def nullable: Boolean = false @@ -1049,7 +1069,12 @@ case class VariancePop(child: Expression) extends UnaryExpression with Aggregate } // placeholder -case class VarianceSamp(child: Expression) extends UnaryExpression with AggregateExpression { +case class VarianceSamp(child: Expression) extends UnaryExpression with AggregateExpression1 { + + override def newInstance(): AggregateFunction1 = { + throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " + + "please set spark.sql.useAggregate2 = true") + } override def nullable: Boolean = false