From 4a6f903897d28a3038918997e692410259a90ae3 Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Fri, 19 Jun 2020 10:36:52 +0800 Subject: [PATCH 01/15] Reuse completeNextStageWithFetchFailure --- .../apache/spark/scheduler/DAGSchedulerSuite.scala | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 9d412f2dba3ce..762b14e170fcc 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -1796,9 +1796,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // lets say there is a fetch failure in this task set, which makes us go back and // run stage 0, attempt 1 - complete(taskSets(1), Seq( - (FetchFailed(makeBlockManagerId("hostA"), - shuffleDep1.shuffleId, 0L, 0, 0, "ignored"), null))) + completeNextStageWithFetchFailure(1, 0, shuffleDep1) scheduler.resubmitFailedStages() // stage 0, attempt 1 should have the properties of job2 @@ -1872,9 +1870,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // have the second stage complete normally completeShuffleMapStageSuccessfully(1, 0, 1, Seq("hostA", "hostC")) // fail the third stage because hostA went down - complete(taskSets(2), Seq( - (FetchFailed(makeBlockManagerId("hostA"), - shuffleDepTwo.shuffleId, 0L, 0, 0, "ignored"), null))) + completeNextStageWithFetchFailure(2, 0, shuffleDepTwo) // TODO assert this: // blockManagerMaster.removeExecutor("hostA-exec") // have DAGScheduler try again @@ -1900,9 +1896,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // complete stage 1 completeShuffleMapStageSuccessfully(1, 0, 1) // pretend stage 2 failed because hostA went down - complete(taskSets(2), Seq( - (FetchFailed(makeBlockManagerId("hostA"), - shuffleDepTwo.shuffleId, 0L, 0, 0, "ignored"), null))) + completeNextStageWithFetchFailure(2, 0, shuffleDepTwo) // TODO assert this: // blockManagerMaster.removeExecutor("hostA-exec") // DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun. From 199aa6f637515455fd2cdd026ea5d189f6291a4f Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Tue, 28 Jul 2020 18:04:56 +0800 Subject: [PATCH 02/15] Support single distinct group with filter. --- .../sql/catalyst/analysis/Analyzer.scala | 12 +-- .../optimizer/RewriteDistinctAggregates.scala | 93 ++++++++++++++++--- .../analysis/AnalysisErrorSuite.scala | 5 - .../sql/execution/aggregate/AggUtils.scala | 15 ++- 4 files changed, 93 insertions(+), 32 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 13d98d8ce9b00..1c7aa4b532fe2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1973,15 +1973,9 @@ class Analyzer( } // We get an aggregate function, we need to wrap it in an AggregateExpression. case agg: AggregateFunction => - // TODO: SPARK-30276 Support Filter expression allows simultaneous use of DISTINCT - if (filter.isDefined) { - if (isDistinct) { - failAnalysis("DISTINCT and FILTER cannot be used in aggregate functions " + - "at the same time") - } else if (!filter.get.deterministic) { - failAnalysis("FILTER expression is non-deterministic, " + - "it cannot be used in aggregate functions") - } + if (filter.isDefined && !filter.get.deterministic) { + failAnalysis("FILTER expression is non-deterministic, " + + "it cannot be used in aggregate functions") } AggregateExpression(agg, Complete, isDistinct, filter) // This function is not an aggregate function, just return the resolved one. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 15aa02ff677de..a70b79819dfea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types.IntegerType @@ -144,28 +144,24 @@ import org.apache.spark.sql.types.IntegerType */ object RewriteDistinctAggregates extends Rule[LogicalPlan] { - private def mayNeedtoRewrite(exprs: Seq[Expression]): Boolean = { - val distinctAggs = exprs.flatMap { _.collect { - case ae: AggregateExpression if ae.isDistinct => ae - }} - // We need at least two distinct aggregates for this rule because aggregation - // strategy can handle a single distinct group. + private def mayNeedtoRewrite(a: Aggregate): Boolean = { + val aggExpressions = collectAggregateExprs(a) + val distinctAggs = aggExpressions.filter(_.isDistinct) + // We need at least two distinct aggregates or the single distinct aggregate group exists filter + // clause for this rule because aggregation strategy can handle a single distinct aggregate + // group without a filter. // This check can produce false-positives, e.g., SUM(DISTINCT a) & COUNT(DISTINCT a). - distinctAggs.size > 1 + distinctAggs.size > 1 || (distinctAggs.size == 1 && aggExpressions.exists(_.filter.isDefined)) } def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case a: Aggregate if mayNeedtoRewrite(a.aggregateExpressions) => rewrite(a) + case a: Aggregate if mayNeedtoRewrite(a) => rewrite(a) } def rewrite(a: Aggregate): Aggregate = { // Collect all aggregate expressions. - val aggExpressions = a.aggregateExpressions.flatMap { e => - e.collect { - case ae: AggregateExpression => ae - } - } + val aggExpressions = collectAggregateExprs(a) // Extract distinct aggregate expressions. val distinctAggGroups = aggExpressions.filter(_.isDistinct).groupBy { e => @@ -326,11 +322,80 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { }.asInstanceOf[NamedExpression] } Aggregate(groupByAttrs, patchedAggExpressions, firstAggregate) + } else if (mayNeedtoProject(a)) { + var currentExprId = 0 + val (projections, filterProjections, aggPairs) = aggExpressions.map { + case ae @ AggregateExpression(af, _, isDistinct, filter, _) => + val unfoldableChildren = af.children.filter(!_.foldable) + val projectionMap = unfoldableChildren.map { + case e: Expression => + currentExprId += 1 + e -> Alias(e, s"_gen_attr_$currentExprId")() + } + val projection = projectionMap.map(_._2) + val filterProjection = filter.map { + case e: Expression => + currentExprId += 1 + e -> Alias(e, s"_gen_attr_$currentExprId")() + } + val exprAttrs = projectionMap.map { kv => + (kv._1, kv._2.toAttribute) + } + val exprAttrLookup = exprAttrs.toMap + val newChildren = af.children.map(c => exprAttrLookup.getOrElse(c, c)) + val naf = af.withNewChildren(newChildren).asInstanceOf[AggregateFunction] + val aggExpr = if (isDistinct && filterProjection.isDefined) { + val newFilter = Some(EqualNullSafe(filterProjection.get._2.toAttribute, Literal(true))) + ae.copy(aggregateFunction = naf, filter = newFilter) + } else { + ae.copy(aggregateFunction = naf, filter = None) + } + (projection, filterProjection.map(_._2), (ae, aggExpr)) + }.unzip3 + val namedGroupingProjection = a.groupingExpressions.flatMap { e => + e.collect { + case ar: AttributeReference => ar + } + } + val rewriteAggProjection = (namedGroupingProjection ++ projections.flatten).distinct ++ + filterProjections.flatten + val project = Project(rewriteAggProjection, a.child) + val rewriteAggExprLookup = aggPairs.toMap + val patchedAggExpressions = a.aggregateExpressions.map { e => + e.transformDown { + case ae: AggregateExpression => rewriteAggExprLookup.getOrElse(ae, ae) + }.asInstanceOf[NamedExpression] + } + val aggregate = Aggregate(a.groupingExpressions, patchedAggExpressions, project) + if (aggExpressions.filter(_.isDistinct).size > 1) { + rewrite(aggregate) + } else { + aggregate + } } else { a } } + private def collectAggregateExprs(a: Aggregate): Seq[AggregateExpression] = { + a.aggregateExpressions.flatMap { _.collect { + case ae: AggregateExpression => ae + }} + } + + private def mayNeedtoProject(a: Aggregate): Boolean = { + var flag = true + a resolveOperatorsUp { + case p: Project => + if (p.output.exists(_.name.startsWith("_gen_attr_"))) { + flag = false + } + p + case other => other + } + flag + } + private def nullify(e: Expression) = Literal.create(null, e.dataType) private def expressionAttributePair(e: Expression) = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 166ffec44a60d..a99f7e2be6e7e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -207,11 +207,6 @@ class AnalysisErrorSuite extends AnalysisTest { "FILTER (WHERE c > 1)"), "FILTER predicate specified, but aggregate is not an aggregate function" :: Nil) - errorTest( - "DISTINCT aggregate function with filter predicate", - CatalystSqlParser.parsePlan("SELECT count(DISTINCT a) FILTER (WHERE c > 1) FROM TaBlE2"), - "DISTINCT and FILTER cannot be used in aggregate functions at the same time" :: Nil) - errorTest( "non-deterministic filter predicate in aggregate functions", CatalystSqlParser.parsePlan("SELECT count(a) FILTER (WHERE rand(int(c)) > 1) FROM TaBlE2"), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 761ac20e84744..7ad01c3d31fb9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -142,6 +142,7 @@ object AggUtils { val distinctAttributes = normalizedNamedDistinctExpressions.map(_.toAttribute) val groupingAttributes = groupingExpressions.map(_.toAttribute) + val filterWithDistinctAttributes = functionsWithDistinct.flatMap(_.filterAttributes.toSeq) // 1. Create an Aggregate Operator for partial aggregations. val partialAggregate: SparkPlan = { @@ -151,10 +152,12 @@ object AggUtils { // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping // expressions will be [key, value]. createAggregate( - groupingExpressions = groupingExpressions ++ normalizedNamedDistinctExpressions, + groupingExpressions = groupingExpressions ++ normalizedNamedDistinctExpressions ++ + filterWithDistinctAttributes, aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, resultExpressions = groupingAttributes ++ distinctAttributes ++ + filterWithDistinctAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), child = child) } @@ -166,11 +169,13 @@ object AggUtils { createAggregate( requiredChildDistributionExpressions = Some(groupingAttributes ++ distinctAttributes), - groupingExpressions = groupingAttributes ++ distinctAttributes, + groupingExpressions = groupingAttributes ++ distinctAttributes ++ + filterWithDistinctAttributes, aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length, resultExpressions = groupingAttributes ++ distinctAttributes ++ + filterWithDistinctAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), child = partialAggregate) } @@ -201,7 +206,8 @@ object AggUtils { // its input will have distinct arguments. // We just keep the isDistinct setting to true, so when users look at the query plan, // they still can see distinct aggregations. - val expr = AggregateExpression(func, Partial, isDistinct = true) + val filter = functionsWithDistinct(i).filter + val expr = AggregateExpression(func, Partial, isDistinct = true, filter) // Use original AggregationFunction to lookup attributes, which is used to build // aggregateFunctionToAttribute val attr = functionsWithDistinct(i).resultAttribute @@ -233,7 +239,8 @@ object AggUtils { // its input will have distinct arguments. // We just keep the isDistinct setting to true, so when users look at the query plan, // they still can see distinct aggregations. - val expr = AggregateExpression(func, Final, isDistinct = true) + val filter = functionsWithDistinct(i).filter + val expr = AggregateExpression(func, Final, isDistinct = true, filter) // Use original AggregationFunction to lookup attributes, which is used to build // aggregateFunctionToAttribute val attr = functionsWithDistinct(i).resultAttribute From a73f11ef1cebd7d962188727f15ee3ca8582aeec Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Wed, 29 Jul 2020 18:05:49 +0800 Subject: [PATCH 03/15] Support distinct agg with filter --- .../optimizer/RewriteDistinctAggregates.scala | 114 +++++++----------- .../sql/execution/aggregate/AggUtils.scala | 15 +-- .../sql-tests/inputs/group-by-filter.sql | 41 ++++--- .../inputs/postgreSQL/aggregates_part3.sql | 7 +- .../inputs/postgreSQL/groupingsets.sql | 5 +- 5 files changed, 78 insertions(+), 104 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index a70b79819dfea..0e258784e5969 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete} +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types.IntegerType @@ -181,7 +181,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } // Aggregation strategy can handle queries with a single distinct group. - if (distinctAggGroups.size > 1) { + if (distinctAggGroups.size >= 1) { // Create the attributes for the grouping id and the group by clause. val gid = AttributeReference("gid", IntegerType, nullable = false)() val groupByMap = a.groupingExpressions.collect { @@ -191,7 +191,13 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { val groupByAttrs = groupByMap.map(_._2) // Functions used to modify aggregate functions and their inputs. - def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e)) + def evalWithinGroup(id: Literal, e: Expression, condition: Option[Expression]) = + if (condition.isDefined) { + If(And(EqualTo(gid, id), condition.get), e, nullify(e)) + } else { + If(EqualTo(gid, id), e, nullify(e)) + } + def patchAggregateFunctionChildren( af: AggregateFunction)( attrs: Expression => Option[Expression]): AggregateFunction = { @@ -203,13 +209,33 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair) val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2) + val distinctAggExprs = aggExpressions + .filter(e => e.isDistinct && e.children.exists(!_.foldable)) + val distinctAggFilterAttrMap = distinctAggExprs.collect { + case AggregateExpression(_, _, _, filter, _) if filter.isDefined => + val (e, attr) = expressionAttributePair(filter.get) + val aggregateExp = AggregateExpression(Max(attr), Partial, false) + (e, attr, Alias(aggregateExp, attr.name)()) + } + val distinctAggFilters = distinctAggFilterAttrMap.map(_._1) + val distinctAggFilterAttrs = distinctAggFilterAttrMap.map(_._2) + val boolOrs = distinctAggFilterAttrMap.map(_._3) // Setup expand & aggregate operators for distinct aggregate expressions. val distinctAggChildAttrLookup = distinctAggChildAttrMap.toMap + val distinctAggFilterAttrLookup = distinctAggFilterAttrMap.map { tuple3 => + tuple3._1 -> tuple3._3.toAttribute + }.toMap val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map { case ((group, expressions), i) => val id = Literal(i + 1) + val filters = expressions.filter(_.filter.isDefined).map(_.filter.get) + val filterProjection = distinctAggFilters.map { + case e if filters.contains(e) => e + case e => nullify(e) + } + // Expand projection val projection = distinctAggChildren.map { case e if group.contains(e) => e @@ -220,12 +246,17 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { val operators = expressions.map { e => val af = e.aggregateFunction val naf = patchAggregateFunctionChildren(af) { x => - distinctAggChildAttrLookup.get(x).map(evalWithinGroup(id, _)) + val condition = if (e.filter.isDefined) { + e.filter.map(distinctAggFilterAttrLookup.get(_)).get + } else { + None + } + distinctAggChildAttrLookup.get(x).map(evalWithinGroup(id, _, condition)) } - (e, e.copy(aggregateFunction = naf, isDistinct = false)) + (e, e.copy(aggregateFunction = naf, isDistinct = false, filter = None)) } - (projection, operators) + (projection ++ filterProjection, operators) } // Setup expand for the 'regular' aggregate expressions. @@ -253,7 +284,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // Select the result of the first aggregate in the last aggregate. val result = AggregateExpression( - aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute), true), + aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute, None), true), mode = Complete, isDistinct = false) @@ -276,6 +307,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { Seq(a.groupingExpressions ++ distinctAggChildren.map(nullify) ++ Seq(regularGroupId) ++ + distinctAggFilters.map(nullify) ++ regularAggChildren) } else { Seq.empty[Seq[Expression]] @@ -293,7 +325,8 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // Construct the expand operator. val expand = Expand( regularAggProjection ++ distinctAggProjections, - groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.map(_._2), + groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ distinctAggFilterAttrs ++ + regularAggChildAttrMap.map(_._2), a.child) // Construct the first aggregate operator. This de-duplicates all the children of @@ -301,7 +334,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid val firstAggregate = Aggregate( firstAggregateGroupBy, - firstAggregateGroupBy ++ regularAggOperatorMap.map(_._2), + firstAggregateGroupBy ++ boolOrs ++ regularAggOperatorMap.map(_._2), expand) // Construct the second aggregate @@ -322,56 +355,6 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { }.asInstanceOf[NamedExpression] } Aggregate(groupByAttrs, patchedAggExpressions, firstAggregate) - } else if (mayNeedtoProject(a)) { - var currentExprId = 0 - val (projections, filterProjections, aggPairs) = aggExpressions.map { - case ae @ AggregateExpression(af, _, isDistinct, filter, _) => - val unfoldableChildren = af.children.filter(!_.foldable) - val projectionMap = unfoldableChildren.map { - case e: Expression => - currentExprId += 1 - e -> Alias(e, s"_gen_attr_$currentExprId")() - } - val projection = projectionMap.map(_._2) - val filterProjection = filter.map { - case e: Expression => - currentExprId += 1 - e -> Alias(e, s"_gen_attr_$currentExprId")() - } - val exprAttrs = projectionMap.map { kv => - (kv._1, kv._2.toAttribute) - } - val exprAttrLookup = exprAttrs.toMap - val newChildren = af.children.map(c => exprAttrLookup.getOrElse(c, c)) - val naf = af.withNewChildren(newChildren).asInstanceOf[AggregateFunction] - val aggExpr = if (isDistinct && filterProjection.isDefined) { - val newFilter = Some(EqualNullSafe(filterProjection.get._2.toAttribute, Literal(true))) - ae.copy(aggregateFunction = naf, filter = newFilter) - } else { - ae.copy(aggregateFunction = naf, filter = None) - } - (projection, filterProjection.map(_._2), (ae, aggExpr)) - }.unzip3 - val namedGroupingProjection = a.groupingExpressions.flatMap { e => - e.collect { - case ar: AttributeReference => ar - } - } - val rewriteAggProjection = (namedGroupingProjection ++ projections.flatten).distinct ++ - filterProjections.flatten - val project = Project(rewriteAggProjection, a.child) - val rewriteAggExprLookup = aggPairs.toMap - val patchedAggExpressions = a.aggregateExpressions.map { e => - e.transformDown { - case ae: AggregateExpression => rewriteAggExprLookup.getOrElse(ae, ae) - }.asInstanceOf[NamedExpression] - } - val aggregate = Aggregate(a.groupingExpressions, patchedAggExpressions, project) - if (aggExpressions.filter(_.isDistinct).size > 1) { - rewrite(aggregate) - } else { - aggregate - } } else { a } @@ -383,19 +366,6 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { }} } - private def mayNeedtoProject(a: Aggregate): Boolean = { - var flag = true - a resolveOperatorsUp { - case p: Project => - if (p.output.exists(_.name.startsWith("_gen_attr_"))) { - flag = false - } - p - case other => other - } - flag - } - private def nullify(e: Expression) = Literal.create(null, e.dataType) private def expressionAttributePair(e: Expression) = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 7ad01c3d31fb9..761ac20e84744 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -142,7 +142,6 @@ object AggUtils { val distinctAttributes = normalizedNamedDistinctExpressions.map(_.toAttribute) val groupingAttributes = groupingExpressions.map(_.toAttribute) - val filterWithDistinctAttributes = functionsWithDistinct.flatMap(_.filterAttributes.toSeq) // 1. Create an Aggregate Operator for partial aggregations. val partialAggregate: SparkPlan = { @@ -152,12 +151,10 @@ object AggUtils { // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping // expressions will be [key, value]. createAggregate( - groupingExpressions = groupingExpressions ++ normalizedNamedDistinctExpressions ++ - filterWithDistinctAttributes, + groupingExpressions = groupingExpressions ++ normalizedNamedDistinctExpressions, aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, resultExpressions = groupingAttributes ++ distinctAttributes ++ - filterWithDistinctAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), child = child) } @@ -169,13 +166,11 @@ object AggUtils { createAggregate( requiredChildDistributionExpressions = Some(groupingAttributes ++ distinctAttributes), - groupingExpressions = groupingAttributes ++ distinctAttributes ++ - filterWithDistinctAttributes, + groupingExpressions = groupingAttributes ++ distinctAttributes, aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length, resultExpressions = groupingAttributes ++ distinctAttributes ++ - filterWithDistinctAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), child = partialAggregate) } @@ -206,8 +201,7 @@ object AggUtils { // its input will have distinct arguments. // We just keep the isDistinct setting to true, so when users look at the query plan, // they still can see distinct aggregations. - val filter = functionsWithDistinct(i).filter - val expr = AggregateExpression(func, Partial, isDistinct = true, filter) + val expr = AggregateExpression(func, Partial, isDistinct = true) // Use original AggregationFunction to lookup attributes, which is used to build // aggregateFunctionToAttribute val attr = functionsWithDistinct(i).resultAttribute @@ -239,8 +233,7 @@ object AggUtils { // its input will have distinct arguments. // We just keep the isDistinct setting to true, so when users look at the query plan, // they still can see distinct aggregations. - val filter = functionsWithDistinct(i).filter - val expr = AggregateExpression(func, Final, isDistinct = true, filter) + val expr = AggregateExpression(func, Final, isDistinct = true) // Use original AggregationFunction to lookup attributes, which is used to build // aggregateFunctionToAttribute val attr = functionsWithDistinct(i).resultAttribute diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql index beb5b9e5fe516..c36fd3bee7ded 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql @@ -33,8 +33,10 @@ SELECT COUNT(id) FILTER (WHERE hiredate = date "2001-01-01") FROM emp; SELECT COUNT(id) FILTER (WHERE hiredate = to_date('2001-01-01 00:00:00')) FROM emp; SELECT COUNT(id) FILTER (WHERE hiredate = to_timestamp("2001-01-01 00:00:00")) FROM emp; SELECT COUNT(id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") = "2001-01-01") FROM emp; --- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT --- SELECT COUNT(DISTINCT id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") = "2001-01-01 00:00:00") FROM emp; +SELECT COUNT(DISTINCT id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") = "2001-01-01 00:00:00") FROM emp; +SELECT COUNT(DISTINCT id), COUNT(DISTINCT id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") = "2001-01-01 00:00:00") FROM emp; +SELECT COUNT(DISTINCT id) FILTER (WHERE hiredate = to_timestamp("2001-01-01 00:00:00")), COUNT(DISTINCT id) FILTER (WHERE hiredate = to_date('2001-01-01 00:00:00')) FROM emp; +SELECT SUM(salary), COUNT(DISTINCT id), COUNT(DISTINCT id) FILTER (WHERE hiredate = date "2001-01-01") FROM emp; -- Aggregate with filter and non-empty GroupBy expressions. SELECT a, COUNT(b) FILTER (WHERE a >= 2) FROM testData GROUP BY a; @@ -44,8 +46,10 @@ SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > date "2003-01-01") FROM emp SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > to_date("2003-01-01")) FROM emp GROUP BY dept_id; SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > to_timestamp("2003-01-01 00:00:00")) FROM emp GROUP BY dept_id; SELECT dept_id, SUM(salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") > "2003-01-01") FROM emp GROUP BY dept_id; --- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT --- SELECT dept_id, SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") > "2001-01-01 00:00:00") FROM emp GROUP BY dept_id; +SELECT dept_id, SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") > "2001-01-01 00:00:00") FROM emp GROUP BY dept_id; +SELECT dept_id, SUM(DISTINCT salary), SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") > "2001-01-01 00:00:00") FROM emp GROUP BY dept_id; +SELECT dept_id, SUM(DISTINCT salary) FILTER (WHERE hiredate > date "2001-01-01"), SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") > "2001-01-01 00:00:00") FROM emp GROUP BY dept_id; +SELECT dept_id, COUNT(id), SUM(DISTINCT salary), SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") > "2001-01-01") FROM emp GROUP BY dept_id; -- Aggregate with filter and grouped by literals. SELECT 'foo', COUNT(a) FILTER (WHERE b <= 2) FROM testData GROUP BY 1; @@ -58,13 +62,23 @@ select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary), select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary), sum(salary) filter (where id + dept_id > 500) from emp group by dept_id; select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id > 200) from emp group by dept_id; select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id + dept_id > 500) from emp group by dept_id; --- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT --- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate), sum(salary) from emp group by dept_id; --- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id; --- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D) from emp group by dept_id; --- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id > 200) from emp group by dept_id; --- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name), sum(salary) from emp group by dept_id; --- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id; +select dept_id, count(distinct emp_name) filter (where id > 200), sum(salary) from emp group by dept_id; +select dept_id, count(distinct emp_name) filter (where id + dept_id > 500), sum(salary) from emp group by dept_id; +select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id > 200), sum(salary) from emp group by dept_id; +select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id + dept_id > 500), sum(salary) from emp group by dept_id; +select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id > 200), sum(salary), sum(salary) filter (where id > 200) from emp group by dept_id; +select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id + dept_id > 500), sum(salary), sum(salary) filter (where id > 200) from emp group by dept_id; +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate), sum(salary) from emp group by dept_id; +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id; +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D) from emp group by dept_id; +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id > 200) from emp group by dept_id; +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name), sum(salary) from emp group by dept_id; +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id; +select dept_id, sum(distinct (id + dept_id)) filter (where id > 200), count(distinct hiredate), sum(salary) from emp group by dept_id; +select dept_id, sum(distinct (id + dept_id)) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id; +select dept_id, avg(distinct (id + dept_id)) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D) from emp group by dept_id; +select dept_id, count(distinct emp_name, hiredate) filter (where id > 200), sum(salary) from emp group by dept_id; +select dept_id, count(distinct emp_name, hiredate) filter (where id > 0), sum(salary) from emp group by dept_id; -- Aggregate with filter and grouped by literals (hash aggregate), here the input table is filtered using WHERE. SELECT 'foo', APPROX_COUNT_DISTINCT(a) FILTER (WHERE b >= 0) FROM testData WHERE a = 0 GROUP BY 1; @@ -78,9 +92,8 @@ SELECT a + 2, COUNT(b) FILTER (WHERE b IN (1, 2)) FROM testData GROUP BY a + 1; SELECT a + 1 + 1, COUNT(b) FILTER (WHERE b > 0) FROM testData GROUP BY a + 1; -- Aggregate with filter, foldable input and multiple distinct groups. --- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT --- SELECT COUNT(DISTINCT b) FILTER (WHERE b > 0), COUNT(DISTINCT b, c) FILTER (WHERE b > 0 AND c > 2) --- FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a; +SELECT COUNT(DISTINCT b) FILTER (WHERE b > 0), COUNT(DISTINCT b, c) FILTER (WHERE b > 0 AND c > 2) +FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a; -- Check analysis exceptions SELECT a AS k, COUNT(b) FILTER (WHERE b > 0) FROM testData GROUP BY k; diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql index 746b677234832..657ea59ec8f11 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql @@ -241,10 +241,9 @@ select sum(1/ten) filter (where ten > 0) from tenk1; -- select ten, sum(distinct four) filter (where four::text ~ '123') from onek a -- group by ten; --- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT --- select ten, sum(distinct four) filter (where four > 10) from onek a --- group by ten --- having exists (select 1 from onek b where sum(distinct a.four) = b.four); +select ten, sum(distinct four) filter (where four > 10) from onek a +group by ten +having exists (select 1 from onek b where sum(distinct a.four) = b.four); -- [SPARK-28682] ANSI SQL: Collation Support -- select max(foo COLLATE "C") filter (where (bar collate "POSIX") > '0') diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/groupingsets.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/groupingsets.sql index fc54d179f742c..45617c53166aa 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/groupingsets.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/groupingsets.sql @@ -336,9 +336,8 @@ order by 2,1; -- order by 2,1; -- FILTER queries --- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT --- select ten, sum(distinct four) filter (where string(four) like '123') from onek a --- group by rollup(ten); +select ten, sum(distinct four) filter (where string(four) like '123') from onek a +group by rollup(ten); -- More rescan tests -- [SPARK-27877] ANSI SQL: LATERAL derived table(T491) From 72e95f1295fbfa4d7bbba653d476891db7afd20b Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Wed, 29 Jul 2020 19:07:46 +0800 Subject: [PATCH 04/15] Supplement doc and comment. --- .../optimizer/RewriteDistinctAggregates.scala | 380 ++++++++++-------- 1 file changed, 213 insertions(+), 167 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 0e258784e5969..f530c52aaf29d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -81,10 +81,10 @@ import org.apache.spark.sql.types.IntegerType * COUNT(DISTINCT cat1) as cat1_cnt, * COUNT(DISTINCT cat2) as cat2_cnt, * SUM(value) FILTER (WHERE id > 1) AS total - * FROM - * data - * GROUP BY - * key + * FROM + * data + * GROUP BY + * key * }}} * * This translates to the following (pseudo) logical plan: @@ -93,7 +93,7 @@ import org.apache.spark.sql.types.IntegerType * key = ['key] * functions = [COUNT(DISTINCT 'cat1), * COUNT(DISTINCT 'cat2), - * sum('value) with FILTER('id > 1)] + * sum('value) FILTER (WHERE 'id > 1)] * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) * LocalTableScan [...] * }}} @@ -108,7 +108,7 @@ import org.apache.spark.sql.types.IntegerType * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) * Aggregate( * key = ['key, 'cat1, 'cat2, 'gid] - * functions = [sum('value) with FILTER('id > 1)] + * functions = [sum('value) FILTER (WHERE 'id > 1)] * output = ['key, 'cat1, 'cat2, 'gid, 'total]) * Expand( * projections = [('key, null, null, 0, cast('value as bigint), 'id), @@ -118,6 +118,49 @@ import org.apache.spark.sql.types.IntegerType * LocalTableScan [...] * }}} * + * Third example: aggregate function with distinct and filter clauses (in sql): + * {{{ + * SELECT + * COUNT(DISTINCT cat1) FILTER (WHERE id > 1) as cat1_cnt, + * COUNT(DISTINCT cat2) FILTER (WHERE id > 2) as cat2_cnt, + * SUM(value) FILTER (WHERE id > 3) AS total + * FROM + * data + * GROUP BY + * key + * }}} + * + * This translates to the following (pseudo) logical plan: + * {{{ + * Aggregate( + * key = ['key] + * functions = [COUNT(DISTINCT 'cat1) FILTER (WHERE 'id > 1), + * COUNT(DISTINCT 'cat2) FILTER (WHERE 'id > 2), + * sum('value) FILTER (WHERE 'id > 3)] + * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) + * LocalTableScan [...] + * }}} + * + * This rule rewrites this logical plan to the following (pseudo) logical plan: + * {{{ + * Aggregate( + * key = ['key] + * functions = [count(if (('gid = 1) and 'max_cond1) 'cat1 else null), + * count(if (('gid = 2) and 'max_cond2) 'cat2 else null), + * first(if (('gid = 0)) 'total else null) ignore nulls] + * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) + * Aggregate( + * key = ['key, 'cat1, 'cat2, 'gid] + * functions = [max('cond1), max('cond2), sum('value) FILTER (WHERE 'id > 1)] + * output = ['key, 'cat1, 'cat2, 'gid, 'max_cond1, 'max_cond2, 'total]) + * Expand( + * projections = [('key, null, null, 0, null, null, cast('value as bigint), 'id), + * ('key, 'cat1, null, 1, 'id > 1, null, null, null), + * ('key, null, 'cat2, 2, null, 'id > 2, null, null)] + * output = ['key, 'cat1, 'cat2, 'gid, 'cond1, 'cond2, 'value, 'id]) + * LocalTableScan [...] + * }}} + * * The rule does the following things here: * 1. Expand the data. There are three aggregation groups in this query: * i. the non-distinct group; @@ -126,15 +169,20 @@ import org.apache.spark.sql.types.IntegerType * An expand operator is inserted to expand the child data for each group. The expand will null * out all unused columns for the given group; this must be done in order to ensure correctness * later on. Groups can by identified by a group id (gid) column added by the expand operator. + * If distinct group exists filter clause, the expand will calculate the filter and output it's + * result which will be used to calculate the global conditions equivalent to filter clauses. * 2. De-duplicate the distinct paths and aggregate the non-aggregate path. The group by clause of * this aggregate consists of the original group by clause, all the requested distinct columns * and the group id. Both de-duplication of distinct column and the aggregation of the * non-distinct group take advantage of the fact that we group by the group id (gid) and that we - * have nulled out all non-relevant columns the given group. + * have nulled out all non-relevant columns the given group. If distinct group exists filter + * clause, we will use max to aggregate the results of the filter output in the previous step. + * These aggregate values are equivalent to filter clauses. * 3. Aggregating the distinct groups and combining this with the results of the non-distinct - * aggregation. In this step we use the group id to filter the inputs for the aggregate - * functions. The result of the non-distinct group are 'aggregated' by using the first operator, - * it might be more elegant to use the native UDAF merge mechanism for this in the future. + * aggregation. In this step we use the group id and the global condition to filter the inputs + * for the aggregate functions. The result of the non-distinct group are 'aggregated' by using + * the first operator, it might be more elegant to use the native UDAF merge mechanism for this + * in the future. * * This rule duplicates the input data by two or more times (# distinct groups + an optional * non-distinct group). This will put quite a bit of memory pressure of the used aggregate and @@ -149,7 +197,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { val distinctAggs = aggExpressions.filter(_.isDistinct) // We need at least two distinct aggregates or the single distinct aggregate group exists filter // clause for this rule because aggregation strategy can handle a single distinct aggregate - // group without a filter. + // group without filter clause. // This check can produce false-positives, e.g., SUM(DISTINCT a) & COUNT(DISTINCT a). distinctAggs.size > 1 || (distinctAggs.size == 1 && aggExpressions.exists(_.filter.isDefined)) } @@ -180,184 +228,182 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } } - // Aggregation strategy can handle queries with a single distinct group. - if (distinctAggGroups.size >= 1) { - // Create the attributes for the grouping id and the group by clause. - val gid = AttributeReference("gid", IntegerType, nullable = false)() - val groupByMap = a.groupingExpressions.collect { - case ne: NamedExpression => ne -> ne.toAttribute - case e => e -> AttributeReference(e.sql, e.dataType, e.nullable)() - } - val groupByAttrs = groupByMap.map(_._2) - - // Functions used to modify aggregate functions and their inputs. - def evalWithinGroup(id: Literal, e: Expression, condition: Option[Expression]) = - if (condition.isDefined) { - If(And(EqualTo(gid, id), condition.get), e, nullify(e)) - } else { - If(EqualTo(gid, id), e, nullify(e)) - } + // Aggregation strategy can handle queries with a single distinct group without filter clause. + // Create the attributes for the grouping id and the group by clause. + val gid = AttributeReference("gid", IntegerType, nullable = false)() + val groupByMap = a.groupingExpressions.collect { + case ne: NamedExpression => ne -> ne.toAttribute + case e => e -> AttributeReference(e.sql, e.dataType, e.nullable)() + } + val groupByAttrs = groupByMap.map(_._2) - def patchAggregateFunctionChildren( - af: AggregateFunction)( - attrs: Expression => Option[Expression]): AggregateFunction = { - val newChildren = af.children.map(c => attrs(c).getOrElse(c)) - af.withNewChildren(newChildren).asInstanceOf[AggregateFunction] + // Functions used to modify aggregate functions and their inputs. + def evalWithinGroup(id: Literal, e: Expression, condition: Option[Expression]) = + if (condition.isDefined) { + If(And(EqualTo(gid, id), condition.get), e, nullify(e)) + } else { + If(EqualTo(gid, id), e, nullify(e)) } - // Setup unique distinct aggregate children. - val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct - val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair) - val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2) - val distinctAggExprs = aggExpressions - .filter(e => e.isDistinct && e.children.exists(!_.foldable)) - val distinctAggFilterAttrMap = distinctAggExprs.collect { - case AggregateExpression(_, _, _, filter, _) if filter.isDefined => - val (e, attr) = expressionAttributePair(filter.get) - val aggregateExp = AggregateExpression(Max(attr), Partial, false) - (e, attr, Alias(aggregateExp, attr.name)()) - } - val distinctAggFilters = distinctAggFilterAttrMap.map(_._1) - val distinctAggFilterAttrs = distinctAggFilterAttrMap.map(_._2) - val boolOrs = distinctAggFilterAttrMap.map(_._3) + def patchAggregateFunctionChildren( + af: AggregateFunction)( + attrs: Expression => Option[Expression]): AggregateFunction = { + val newChildren = af.children.map(c => attrs(c).getOrElse(c)) + af.withNewChildren(newChildren).asInstanceOf[AggregateFunction] + } - // Setup expand & aggregate operators for distinct aggregate expressions. - val distinctAggChildAttrLookup = distinctAggChildAttrMap.toMap - val distinctAggFilterAttrLookup = distinctAggFilterAttrMap.map { tuple3 => - tuple3._1 -> tuple3._3.toAttribute - }.toMap - val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map { - case ((group, expressions), i) => - val id = Literal(i + 1) + // Setup unique distinct aggregate children. + val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct + val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair) + val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2) + // Setup all the filters in distinct aggregate. + val distinctAggExprs = aggExpressions + .filter(e => e.isDistinct && e.children.exists(!_.foldable)) + val distinctAggFilterAttrMap = distinctAggExprs.collect { + case AggregateExpression(_, _, _, filter, _) if filter.isDefined => + val (e, attr) = expressionAttributePair(filter.get) + val aggregateExp = AggregateExpression(Max(attr), Partial, false) + (e, attr, Alias(aggregateExp, attr.name)()) + } + val distinctAggFilters = distinctAggFilterAttrMap.map(_._1) + val distinctAggFilterAttrs = distinctAggFilterAttrMap.map(_._2) + val boolOrs = distinctAggFilterAttrMap.map(_._3) - val filters = expressions.filter(_.filter.isDefined).map(_.filter.get) - val filterProjection = distinctAggFilters.map { - case e if filters.contains(e) => e - case e => nullify(e) - } + // Setup expand & aggregate operators for distinct aggregate expressions. + val distinctAggChildAttrLookup = distinctAggChildAttrMap.toMap + val distinctAggFilterAttrLookup = distinctAggFilterAttrMap.map { tuple3 => + tuple3._1 -> tuple3._3.toAttribute + }.toMap + val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map { + case ((group, expressions), i) => + val id = Literal(i + 1) + + // Expand projection for filter + val filters = expressions.filter(_.filter.isDefined).map(_.filter.get) + val filterProjection = distinctAggFilters.map { + case e if filters.contains(e) => e + case e => nullify(e) + } - // Expand projection - val projection = distinctAggChildren.map { - case e if group.contains(e) => e - case e => nullify(e) - } :+ id + // Expand projection + val projection = distinctAggChildren.map { + case e if group.contains(e) => e + case e => nullify(e) + } :+ id - // Final aggregate - val operators = expressions.map { e => - val af = e.aggregateFunction - val naf = patchAggregateFunctionChildren(af) { x => - val condition = if (e.filter.isDefined) { - e.filter.map(distinctAggFilterAttrLookup.get(_)).get - } else { - None - } - distinctAggChildAttrLookup.get(x).map(evalWithinGroup(id, _, condition)) + // Final aggregate + val operators = expressions.map { e => + val af = e.aggregateFunction + val naf = patchAggregateFunctionChildren(af) { x => + val condition = if (e.filter.isDefined) { + e.filter.map(distinctAggFilterAttrLookup.get(_)).get + } else { + None } - (e, e.copy(aggregateFunction = naf, isDistinct = false, filter = None)) + distinctAggChildAttrLookup.get(x).map(evalWithinGroup(id, _, condition)) } + (e, e.copy(aggregateFunction = naf, isDistinct = false, filter = None)) + } - (projection ++ filterProjection, operators) - } - - // Setup expand for the 'regular' aggregate expressions. - // only expand unfoldable children - val regularAggExprs = aggExpressions - .filter(e => !e.isDistinct && e.children.exists(!_.foldable)) - val regularAggFunChildren = regularAggExprs - .flatMap(_.aggregateFunction.children.filter(!_.foldable)) - val regularAggFilterAttrs = regularAggExprs.flatMap(_.filterAttributes) - val regularAggChildren = (regularAggFunChildren ++ regularAggFilterAttrs).distinct - val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair) + (projection ++ filterProjection, operators) + } - // Setup aggregates for 'regular' aggregate expressions. - val regularGroupId = Literal(0) - val regularAggChildAttrLookup = regularAggChildAttrMap.toMap - val regularAggOperatorMap = regularAggExprs.map { e => - // Perform the actual aggregation in the initial aggregate. - val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup.get) - // We changed the attributes in the [[Expand]] output using expressionAttributePair. - // So we need to replace the attributes in FILTER expression with new ones. - val filterOpt = e.filter.map(_.transform { - case a: Attribute => regularAggChildAttrLookup.getOrElse(a, a) - }) - val operator = Alias(e.copy(aggregateFunction = af, filter = filterOpt), e.sql)() + // Setup expand for the 'regular' aggregate expressions. + // only expand unfoldable children + val regularAggExprs = aggExpressions + .filter(e => !e.isDistinct && e.children.exists(!_.foldable)) + val regularAggFunChildren = regularAggExprs + .flatMap(_.aggregateFunction.children.filter(!_.foldable)) + val regularAggFilterAttrs = regularAggExprs.flatMap(_.filterAttributes) + val regularAggChildren = (regularAggFunChildren ++ regularAggFilterAttrs).distinct + val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair) - // Select the result of the first aggregate in the last aggregate. - val result = AggregateExpression( - aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute, None), true), - mode = Complete, - isDistinct = false) + // Setup aggregates for 'regular' aggregate expressions. + val regularGroupId = Literal(0) + val regularAggChildAttrLookup = regularAggChildAttrMap.toMap + val regularAggOperatorMap = regularAggExprs.map { e => + // Perform the actual aggregation in the initial aggregate. + val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup.get) + // We changed the attributes in the [[Expand]] output using expressionAttributePair. + // So we need to replace the attributes in FILTER expression with new ones. + val filterOpt = e.filter.map(_.transform { + case a: Attribute => regularAggChildAttrLookup.getOrElse(a, a) + }) + val operator = Alias(e.copy(aggregateFunction = af, filter = filterOpt), e.sql)() - // Some aggregate functions (COUNT) have the special property that they can return a - // non-null result without any input. We need to make sure we return a result in this case. - val resultWithDefault = af.defaultResult match { - case Some(lit) => Coalesce(Seq(result, lit)) - case None => result - } + // Select the result of the first aggregate in the last aggregate. + val result = AggregateExpression( + aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute, None), true), + mode = Complete, + isDistinct = false) - // Return a Tuple3 containing: - // i. The original aggregate expression (used for look ups). - // ii. The actual aggregation operator (used in the first aggregate). - // iii. The operator that selects and returns the result (used in the second aggregate). - (e, operator, resultWithDefault) + // Some aggregate functions (COUNT) have the special property that they can return a + // non-null result without any input. We need to make sure we return a result in this case. + val resultWithDefault = af.defaultResult match { + case Some(lit) => Coalesce(Seq(result, lit)) + case None => result } - // Construct the regular aggregate input projection only if we need one. - val regularAggProjection = if (regularAggExprs.nonEmpty) { - Seq(a.groupingExpressions ++ - distinctAggChildren.map(nullify) ++ - Seq(regularGroupId) ++ - distinctAggFilters.map(nullify) ++ - regularAggChildren) - } else { - Seq.empty[Seq[Expression]] - } + // Return a Tuple3 containing: + // i. The original aggregate expression (used for look ups). + // ii. The actual aggregation operator (used in the first aggregate). + // iii. The operator that selects and returns the result (used in the second aggregate). + (e, operator, resultWithDefault) + } - // Construct the distinct aggregate input projections. - val regularAggNulls = regularAggChildren.map(nullify) - val distinctAggProjections = distinctAggOperatorMap.map { - case (projection, _) => - a.groupingExpressions ++ - projection ++ - regularAggNulls - } + // Construct the regular aggregate input projection only if we need one. + val regularAggProjection = if (regularAggExprs.nonEmpty) { + Seq(a.groupingExpressions ++ + distinctAggChildren.map(nullify) ++ + Seq(regularGroupId) ++ + distinctAggFilters.map(nullify) ++ + regularAggChildren) + } else { + Seq.empty[Seq[Expression]] + } + + // Construct the distinct aggregate input projections. + val regularAggNulls = regularAggChildren.map(nullify) + val distinctAggProjections = distinctAggOperatorMap.map { + case (projection, _) => + a.groupingExpressions ++ + projection ++ + regularAggNulls + } - // Construct the expand operator. - val expand = Expand( - regularAggProjection ++ distinctAggProjections, - groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ distinctAggFilterAttrs ++ - regularAggChildAttrMap.map(_._2), - a.child) + // Construct the expand operator. + val expand = Expand( + regularAggProjection ++ distinctAggProjections, + groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ distinctAggFilterAttrs ++ + regularAggChildAttrMap.map(_._2), + a.child) - // Construct the first aggregate operator. This de-duplicates all the children of - // distinct operators, and applies the regular aggregate operators. - val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid - val firstAggregate = Aggregate( - firstAggregateGroupBy, - firstAggregateGroupBy ++ boolOrs ++ regularAggOperatorMap.map(_._2), - expand) + // Construct the first aggregate operator. This de-duplicates all the children of + // distinct operators, and applies the regular aggregate operators. + val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid + val firstAggregate = Aggregate( + firstAggregateGroupBy, + firstAggregateGroupBy ++ boolOrs ++ regularAggOperatorMap.map(_._2), + expand) - // Construct the second aggregate - val transformations: Map[Expression, Expression] = - (distinctAggOperatorMap.flatMap(_._2) ++ - regularAggOperatorMap.map(e => (e._1, e._3))).toMap + // Construct the second aggregate + val transformations: Map[Expression, Expression] = + (distinctAggOperatorMap.flatMap(_._2) ++ + regularAggOperatorMap.map(e => (e._1, e._3))).toMap - val patchedAggExpressions = a.aggregateExpressions.map { e => - e.transformDown { - case e: Expression => - // The same GROUP BY clauses can have different forms (different names for instance) in - // the groupBy and aggregate expressions of an aggregate. This makes a map lookup - // tricky. So we do a linear search for a semantically equal group by expression. - groupByMap - .find(ge => e.semanticEquals(ge._1)) - .map(_._2) - .getOrElse(transformations.getOrElse(e, e)) - }.asInstanceOf[NamedExpression] - } - Aggregate(groupByAttrs, patchedAggExpressions, firstAggregate) - } else { - a + val patchedAggExpressions = a.aggregateExpressions.map { e => + e.transformDown { + case e: Expression => + // The same GROUP BY clauses can have different forms (different names for instance) in + // the groupBy and aggregate expressions of an aggregate. This makes a map lookup + // tricky. So we do a linear search for a semantically equal group by expression. + groupByMap + .find(ge => e.semanticEquals(ge._1)) + .map(_._2) + .getOrElse(transformations.getOrElse(e, e)) + }.asInstanceOf[NamedExpression] } + Aggregate(groupByAttrs, patchedAggExpressions, firstAggregate) } private def collectAggregateExprs(a: Aggregate): Seq[AggregateExpression] = { From 8e82e836d5164ab04f36c2deff9d8e9c39577b99 Mon Sep 17 00:00:00 2001 From: beliefer Date: Wed, 29 Jul 2020 22:28:59 +0800 Subject: [PATCH 05/15] Add test case and regenerate golden files. --- .../sql-tests/results/group-by-filter.sql.out | 316 +++++++++++++++++- .../postgreSQL/aggregates_part3.sql.out | 16 +- .../results/postgreSQL/groupingsets.sql.out | 21 +- 3 files changed, 350 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out index d41d25280146b..4dac09fcefe2e 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 37 +-- Number of queries: 63 -- !query @@ -94,6 +94,38 @@ struct +-- !query output +2 + + +-- !query +SELECT COUNT(DISTINCT id), COUNT(DISTINCT id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") = "2001-01-01 00:00:00") FROM emp +-- !query schema +struct +-- !query output +8 2 + + +-- !query +SELECT COUNT(DISTINCT id) FILTER (WHERE hiredate = to_timestamp("2001-01-01 00:00:00")), COUNT(DISTINCT id) FILTER (WHERE hiredate = to_date('2001-01-01 00:00:00')) FROM emp +-- !query schema +struct +-- !query output +2 2 + + +-- !query +SELECT SUM(salary), COUNT(DISTINCT id), COUNT(DISTINCT id) FILTER (WHERE hiredate = date "2001-01-01") FROM emp +-- !query schema +struct +-- !query output +2450.0 8 2 + + -- !query SELECT a, COUNT(b) FILTER (WHERE a >= 2) FROM testData GROUP BY a -- !query schema @@ -177,6 +209,58 @@ struct "2001-01-01 00:00:00") FROM emp GROUP BY dept_id +-- !query schema +struct 2001-01-01 00:00:00)):double> +-- !query output +10 300.0 +100 400.0 +20 300.0 +30 400.0 +70 150.0 +NULL NULL + + +-- !query +SELECT dept_id, SUM(DISTINCT salary), SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") > "2001-01-01 00:00:00") FROM emp GROUP BY dept_id +-- !query schema +struct 2001-01-01 00:00:00)):double> +-- !query output +10 300.0 300.0 +100 400.0 400.0 +20 300.0 300.0 +30 400.0 400.0 +70 150.0 150.0 +NULL 400.0 NULL + + +-- !query +SELECT dept_id, SUM(DISTINCT salary) FILTER (WHERE hiredate > date "2001-01-01"), SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") > "2001-01-01 00:00:00") FROM emp GROUP BY dept_id +-- !query schema +struct DATE '2001-01-01')):double,sum(DISTINCT salary) FILTER (WHERE (date_format(CAST(hiredate AS TIMESTAMP), yyyy-MM-dd HH:mm:ss) > 2001-01-01 00:00:00)):double> +-- !query output +10 300.0 300.0 +100 400.0 400.0 +20 300.0 300.0 +30 400.0 400.0 +70 150.0 150.0 +NULL NULL NULL + + +-- !query +SELECT dept_id, COUNT(id), SUM(DISTINCT salary), SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") > "2001-01-01") FROM emp GROUP BY dept_id +-- !query schema +struct 2001-01-01)):double> +-- !query output +10 3 300.0 300.0 +100 2 400.0 400.0 +20 1 300.0 300.0 +30 1 400.0 400.0 +70 1 150.0 150.0 +NULL 1 400.0 NULL + + -- !query SELECT 'foo', COUNT(a) FILTER (WHERE b <= 2) FROM testData GROUP BY 1 -- !query schema @@ -261,6 +345,227 @@ struct 200), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,sum(salary):double> +-- !query output +10 0 400.0 +100 2 800.0 +20 1 300.0 +30 1 400.0 +70 1 150.0 +NULL 1 400.0 + + +-- !query +select dept_id, count(distinct emp_name) filter (where id + dept_id > 500), sum(salary) from emp group by dept_id +-- !query schema +struct 500)):bigint,sum(salary):double> +-- !query output +10 0 400.0 +100 2 800.0 +20 0 300.0 +30 0 400.0 +70 1 150.0 +NULL 0 400.0 + + +-- !query +select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id > 200), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,sum(salary):double> +-- !query output +10 2 0 400.0 +100 2 2 800.0 +20 1 1 300.0 +30 1 1 400.0 +70 1 1 150.0 +NULL 1 1 400.0 + + +-- !query +select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id + dept_id > 500), sum(salary) from emp group by dept_id +-- !query schema +struct 500)):bigint,sum(salary):double> +-- !query output +10 2 0 400.0 +100 2 2 800.0 +20 1 0 300.0 +30 1 0 400.0 +70 1 1 150.0 +NULL 1 0 400.0 + + +-- !query +select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id > 200), sum(salary), sum(salary) filter (where id > 200) from emp group by dept_id +-- !query schema +struct 200)):bigint,sum(salary):double,sum(salary) FILTER (WHERE (id > 200)):double> +-- !query output +10 2 0 400.0 NULL +100 2 2 800.0 800.0 +20 1 1 300.0 300.0 +30 1 1 400.0 400.0 +70 1 1 150.0 150.0 +NULL 1 1 400.0 400.0 + + +-- !query +select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id + dept_id > 500), sum(salary), sum(salary) filter (where id > 200) from emp group by dept_id +-- !query schema +struct 500)):bigint,sum(salary):double,sum(salary) FILTER (WHERE (id > 200)):double> +-- !query output +10 2 0 400.0 NULL +100 2 2 800.0 800.0 +20 1 0 300.0 300.0 +30 1 0 400.0 400.0 +70 1 1 150.0 150.0 +NULL 1 0 400.0 400.0 + + +-- !query +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,count(DISTINCT hiredate):bigint,sum(salary):double> +-- !query output +10 0 2 400.0 +100 2 2 800.0 +20 1 1 300.0 +30 1 1 400.0 +70 1 1 150.0 +NULL 1 1 400.0 + + +-- !query +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,count(DISTINCT hiredate) FILTER (WHERE (hiredate > DATE '2003-01-01')):bigint,sum(salary):double> +-- !query output +10 0 1 400.0 +100 2 1 800.0 +20 1 0 300.0 +30 1 1 400.0 +70 1 1 150.0 +NULL 1 0 400.0 + + +-- !query +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D) from emp group by dept_id +-- !query schema +struct 200)):bigint,count(DISTINCT hiredate) FILTER (WHERE (hiredate > DATE '2003-01-01')):bigint,sum(salary) FILTER (WHERE (salary < 400.0)):double> +-- !query output +10 0 1 400.0 +100 2 1 NULL +20 1 0 300.0 +30 1 1 NULL +70 1 1 150.0 +NULL 1 0 NULL + + +-- !query +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id > 200) from emp group by dept_id +-- !query schema +struct 200)):bigint,count(DISTINCT hiredate) FILTER (WHERE (hiredate > DATE '2003-01-01')):bigint,sum(salary) FILTER (WHERE (salary < 400.0)):double,sum(salary) FILTER (WHERE (id > 200)):double> +-- !query output +10 0 1 400.0 NULL +100 2 1 NULL 800.0 +20 1 0 300.0 300.0 +30 1 1 NULL 400.0 +70 1 1 150.0 150.0 +NULL 1 0 NULL 400.0 + + +-- !query +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,count(DISTINCT emp_name):bigint,sum(salary):double> +-- !query output +10 0 2 400.0 +100 2 2 800.0 +20 1 1 300.0 +30 1 1 400.0 +70 1 1 150.0 +NULL 1 1 400.0 + + +-- !query +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,count(DISTINCT emp_name) FILTER (WHERE (hiredate > DATE '2003-01-01')):bigint,sum(salary):double> +-- !query output +10 0 1 400.0 +100 2 1 800.0 +20 1 0 300.0 +30 1 1 400.0 +70 1 1 150.0 +NULL 1 0 400.0 + + +-- !query +select dept_id, sum(distinct (id + dept_id)) filter (where id > 200), count(distinct hiredate), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,count(DISTINCT hiredate):bigint,sum(salary):double> +-- !query output +10 NULL 2 400.0 +100 1500 2 800.0 +20 320 1 300.0 +30 430 1 400.0 +70 870 1 150.0 +NULL NULL 1 400.0 + + +-- !query +select dept_id, sum(distinct (id + dept_id)) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,count(DISTINCT hiredate) FILTER (WHERE (hiredate > DATE '2003-01-01')):bigint,sum(salary):double> +-- !query output +10 NULL 1 400.0 +100 1500 1 800.0 +20 320 0 300.0 +30 430 1 400.0 +70 870 1 150.0 +NULL NULL 0 400.0 + + +-- !query +select dept_id, avg(distinct (id + dept_id)) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D) from emp group by dept_id +-- !query schema +struct 200)):double,count(DISTINCT hiredate) FILTER (WHERE (hiredate > DATE '2003-01-01')):bigint,sum(salary) FILTER (WHERE (salary < 400.0)):double> +-- !query output +10 NULL 1 400.0 +100 750.0 1 NULL +20 320.0 0 300.0 +30 430.0 1 NULL +70 870.0 1 150.0 +NULL NULL 0 NULL + + +-- !query +select dept_id, count(distinct emp_name, hiredate) filter (where id > 200), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,sum(salary):double> +-- !query output +10 0 400.0 +100 2 800.0 +20 1 300.0 +30 1 400.0 +70 1 150.0 +NULL 1 400.0 + + +-- !query +select dept_id, count(distinct emp_name, hiredate) filter (where id > 0), sum(salary) from emp group by dept_id +-- !query schema +struct 0)):bigint,sum(salary):double> +-- !query output +10 2 400.0 +100 2 800.0 +20 1 300.0 +30 1 400.0 +70 1 150.0 +NULL 1 400.0 + + -- !query SELECT 'foo', APPROX_COUNT_DISTINCT(a) FILTER (WHERE b >= 0) FROM testData WHERE a = 0 GROUP BY 1 -- !query schema @@ -309,6 +614,15 @@ struct<((a + 1) + 1):int,count(b) FILTER (WHERE (b > 0)):bigint> NULL 1 +-- !query +SELECT COUNT(DISTINCT b) FILTER (WHERE b > 0), COUNT(DISTINCT b, c) FILTER (WHERE b > 0 AND c > 2) +FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a +-- !query schema +struct 0)):bigint,count(DISTINCT b, c) FILTER (WHERE ((b > 0) AND (c > 2))):bigint> +-- !query output +1 1 + + -- !query SELECT a AS k, COUNT(b) FILTER (WHERE b > 0) FROM testData GROUP BY k -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part3.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part3.sql.out index 69f96b02782e3..e1f735e5fe1dc 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part3.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part3.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 4 +-- Number of queries: 5 -- !query @@ -27,6 +27,20 @@ struct 0)):d 2828.9682539682954 +-- !query +select ten, sum(distinct four) filter (where four > 10) from onek a +group by ten +having exists (select 1 from onek b where sum(distinct a.four) = b.four) +-- !query schema +struct 10)):bigint> +-- !query output +0 NULL +2 NULL +4 NULL +6 NULL +8 NULL + + -- !query select (select count(*) from (values (1)) t0(inner_c)) diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/groupingsets.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/groupingsets.sql.out index 7312c20876296..2619634d7d569 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/groupingsets.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/groupingsets.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 54 +-- Number of queries: 55 -- !query @@ -443,6 +443,25 @@ struct NULL 1 +-- !query +select ten, sum(distinct four) filter (where string(four) like '123') from onek a +group by rollup(ten) +-- !query schema +struct +-- !query output +0 NULL +1 NULL +2 NULL +3 NULL +4 NULL +5 NULL +6 NULL +7 NULL +8 NULL +9 NULL +NULL NULL + + -- !query select count(*) from gstest4 group by rollup(unhashable_col,unsortable_col) -- !query schema From 4ba808bf602d39330810ce0a3bc61fe2ba9ef2b5 Mon Sep 17 00:00:00 2001 From: beliefer Date: Wed, 29 Jul 2020 22:34:02 +0800 Subject: [PATCH 06/15] Add test case and regenerate golden files. --- .../sql/catalyst/optimizer/RewriteDistinctAggregates.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index f530c52aaf29d..9dd4ef5a2050f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types.IntegerType From 145a9dd2b315f25047d7032b6027a74bef34d202 Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Thu, 30 Jul 2020 10:06:55 +0800 Subject: [PATCH 07/15] Optimize code --- .../optimizer/RewriteDistinctAggregates.scala | 312 +++++++++--------- 1 file changed, 158 insertions(+), 154 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 9dd4ef5a2050f..55469aa9aa57c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -208,7 +208,6 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { def rewrite(a: Aggregate): Aggregate = { - // Collect all aggregate expressions. val aggExpressions = collectAggregateExprs(a) // Extract distinct aggregate expressions. @@ -229,184 +228,189 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } // Aggregation strategy can handle queries with a single distinct group without filter clause. - // Create the attributes for the grouping id and the group by clause. - val gid = AttributeReference("gid", IntegerType, nullable = false)() - val groupByMap = a.groupingExpressions.collect { - case ne: NamedExpression => ne -> ne.toAttribute - case e => e -> AttributeReference(e.sql, e.dataType, e.nullable)() - } - val groupByAttrs = groupByMap.map(_._2) - - // Functions used to modify aggregate functions and their inputs. - def evalWithinGroup(id: Literal, e: Expression, condition: Option[Expression]) = - if (condition.isDefined) { - If(And(EqualTo(gid, id), condition.get), e, nullify(e)) - } else { - If(EqualTo(gid, id), e, nullify(e)) + if (distinctAggGroups.size > 1 || aggExpressions.exists(_.filter.isDefined)) { + // Create the attributes for the grouping id and the group by clause. + val gid = AttributeReference("gid", IntegerType, nullable = false)() + val groupByMap = a.groupingExpressions.collect { + case ne: NamedExpression => ne -> ne.toAttribute + case e => e -> AttributeReference(e.sql, e.dataType, e.nullable)() } + val groupByAttrs = groupByMap.map(_._2) - def patchAggregateFunctionChildren( - af: AggregateFunction)( - attrs: Expression => Option[Expression]): AggregateFunction = { - val newChildren = af.children.map(c => attrs(c).getOrElse(c)) - af.withNewChildren(newChildren).asInstanceOf[AggregateFunction] - } + // Functions used to modify aggregate functions and their inputs. + def evalWithinGroup(id: Literal, e: Expression, condition: Option[Expression]) = + if (condition.isDefined) { + If(And(EqualTo(gid, id), condition.get), e, nullify(e)) + } else { + If(EqualTo(gid, id), e, nullify(e)) + } - // Setup unique distinct aggregate children. - val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct - val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair) - val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2) - // Setup all the filters in distinct aggregate. - val distinctAggExprs = aggExpressions - .filter(e => e.isDistinct && e.children.exists(!_.foldable)) - val distinctAggFilterAttrMap = distinctAggExprs.collect { - case AggregateExpression(_, _, _, filter, _) if filter.isDefined => - val (e, attr) = expressionAttributePair(filter.get) - val aggregateExp = AggregateExpression(Max(attr), Partial, false) - (e, attr, Alias(aggregateExp, attr.name)()) - } - val distinctAggFilters = distinctAggFilterAttrMap.map(_._1) - val distinctAggFilterAttrs = distinctAggFilterAttrMap.map(_._2) - val boolOrs = distinctAggFilterAttrMap.map(_._3) + def patchAggregateFunctionChildren( + af: AggregateFunction)( + attrs: Expression => Option[Expression]): AggregateFunction = { + val newChildren = af.children.map(c => attrs(c).getOrElse(c)) + af.withNewChildren(newChildren).asInstanceOf[AggregateFunction] + } - // Setup expand & aggregate operators for distinct aggregate expressions. - val distinctAggChildAttrLookup = distinctAggChildAttrMap.toMap - val distinctAggFilterAttrLookup = distinctAggFilterAttrMap.map { tuple3 => - tuple3._1 -> tuple3._3.toAttribute - }.toMap - val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map { - case ((group, expressions), i) => - val id = Literal(i + 1) + // Setup unique distinct aggregate children. + val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct + val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair) + val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2) + // Setup all the filters in distinct aggregate. + val distinctAggExprs = aggExpressions + .filter(e => e.isDistinct && e.children.exists(!_.foldable)) + val distinctAggFilterAttrMap = distinctAggExprs.collect { + case AggregateExpression(_, _, _, filter, _) if filter.isDefined => + val (e, attr) = expressionAttributePair(filter.get) + val aggregateExp = AggregateExpression(Max(attr), Partial, false) + (e, attr, Alias(aggregateExp, attr.name)()) + } + val distinctAggFilters = distinctAggFilterAttrMap.map(_._1) + val distinctAggFilterAttrs = distinctAggFilterAttrMap.map(_._2) + val boolOrs = distinctAggFilterAttrMap.map(_._3) - // Expand projection for filter - val filters = expressions.filter(_.filter.isDefined).map(_.filter.get) - val filterProjection = distinctAggFilters.map { - case e if filters.contains(e) => e - case e => nullify(e) - } + // Setup expand & aggregate operators for distinct aggregate expressions. + val distinctAggChildAttrLookup = distinctAggChildAttrMap.toMap + val distinctAggFilterAttrLookup = distinctAggFilterAttrMap.map { tuple3 => + tuple3._1 -> tuple3._3.toAttribute + }.toMap + val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map { + case ((group, expressions), i) => + val id = Literal(i + 1) - // Expand projection - val projection = distinctAggChildren.map { - case e if group.contains(e) => e - case e => nullify(e) - } :+ id + // Expand projection for filter + val filters = expressions.filter(_.filter.isDefined).map(_.filter.get) + val filterProjection = distinctAggFilters.map { + case e if filters.contains(e) => e + case e => nullify(e) + } - // Final aggregate - val operators = expressions.map { e => - val af = e.aggregateFunction - val naf = patchAggregateFunctionChildren(af) { x => - val condition = if (e.filter.isDefined) { - e.filter.map(distinctAggFilterAttrLookup.get(_)).get - } else { - None + // Expand projection + val projection = distinctAggChildren.map { + case e if group.contains(e) => e + case e => nullify(e) + } :+ id + + // Final aggregate + val operators = expressions.map { e => + val af = e.aggregateFunction + val naf = patchAggregateFunctionChildren(af) { x => + val condition = if (e.filter.isDefined) { + e.filter.map(distinctAggFilterAttrLookup.get(_)).get + } else { + None + } + distinctAggChildAttrLookup.get(x).map(evalWithinGroup(id, _, condition)) } - distinctAggChildAttrLookup.get(x).map(evalWithinGroup(id, _, condition)) + (e, e.copy(aggregateFunction = naf, isDistinct = false, filter = None)) } - (e, e.copy(aggregateFunction = naf, isDistinct = false, filter = None)) - } - (projection ++ filterProjection, operators) - } + (projection ++ filterProjection, operators) + } - // Setup expand for the 'regular' aggregate expressions. - // only expand unfoldable children - val regularAggExprs = aggExpressions - .filter(e => !e.isDistinct && e.children.exists(!_.foldable)) - val regularAggFunChildren = regularAggExprs - .flatMap(_.aggregateFunction.children.filter(!_.foldable)) - val regularAggFilterAttrs = regularAggExprs.flatMap(_.filterAttributes) - val regularAggChildren = (regularAggFunChildren ++ regularAggFilterAttrs).distinct - val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair) + // Setup expand for the 'regular' aggregate expressions. + // only expand unfoldable children + val regularAggExprs = aggExpressions + .filter(e => !e.isDistinct && e.children.exists(!_.foldable)) + val regularAggFunChildren = regularAggExprs + .flatMap(_.aggregateFunction.children.filter(!_.foldable)) + val regularAggFilterAttrs = regularAggExprs.flatMap(_.filterAttributes) + val regularAggChildren = (regularAggFunChildren ++ regularAggFilterAttrs).distinct + val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair) - // Setup aggregates for 'regular' aggregate expressions. - val regularGroupId = Literal(0) - val regularAggChildAttrLookup = regularAggChildAttrMap.toMap - val regularAggOperatorMap = regularAggExprs.map { e => - // Perform the actual aggregation in the initial aggregate. - val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup.get) - // We changed the attributes in the [[Expand]] output using expressionAttributePair. - // So we need to replace the attributes in FILTER expression with new ones. - val filterOpt = e.filter.map(_.transform { - case a: Attribute => regularAggChildAttrLookup.getOrElse(a, a) - }) - val operator = Alias(e.copy(aggregateFunction = af, filter = filterOpt), e.sql)() + // Setup aggregates for 'regular' aggregate expressions. + val regularGroupId = Literal(0) + val regularAggChildAttrLookup = regularAggChildAttrMap.toMap + val regularAggOperatorMap = regularAggExprs.map { e => + // Perform the actual aggregation in the initial aggregate. + val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup.get) + // We changed the attributes in the [[Expand]] output using expressionAttributePair. + // So we need to replace the attributes in FILTER expression with new ones. + val filterOpt = e.filter.map(_.transform { + case a: Attribute => regularAggChildAttrLookup.getOrElse(a, a) + }) + val operator = Alias(e.copy(aggregateFunction = af, filter = filterOpt), e.sql)() - // Select the result of the first aggregate in the last aggregate. - val result = AggregateExpression( - aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute, None), true), - mode = Complete, - isDistinct = false) + // Select the result of the first aggregate in the last aggregate. + val result = AggregateExpression( + aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute, None), true), + mode = Complete, + isDistinct = false) - // Some aggregate functions (COUNT) have the special property that they can return a - // non-null result without any input. We need to make sure we return a result in this case. - val resultWithDefault = af.defaultResult match { - case Some(lit) => Coalesce(Seq(result, lit)) - case None => result - } + // Some aggregate functions (COUNT) have the special property that they can return a + // non-null result without any input. We need to make sure we return a result in this case. + val resultWithDefault = af.defaultResult match { + case Some(lit) => Coalesce(Seq(result, lit)) + case None => result + } - // Return a Tuple3 containing: - // i. The original aggregate expression (used for look ups). - // ii. The actual aggregation operator (used in the first aggregate). - // iii. The operator that selects and returns the result (used in the second aggregate). - (e, operator, resultWithDefault) - } + // Return a Tuple3 containing: + // i. The original aggregate expression (used for look ups). + // ii. The actual aggregation operator (used in the first aggregate). + // iii. The operator that selects and returns the result (used in the second aggregate). + (e, operator, resultWithDefault) + } - // Construct the regular aggregate input projection only if we need one. - val regularAggProjection = if (regularAggExprs.nonEmpty) { - Seq(a.groupingExpressions ++ - distinctAggChildren.map(nullify) ++ - Seq(regularGroupId) ++ - distinctAggFilters.map(nullify) ++ - regularAggChildren) - } else { - Seq.empty[Seq[Expression]] - } + // Construct the regular aggregate input projection only if we need one. + val regularAggProjection = if (regularAggExprs.nonEmpty) { + Seq(a.groupingExpressions ++ + distinctAggChildren.map(nullify) ++ + Seq(regularGroupId) ++ + distinctAggFilters.map(nullify) ++ + regularAggChildren) + } else { + Seq.empty[Seq[Expression]] + } - // Construct the distinct aggregate input projections. - val regularAggNulls = regularAggChildren.map(nullify) - val distinctAggProjections = distinctAggOperatorMap.map { - case (projection, _) => - a.groupingExpressions ++ - projection ++ - regularAggNulls - } + // Construct the distinct aggregate input projections. + val regularAggNulls = regularAggChildren.map(nullify) + val distinctAggProjections = distinctAggOperatorMap.map { + case (projection, _) => + a.groupingExpressions ++ + projection ++ + regularAggNulls + } - // Construct the expand operator. - val expand = Expand( - regularAggProjection ++ distinctAggProjections, - groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ distinctAggFilterAttrs ++ - regularAggChildAttrMap.map(_._2), - a.child) + // Construct the expand operator. + val expand = Expand( + regularAggProjection ++ distinctAggProjections, + groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ distinctAggFilterAttrs ++ + regularAggChildAttrMap.map(_._2), + a.child) - // Construct the first aggregate operator. This de-duplicates all the children of - // distinct operators, and applies the regular aggregate operators. - val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid - val firstAggregate = Aggregate( - firstAggregateGroupBy, - firstAggregateGroupBy ++ boolOrs ++ regularAggOperatorMap.map(_._2), - expand) + // Construct the first aggregate operator. This de-duplicates all the children of + // distinct operators, and applies the regular aggregate operators. + val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid + val firstAggregate = Aggregate( + firstAggregateGroupBy, + firstAggregateGroupBy ++ boolOrs ++ regularAggOperatorMap.map(_._2), + expand) - // Construct the second aggregate - val transformations: Map[Expression, Expression] = - (distinctAggOperatorMap.flatMap(_._2) ++ - regularAggOperatorMap.map(e => (e._1, e._3))).toMap + // Construct the second aggregate + val transformations: Map[Expression, Expression] = + (distinctAggOperatorMap.flatMap(_._2) ++ + regularAggOperatorMap.map(e => (e._1, e._3))).toMap - val patchedAggExpressions = a.aggregateExpressions.map { e => - e.transformDown { - case e: Expression => - // The same GROUP BY clauses can have different forms (different names for instance) in - // the groupBy and aggregate expressions of an aggregate. This makes a map lookup - // tricky. So we do a linear search for a semantically equal group by expression. - groupByMap - .find(ge => e.semanticEquals(ge._1)) - .map(_._2) - .getOrElse(transformations.getOrElse(e, e)) - }.asInstanceOf[NamedExpression] + val patchedAggExpressions = a.aggregateExpressions.map { e => + e.transformDown { + case e: Expression => + // The same GROUP BY clauses can have different forms (different names for instance) in + // the groupBy and aggregate expressions of an aggregate. This makes a map lookup + // tricky. So we do a linear search for a semantically equal group by expression. + groupByMap + .find(ge => e.semanticEquals(ge._1)) + .map(_._2) + .getOrElse(transformations.getOrElse(e, e)) + }.asInstanceOf[NamedExpression] + } + Aggregate(groupByAttrs, patchedAggExpressions, firstAggregate) + } else { + a } - Aggregate(groupByAttrs, patchedAggExpressions, firstAggregate) } private def collectAggregateExprs(a: Aggregate): Seq[AggregateExpression] = { + // Collect all aggregate expressions. a.aggregateExpressions.flatMap { _.collect { case ae: AggregateExpression => ae }} From 0fcf643ff03492296a773e98edacf3e151fe34fc Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Thu, 30 Jul 2020 17:22:17 +0800 Subject: [PATCH 08/15] Update doc --- .../sql/catalyst/optimizer/RewriteDistinctAggregates.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 55469aa9aa57c..25282a2e9a183 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -151,7 +151,7 @@ import org.apache.spark.sql.types.IntegerType * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) * Aggregate( * key = ['key, 'cat1, 'cat2, 'gid] - * functions = [max('cond1), max('cond2), sum('value) FILTER (WHERE 'id > 1)] + * functions = [max('cond1), max('cond2), sum('value) FILTER (WHERE 'id > 3)] * output = ['key, 'cat1, 'cat2, 'gid, 'max_cond1, 'max_cond2, 'total]) * Expand( * projections = [('key, null, null, 0, null, null, cast('value as bigint), 'id), From 92a37a96ce5e4e43fb96c7faa1b34cb32e896955 Mon Sep 17 00:00:00 2001 From: beliefer Date: Thu, 30 Jul 2020 22:35:21 +0800 Subject: [PATCH 09/15] Optimize code. --- .../optimizer/RewriteDistinctAggregates.scala | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 25282a2e9a183..5ac5027cfe9e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -199,7 +199,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // clause for this rule because aggregation strategy can handle a single distinct aggregate // group without filter clause. // This check can produce false-positives, e.g., SUM(DISTINCT a) & COUNT(DISTINCT a). - distinctAggs.size > 1 || (distinctAggs.size == 1 && aggExpressions.exists(_.filter.isDefined)) + distinctAggs.size > 1 || distinctAggs.exists(_.filter.isDefined) } def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { @@ -209,6 +209,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { def rewrite(a: Aggregate): Aggregate = { val aggExpressions = collectAggregateExprs(a) + val distinctAggs = aggExpressions.filter(_.isDistinct) // Extract distinct aggregate expressions. val distinctAggGroups = aggExpressions.filter(_.isDistinct).groupBy { e => @@ -228,7 +229,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } // Aggregation strategy can handle queries with a single distinct group without filter clause. - if (distinctAggGroups.size > 1 || aggExpressions.exists(_.filter.isDefined)) { + if (distinctAggGroups.size > 1 || distinctAggs.exists(_.filter.isDefined)) { // Create the attributes for the grouping id and the group by clause. val gid = AttributeReference("gid", IntegerType, nullable = false)() val groupByMap = a.groupingExpressions.collect { @@ -259,21 +260,16 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // Setup all the filters in distinct aggregate. val distinctAggExprs = aggExpressions .filter(e => e.isDistinct && e.children.exists(!_.foldable)) - val distinctAggFilterAttrMap = distinctAggExprs.collect { + val (distinctAggFilters, distinctAggFilterAttrs, maxCond) = distinctAggExprs.collect { case AggregateExpression(_, _, _, filter, _) if filter.isDefined => val (e, attr) = expressionAttributePair(filter.get) val aggregateExp = AggregateExpression(Max(attr), Partial, false) (e, attr, Alias(aggregateExp, attr.name)()) - } - val distinctAggFilters = distinctAggFilterAttrMap.map(_._1) - val distinctAggFilterAttrs = distinctAggFilterAttrMap.map(_._2) - val boolOrs = distinctAggFilterAttrMap.map(_._3) + }.unzip3 // Setup expand & aggregate operators for distinct aggregate expressions. val distinctAggChildAttrLookup = distinctAggChildAttrMap.toMap - val distinctAggFilterAttrLookup = distinctAggFilterAttrMap.map { tuple3 => - tuple3._1 -> tuple3._3.toAttribute - }.toMap + val distinctAggFilterAttrLookup = distinctAggFilters.zip(maxCond.map(_.toAttribute)).toMap val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map { case ((group, expressions), i) => val id = Literal(i + 1) @@ -383,7 +379,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid val firstAggregate = Aggregate( firstAggregateGroupBy, - firstAggregateGroupBy ++ boolOrs ++ regularAggOperatorMap.map(_._2), + firstAggregateGroupBy ++ maxCond ++ regularAggOperatorMap.map(_._2), expand) // Construct the second aggregate From 7362dfbdc165f0ff24e24ac02074601852447cf8 Mon Sep 17 00:00:00 2001 From: beliefer Date: Thu, 30 Jul 2020 22:38:23 +0800 Subject: [PATCH 10/15] Optimize code. --- .../sql/catalyst/optimizer/RewriteDistinctAggregates.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 5ac5027cfe9e0..0c9e76ee9967e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -260,7 +260,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // Setup all the filters in distinct aggregate. val distinctAggExprs = aggExpressions .filter(e => e.isDistinct && e.children.exists(!_.foldable)) - val (distinctAggFilters, distinctAggFilterAttrs, maxCond) = distinctAggExprs.collect { + val (distinctAggFilters, distinctAggFilterAttrs, maxConds) = distinctAggExprs.collect { case AggregateExpression(_, _, _, filter, _) if filter.isDefined => val (e, attr) = expressionAttributePair(filter.get) val aggregateExp = AggregateExpression(Max(attr), Partial, false) @@ -269,7 +269,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // Setup expand & aggregate operators for distinct aggregate expressions. val distinctAggChildAttrLookup = distinctAggChildAttrMap.toMap - val distinctAggFilterAttrLookup = distinctAggFilters.zip(maxCond.map(_.toAttribute)).toMap + val distinctAggFilterAttrLookup = distinctAggFilters.zip(maxConds.map(_.toAttribute)).toMap val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map { case ((group, expressions), i) => val id = Literal(i + 1) @@ -379,7 +379,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid val firstAggregate = Aggregate( firstAggregateGroupBy, - firstAggregateGroupBy ++ maxCond ++ regularAggOperatorMap.map(_._2), + firstAggregateGroupBy ++ maxConds ++ regularAggOperatorMap.map(_._2), expand) // Construct the second aggregate From 9939ea7a852685a3d31136e30e2a9e44fbe7fb5f Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Fri, 31 Jul 2020 14:50:13 +0800 Subject: [PATCH 11/15] Add tests case like distinct 1 --- .../sql-tests/inputs/group-by-filter.sql | 3 ++ .../sql-tests/results/group-by-filter.sql.out | 33 ++++++++++++++++++- 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql index 0f660d8998627..35762cf7ce5c3 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql @@ -40,6 +40,7 @@ SELECT COUNT(DISTINCT id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm: SELECT COUNT(DISTINCT id), COUNT(DISTINCT id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") = "2001-01-01 00:00:00") FROM emp; SELECT COUNT(DISTINCT id) FILTER (WHERE hiredate = to_timestamp("2001-01-01 00:00:00")), COUNT(DISTINCT id) FILTER (WHERE hiredate = to_date('2001-01-01 00:00:00')) FROM emp; SELECT SUM(salary), COUNT(DISTINCT id), COUNT(DISTINCT id) FILTER (WHERE hiredate = date "2001-01-01") FROM emp; +SELECT COUNT(DISTINCT 1) FILTER (WHERE a = 1) FROM testData; -- Aggregate with filter and non-empty GroupBy expressions. SELECT a, COUNT(b) FILTER (WHERE a >= 2) FROM testData GROUP BY a; @@ -53,6 +54,7 @@ SELECT dept_id, SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-M SELECT dept_id, SUM(DISTINCT salary), SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") > "2001-01-01 00:00:00") FROM emp GROUP BY dept_id; SELECT dept_id, SUM(DISTINCT salary) FILTER (WHERE hiredate > date "2001-01-01"), SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") > "2001-01-01 00:00:00") FROM emp GROUP BY dept_id; SELECT dept_id, COUNT(id), SUM(DISTINCT salary), SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") > "2001-01-01") FROM emp GROUP BY dept_id; +SELECT b, COUNT(DISTINCT 1) FILTER (WHERE a = 1) FROM testData GROUP BY b; -- Aggregate with filter and grouped by literals. SELECT 'foo', COUNT(a) FILTER (WHERE b <= 2) FROM testData GROUP BY 1; @@ -82,6 +84,7 @@ select dept_id, sum(distinct (id + dept_id)) filter (where id > 200), count(dist select dept_id, avg(distinct (id + dept_id)) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D) from emp group by dept_id; select dept_id, count(distinct emp_name, hiredate) filter (where id > 200), sum(salary) from emp group by dept_id; select dept_id, count(distinct emp_name, hiredate) filter (where id > 0), sum(salary) from emp group by dept_id; +select dept_id, count(distinct 1), count(distinct 1) filter (where id > 200), sum(salary) from emp group by dept_id; -- Aggregate with filter and grouped by literals (hash aggregate), here the input table is filtered using WHERE. SELECT 'foo', APPROX_COUNT_DISTINCT(a) FILTER (WHERE b >= 0) FROM testData WHERE a = 0 GROUP BY 1; diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out index 4dac09fcefe2e..ba7792943ccc3 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 63 +-- Number of queries: 66 -- !query @@ -126,6 +126,14 @@ struct +-- !query output +1 + + -- !query SELECT a, COUNT(b) FILTER (WHERE a >= 2) FROM testData GROUP BY a -- !query schema @@ -261,6 +269,16 @@ struct +-- !query output +1 1 +2 1 +NULL 0 + + -- !query SELECT 'foo', COUNT(a) FILTER (WHERE b <= 2) FROM testData GROUP BY 1 -- !query schema @@ -566,6 +584,19 @@ struct 0)):bi NULL 1 400.0 +-- !query +select dept_id, count(distinct 1), count(distinct 1) filter (where id > 200), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,sum(salary):double> +-- !query output +10 1 0 400.0 +100 1 1 800.0 +20 1 1 300.0 +30 1 1 400.0 +70 1 1 150.0 +NULL 1 1 400.0 + + -- !query SELECT 'foo', APPROX_COUNT_DISTINCT(a) FILTER (WHERE b >= 0) FROM testData WHERE a = 0 GROUP BY 1 -- !query schema From 2dc6f32f3a065702fb5c974e4051f23675e32196 Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Fri, 31 Jul 2020 19:07:16 +0800 Subject: [PATCH 12/15] Optimize code --- .../optimizer/RewriteDistinctAggregates.scala | 14 +++++++++----- .../sql-tests/inputs/group-by-filter.sql | 2 ++ .../sql-tests/results/group-by-filter.sql.out | 16 ++++++++++++++++ 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 0c9e76ee9967e..5d08d6b56da61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -170,17 +170,21 @@ import org.apache.spark.sql.types.IntegerType * out all unused columns for the given group; this must be done in order to ensure correctness * later on. Groups can by identified by a group id (gid) column added by the expand operator. * If distinct group exists filter clause, the expand will calculate the filter and output it's - * result which will be used to calculate the global conditions equivalent to filter clauses. + * result (e.g. cond1) which will be used to calculate the global conditions (e.g. max_cond1) + * equivalent to filter clauses. * 2. De-duplicate the distinct paths and aggregate the non-aggregate path. The group by clause of * this aggregate consists of the original group by clause, all the requested distinct columns * and the group id. Both de-duplication of distinct column and the aggregation of the * non-distinct group take advantage of the fact that we group by the group id (gid) and that we * have nulled out all non-relevant columns the given group. If distinct group exists filter - * clause, we will use max to aggregate the results of the filter output in the previous step. - * These aggregate values are equivalent to filter clauses. + * clause, we will use max to aggregate the results (e.g. cond1) of the filter output in the + * previous step. These aggregate will output the global conditions (e.g. max_cond1) equivalent + * to filter clauses. * 3. Aggregating the distinct groups and combining this with the results of the non-distinct * aggregation. In this step we use the group id and the global condition to filter the inputs - * for the aggregate functions. The result of the non-distinct group are 'aggregated' by using + * for the aggregate functions. If the global condition (e.g. max_cond1) is true, it means at + * least one row of a distinct value satisfies the filter. This distinct value should be included + * in the aggregate function. The result of the non-distinct group are 'aggregated' by using * the first operator, it might be more elegant to use the native UDAF merge mechanism for this * in the future. * @@ -263,7 +267,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { val (distinctAggFilters, distinctAggFilterAttrs, maxConds) = distinctAggExprs.collect { case AggregateExpression(_, _, _, filter, _) if filter.isDefined => val (e, attr) = expressionAttributePair(filter.get) - val aggregateExp = AggregateExpression(Max(attr), Partial, false) + val aggregateExp = Max(attr).toAggregateExpression() (e, attr, Alias(aggregateExp, attr.name)()) }.unzip3 diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql index 35762cf7ce5c3..24d303621faea 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql @@ -41,6 +41,8 @@ SELECT COUNT(DISTINCT id), COUNT(DISTINCT id) FILTER (WHERE date_format(hiredate SELECT COUNT(DISTINCT id) FILTER (WHERE hiredate = to_timestamp("2001-01-01 00:00:00")), COUNT(DISTINCT id) FILTER (WHERE hiredate = to_date('2001-01-01 00:00:00')) FROM emp; SELECT SUM(salary), COUNT(DISTINCT id), COUNT(DISTINCT id) FILTER (WHERE hiredate = date "2001-01-01") FROM emp; SELECT COUNT(DISTINCT 1) FILTER (WHERE a = 1) FROM testData; +SELECT COUNT(DISTINCT id) FILTER (WHERE true) FROM emp; +SELECT COUNT(DISTINCT id) FILTER (WHERE false) FROM emp; -- Aggregate with filter and non-empty GroupBy expressions. SELECT a, COUNT(b) FILTER (WHERE a >= 2) FROM testData GROUP BY a; diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out index ba7792943ccc3..669da62ae3120 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out @@ -134,6 +134,22 @@ struct 1 +-- !query +SELECT COUNT(DISTINCT id) FILTER (WHERE true) FROM emp +-- !query schema +struct +-- !query output +8 + + +-- !query +SELECT COUNT(DISTINCT id) FILTER (WHERE false) FROM emp +-- !query schema +struct +-- !query output +0 + + -- !query SELECT a, COUNT(b) FILTER (WHERE a >= 2) FROM testData GROUP BY a -- !query schema From abafc208acb8fa6f2987a2ee00b1761a26086266 Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Fri, 31 Jul 2020 19:09:17 +0800 Subject: [PATCH 13/15] Optimize code --- .../test/resources/sql-tests/results/group-by-filter.sql.out | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out index 669da62ae3120..c349d9d84c226 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 66 +-- Number of queries: 68 -- !query From 39583dde43da9580245cd34768d3f613fab8b090 Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Mon, 3 Aug 2020 13:16:40 +0800 Subject: [PATCH 14/15] Optimize code --- .../sql/catalyst/optimizer/RewriteDistinctAggregates.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 5d08d6b56da61..c40668177e843 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -262,8 +262,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair) val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2) // Setup all the filters in distinct aggregate. - val distinctAggExprs = aggExpressions - .filter(e => e.isDistinct && e.children.exists(!_.foldable)) + val distinctAggExprs = aggExpressions.filter(e => e.isDistinct) val (distinctAggFilters, distinctAggFilterAttrs, maxConds) = distinctAggExprs.collect { case AggregateExpression(_, _, _, filter, _) if filter.isDefined => val (e, attr) = expressionAttributePair(filter.get) From 883973b9bc8a9c530a002cf4b48217546929fb5e Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Mon, 3 Aug 2020 13:40:19 +0800 Subject: [PATCH 15/15] Optimize code --- .../sql/catalyst/optimizer/RewriteDistinctAggregates.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index c40668177e843..af3a8fe684bb3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -262,8 +262,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair) val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2) // Setup all the filters in distinct aggregate. - val distinctAggExprs = aggExpressions.filter(e => e.isDistinct) - val (distinctAggFilters, distinctAggFilterAttrs, maxConds) = distinctAggExprs.collect { + val (distinctAggFilters, distinctAggFilterAttrs, maxConds) = distinctAggs.collect { case AggregateExpression(_, _, _, filter, _) if filter.isDefined => val (e, attr) = expressionAttributePair(filter.get) val aggregateExp = Max(attr).toAggregateExpression()