From 4a6f903897d28a3038918997e692410259a90ae3 Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Fri, 19 Jun 2020 10:36:52 +0800 Subject: [PATCH 01/11] Reuse completeNextStageWithFetchFailure --- .../apache/spark/scheduler/DAGSchedulerSuite.scala | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 9d412f2dba3ce..762b14e170fcc 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -1796,9 +1796,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // lets say there is a fetch failure in this task set, which makes us go back and // run stage 0, attempt 1 - complete(taskSets(1), Seq( - (FetchFailed(makeBlockManagerId("hostA"), - shuffleDep1.shuffleId, 0L, 0, 0, "ignored"), null))) + completeNextStageWithFetchFailure(1, 0, shuffleDep1) scheduler.resubmitFailedStages() // stage 0, attempt 1 should have the properties of job2 @@ -1872,9 +1870,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // have the second stage complete normally completeShuffleMapStageSuccessfully(1, 0, 1, Seq("hostA", "hostC")) // fail the third stage because hostA went down - complete(taskSets(2), Seq( - (FetchFailed(makeBlockManagerId("hostA"), - shuffleDepTwo.shuffleId, 0L, 0, 0, "ignored"), null))) + completeNextStageWithFetchFailure(2, 0, shuffleDepTwo) // TODO assert this: // blockManagerMaster.removeExecutor("hostA-exec") // have DAGScheduler try again @@ -1900,9 +1896,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // complete stage 1 completeShuffleMapStageSuccessfully(1, 0, 1) // pretend stage 2 failed because hostA went down - complete(taskSets(2), Seq( - (FetchFailed(makeBlockManagerId("hostA"), - shuffleDepTwo.shuffleId, 0L, 0, 0, "ignored"), null))) + completeNextStageWithFetchFailure(2, 0, shuffleDepTwo) // TODO assert this: // blockManagerMaster.removeExecutor("hostA-exec") // DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun. From 5427485025a3ed251f245387521311b33724da3c Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Thu, 16 Jul 2020 18:16:25 +0800 Subject: [PATCH 02/11] add new rule to project filter --- .../sql/catalyst/analysis/Analyzer.scala | 12 +- .../spark/sql/catalyst/dsl/package.scala | 2 + .../expressions/aggregate/interfaces.scala | 18 +- .../sql/catalyst/optimizer/Optimizer.scala | 2 + .../optimizer/ProjectFilterInAggregates.scala | 178 ++++++++++++++++++ .../optimizer/RewriteDistinctAggregates.scala | 5 +- .../analysis/AnalysisErrorSuite.scala | 5 - .../ProjectFilterInAggregatesSuite.scala | 59 ++++++ .../sql/execution/aggregate/AggUtils.scala | 19 +- .../sql-tests/inputs/group-by-filter.sql | 41 ++-- .../inputs/postgreSQL/aggregates_part3.sql | 7 +- .../inputs/postgreSQL/groupingsets.sql | 5 +- 12 files changed, 306 insertions(+), 47 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ProjectFilterInAggregates.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectFilterInAggregatesSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 023ef2ee17473..f0d75be54adb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1922,15 +1922,9 @@ class Analyzer( } // We get an aggregate function, we need to wrap it in an AggregateExpression. case agg: AggregateFunction => - // TODO: SPARK-30276 Support Filter expression allows simultaneous use of DISTINCT - if (filter.isDefined) { - if (isDistinct) { - failAnalysis("DISTINCT and FILTER cannot be used in aggregate functions " + - "at the same time") - } else if (!filter.get.deterministic) { - failAnalysis("FILTER expression is non-deterministic, " + - "it cannot be used in aggregate functions") - } + if (filter.isDefined && !filter.get.deterministic) { + failAnalysis("FILTER expression is non-deterministic, " + + "it cannot be used in aggregate functions") } AggregateExpression(agg, Complete, isDistinct, filter) // This function is not an aggregate function, just return the resolved one. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 26f5bee72092c..3bdd9122e5da1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -173,6 +173,8 @@ package object dsl { def count(e: Expression): Expression = Count(e).toAggregateExpression() def countDistinct(e: Expression*): Expression = Count(e).toAggregateExpression(isDistinct = true) + def countDistinct(filter: Option[Expression], e: Expression*): Expression = + Count(e).toAggregateExpression(isDistinct = true, filter = filter) def approxCountDistinct(e: Expression, rsd: Double = 0.05): Expression = HyperLogLogPlusPlus(e, rsd).toAggregateExpression() def avg(e: Expression): Expression = Average(e).toAggregateExpression() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 8e8862edb6dd5..0e46dcc0ee3d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -216,15 +216,21 @@ abstract class AggregateFunction extends Expression { def toAggregateExpression(): AggregateExpression = toAggregateExpression(isDistinct = false) /** - * Wraps this [[AggregateFunction]] in an [[AggregateExpression]] and sets `isDistinct` - * flag of the [[AggregateExpression]] to the given value because + * Wraps this [[AggregateFunction]] in an [[AggregateExpression]] with `isDistinct` + * flag and an optional `filter` of the [[AggregateExpression]] to the given value because * [[AggregateExpression]] is the container of an [[AggregateFunction]], aggregation mode, - * and the flag indicating if this aggregation is distinct aggregation or not. - * An [[AggregateFunction]] should not be used without being wrapped in + * the flag indicating if this aggregation is distinct aggregation or not and the optional + * `filter`. An [[AggregateFunction]] should not be used without being wrapped in * an [[AggregateExpression]]. */ - def toAggregateExpression(isDistinct: Boolean): AggregateExpression = { - AggregateExpression(aggregateFunction = this, mode = Complete, isDistinct = isDistinct) + def toAggregateExpression( + isDistinct: Boolean, + filter: Option[Expression] = None): AggregateExpression = { + AggregateExpression( + aggregateFunction = this, + mode = Complete, + isDistinct = isDistinct, + filter = filter) } def sql(isDistinct: Boolean): String = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 33da482c4eea4..4cb010f876154 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -140,6 +140,7 @@ abstract class Optimizer(catalogManager: CatalogManager) RewriteNonCorrelatedExists, ComputeCurrentTime, GetCurrentDatabaseAndCatalog(catalogManager), + ProjectFilterInAggregates, RewriteDistinctAggregates, ReplaceDeduplicateWithAggregate) :: ////////////////////////////////////////////////////////////////////////////////////////// @@ -237,6 +238,7 @@ abstract class Optimizer(catalogManager: CatalogManager) ReplaceExpressions.ruleName :: ComputeCurrentTime.ruleName :: GetCurrentDatabaseAndCatalog(catalogManager).ruleName :: + ProjectFilterInAggregates.ruleName :: RewriteDistinctAggregates.ruleName :: ReplaceDeduplicateWithAggregate.ruleName :: ReplaceIntersectWithSemiJoin.ruleName :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ProjectFilterInAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ProjectFilterInAggregates.scala new file mode 100644 index 0000000000000..825ea396d4cc2 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ProjectFilterInAggregates.scala @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Expression, If, IsNotNull, Literal, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * If an aggregate query with filter clause, this rule will create a project node so as to filter + * the output of aggregate's child in advance. + * + * First example: query with filter clauses (in sql): + * {{{ + * val data = Seq( + * (1, "a", "ca1", "cb1", 10), + * (2, "a", "ca1", "cb2", 5), + * (3, "b", "ca1", "cb1", 13)) + * .toDF("id", "key", "cat1", "cat2", "value") + * data.createOrReplaceTempView("data") + * + * SELECT + * COUNT(DISTINCT cat1) AS cat1_cnt, + * COUNT(DISTINCT cat2) FILTER (WHERE id > 1) AS cat2_cnt, + * SUM(value) AS total, + * SUM(value) FILTER (WHERE key = "a") AS total2 + * FROM + * data + * GROUP BY + * key + * }}} + * + * This translates to the following (pseudo) logical plan: + * {{{ + * Aggregate( + * key = ['key] + * functions = [COUNT(DISTINCT 'cat1), + * COUNT(DISTINCT 'cat2) with FILTER('id > 1), + * SUM('value), + * SUM('value) with FILTER('key = "a")] + * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total, 'total2]) + * LocalTableScan [...] + * }}} + * + * This rule rewrites this logical plan to the following (pseudo) logical plan: + * {{{ + * Aggregate( + * key = ['key] + * functions = [count('_gen_attr_1), + * count('_gen_attr_2) with FILTER('_gen_attr_2 is not null), + * sum('_gen_attr_3), + * sum('_gen_attr_4) with FILTER('_gen_attr_4 is not null)] + * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total, 'total2]) + * Project( + * projectionList = ['key, + * 'cat1, + * if ('id > 1) 'cat2 else null, + * cast('value as bigint), + * if ('key = "a") cast('value as bigint) else null] + * output = ['key, '_gen_attr_1, '_gen_attr_2, '_gen_attr_3, '_gen_attr_4]) + * LocalTableScan [...] + * }}} + * + * The rule does the following things here: + * 1. Project the output of the child of the aggregate query. There are two aggregation + * groups in this query: + * i. the group without filter clause; + * ii. the group with filter clause; + * When there is at least one aggregate function having the filter clause, we add a project + * node on the input plan. + * 2. Avoid projections that may output the same attributes. There are three aggregation groups + * in this query: + * i. the non-distinct 'cat1 group without filter clause; + * ii. the distinct 'cat1 group without filter clause; + * iii. the distinct 'cat1 group with filter clause. + * The attributes referenced by different aggregate expressions are likely to overlap, + * and if no additional processing is performed, data loss will occur. If we directly output + * the attributes of the aggregate expression, we may get three attributes 'cat1. To prevent + * this, we generate new attributes (e.g. '_gen_attr_1) and replace the original ones. + * + * Why we need the first phase? guaranteed to compute filter clauses in the first aggregate + * locally. + * Note: after generate new attributes, the aggregate may have at least two distinct groups, + * so may trigger [[RewriteDistinctAggregates]]. + */ +object ProjectFilterInAggregates extends Rule[LogicalPlan] { + + private def collectAggregateExprs(exprs: Seq[Expression]): Seq[AggregateExpression] = { + exprs.flatMap { _.collect { + case ae: AggregateExpression => ae + }} + } + + private def mayNeedtoProject(exprs: Seq[Expression]): Boolean = { + collectAggregateExprs(exprs).exists(_.filter.isDefined) + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case a: Aggregate if mayNeedtoProject(a.aggregateExpressions) => project(a) + } + + def project(a: Aggregate): Aggregate = { + val aggExpressions = collectAggregateExprs(a.aggregateExpressions) + // Constructs pairs between old and new expressions for aggregates. + val aggExprs = aggExpressions.filter(e => e.children.exists(!_.foldable)) + val (projections, aggPairs) = aggExprs.map { + case ae @ AggregateExpression(af, _, _, filter, _) => + // First, In order to reduce costs, it is better to handle the filter clause locally. + // e.g. COUNT (DISTINCT a) FILTER (WHERE id > 1), evaluate expression + // If(id > 1) 'a else null first, and use the result as output. + // Second, If at least two DISTINCT aggregate expression which may references the + // same attributes. We need to construct the generated attributes so as the output not + // lost. e.g. SUM (DISTINCT a), COUNT (DISTINCT a) FILTER (WHERE id > 1) will output + // attribute '_gen_attr-1 and attribute '_gen_attr-2 instead of two 'a. + // Note: The illusionary mechanism may result in at least two distinct groups, so the + // RewriteDistinctAggregates may rewrite the logical plan. + val unfoldableChildren = af.children.filter(!_.foldable) + // Expand projection + val projectionMap = unfoldableChildren.map { + case e if filter.isDefined => + val ife = If(filter.get, e, Literal.create(null, e.dataType)) + e -> Alias(ife, s"_gen_attr_${NamedExpression.newExprId.id}")() + // For convenience and unification, we always alias the column, even if + // there is no filter. + case e => e -> Alias(e, s"_gen_attr_${NamedExpression.newExprId.id}")() + } + val projection = projectionMap.map(_._2) + val exprAttrs = projectionMap.map { kv => + (kv._1, kv._2.toAttribute) + } + val exprAttrLookup = exprAttrs.toMap + val newChildren = af.children.map(c => exprAttrLookup.getOrElse(c, c)) + val raf = af.withNewChildren(newChildren).asInstanceOf[AggregateFunction] + val aggExpr = if (filter.isDefined) { + // When the filter execution result is false, the conditional expression will + // output null, it will affect the results of those aggregate functions not + // ignore nulls (e.g. count). So we add a new filter with IsNotNull. + ae.copy(aggregateFunction = raf, filter = Some(IsNotNull(newChildren.last))) + } else { + ae.copy(aggregateFunction = raf, filter = None) + } + + (projection, (ae, aggExpr)) + }.unzip + // Construct the aggregate input projection. + val namedGroupingProjection = a.groupingExpressions.flatMap { e => + e.collect { + case ar: AttributeReference => ar + } + } + val rewriteAggProjection = namedGroupingProjection ++ projections.flatten + // Construct the project operator. + val project = Project(rewriteAggProjection, a.child) + val rewriteAggExprLookup = aggPairs.toMap + val patchedAggExpressions = a.aggregateExpressions.map { e => + e.transformDown { + case ae: AggregateExpression => rewriteAggExprLookup.getOrElse(ae, ae) + }.asInstanceOf[NamedExpression] + } + Aggregate(a.groupingExpressions, patchedAggExpressions, project) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index e5571069a7c41..ad41eb8f4b729 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -226,7 +226,10 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { val naf = patchAggregateFunctionChildren(af) { x => distinctAggChildAttrLookup.get(x).map(evalWithinGroup(id, _)) } - (e, e.copy(aggregateFunction = naf, isDistinct = false)) + val filterOpt = e.filter.map(_.transform { + case a: Attribute => distinctAggChildAttrLookup.getOrElse(a, a) + }) + (e, e.copy(aggregateFunction = naf, isDistinct = false, filter = filterOpt)) } (projection, operators) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 166ffec44a60d..a99f7e2be6e7e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -207,11 +207,6 @@ class AnalysisErrorSuite extends AnalysisTest { "FILTER (WHERE c > 1)"), "FILTER predicate specified, but aggregate is not an aggregate function" :: Nil) - errorTest( - "DISTINCT aggregate function with filter predicate", - CatalystSqlParser.parsePlan("SELECT count(DISTINCT a) FILTER (WHERE c > 1) FROM TaBlE2"), - "DISTINCT and FILTER cannot be used in aggregate functions at the same time" :: Nil) - errorTest( "non-deterministic filter predicate in aggregate functions", CatalystSqlParser.parsePlan("SELECT count(a) FILTER (WHERE rand(int(c)) > 1) FROM TaBlE2"), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectFilterInAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectFilterInAggregatesSuite.scala new file mode 100644 index 0000000000000..4bf5b39d6d180 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectFilterInAggregatesSuite.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{EqualTo, Literal} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LocalRelation, LogicalPlan, Project} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, GROUP_BY_ORDINAL} +import org.apache.spark.sql.types.{IntegerType, StringType} + +class ProjectFilterInAggregatesSuite extends PlanTest { + override val conf = new SQLConf().copy(CASE_SENSITIVE -> false, GROUP_BY_ORDINAL -> false) + val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) + val analyzer = new Analyzer(catalog, conf) + + val nullInt = Literal(null, IntegerType) + val nullString = Literal(null, StringType) + val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, 'e.int) + + private def checkGenerate(generate: LogicalPlan): Unit = generate match { + case Aggregate(_, _, _: Project) => + case _ => fail(s"Plan is not generated:\n$generate") + } + + test("single distinct group with filter") { + val input = testRelation + .groupBy('a)(countDistinct(Some(EqualTo('d, Literal(""))), 'e)) + .analyze + checkGenerate(ProjectFilterInAggregates(input)) + } + + test("at least one distinct group with filter") { + val input = testRelation + .groupBy('a)(countDistinct(Some(EqualTo('d, Literal(""))), 'e), countDistinct('d)) + .analyze + checkGenerate(ProjectFilterInAggregates(input)) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 761ac20e84744..40c022932886d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -142,19 +142,25 @@ object AggUtils { val distinctAttributes = normalizedNamedDistinctExpressions.map(_.toAttribute) val groupingAttributes = groupingExpressions.map(_.toAttribute) + val distinctFilterAttributes = functionsWithDistinct.flatMap(_.filterAttributes.toSeq) // 1. Create an Aggregate Operator for partial aggregations. val partialAggregate: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) // We will group by the original grouping expression, plus an additional expression for the - // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping - // expressions will be [key, value]. + // DISTINCT column and the referred attributes in the FILTER clause associated with each + // aggregate function. For example, + // 1.for AVG(DISTINCT value) GROUP BY key, the grouping expressions will be [key, value]. + // 2.for AVG (DISTINCT value) Filter (WHERE age > 20) GROUP BY key, the grouping expression + // will be [key, value, age]. createAggregate( - groupingExpressions = groupingExpressions ++ normalizedNamedDistinctExpressions, + groupingExpressions = groupingExpressions ++ normalizedNamedDistinctExpressions ++ + distinctFilterAttributes, aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, resultExpressions = groupingAttributes ++ distinctAttributes ++ + distinctFilterAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), child = child) } @@ -166,11 +172,13 @@ object AggUtils { createAggregate( requiredChildDistributionExpressions = Some(groupingAttributes ++ distinctAttributes), - groupingExpressions = groupingAttributes ++ distinctAttributes, + groupingExpressions = groupingAttributes ++ distinctAttributes ++ + distinctFilterAttributes, aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length, resultExpressions = groupingAttributes ++ distinctAttributes ++ + distinctFilterAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), child = partialAggregate) } @@ -201,7 +209,8 @@ object AggUtils { // its input will have distinct arguments. // We just keep the isDistinct setting to true, so when users look at the query plan, // they still can see distinct aggregations. - val expr = AggregateExpression(func, Partial, isDistinct = true) + val filter = functionsWithDistinct(i).filter + val expr = AggregateExpression(func, Partial, isDistinct = true, filter) // Use original AggregationFunction to lookup attributes, which is used to build // aggregateFunctionToAttribute val attr = functionsWithDistinct(i).resultAttribute diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql index beb5b9e5fe516..c36fd3bee7ded 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql @@ -33,8 +33,10 @@ SELECT COUNT(id) FILTER (WHERE hiredate = date "2001-01-01") FROM emp; SELECT COUNT(id) FILTER (WHERE hiredate = to_date('2001-01-01 00:00:00')) FROM emp; SELECT COUNT(id) FILTER (WHERE hiredate = to_timestamp("2001-01-01 00:00:00")) FROM emp; SELECT COUNT(id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") = "2001-01-01") FROM emp; --- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT --- SELECT COUNT(DISTINCT id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") = "2001-01-01 00:00:00") FROM emp; +SELECT COUNT(DISTINCT id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") = "2001-01-01 00:00:00") FROM emp; +SELECT COUNT(DISTINCT id), COUNT(DISTINCT id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") = "2001-01-01 00:00:00") FROM emp; +SELECT COUNT(DISTINCT id) FILTER (WHERE hiredate = to_timestamp("2001-01-01 00:00:00")), COUNT(DISTINCT id) FILTER (WHERE hiredate = to_date('2001-01-01 00:00:00')) FROM emp; +SELECT SUM(salary), COUNT(DISTINCT id), COUNT(DISTINCT id) FILTER (WHERE hiredate = date "2001-01-01") FROM emp; -- Aggregate with filter and non-empty GroupBy expressions. SELECT a, COUNT(b) FILTER (WHERE a >= 2) FROM testData GROUP BY a; @@ -44,8 +46,10 @@ SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > date "2003-01-01") FROM emp SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > to_date("2003-01-01")) FROM emp GROUP BY dept_id; SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > to_timestamp("2003-01-01 00:00:00")) FROM emp GROUP BY dept_id; SELECT dept_id, SUM(salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") > "2003-01-01") FROM emp GROUP BY dept_id; --- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT --- SELECT dept_id, SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") > "2001-01-01 00:00:00") FROM emp GROUP BY dept_id; +SELECT dept_id, SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") > "2001-01-01 00:00:00") FROM emp GROUP BY dept_id; +SELECT dept_id, SUM(DISTINCT salary), SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") > "2001-01-01 00:00:00") FROM emp GROUP BY dept_id; +SELECT dept_id, SUM(DISTINCT salary) FILTER (WHERE hiredate > date "2001-01-01"), SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") > "2001-01-01 00:00:00") FROM emp GROUP BY dept_id; +SELECT dept_id, COUNT(id), SUM(DISTINCT salary), SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") > "2001-01-01") FROM emp GROUP BY dept_id; -- Aggregate with filter and grouped by literals. SELECT 'foo', COUNT(a) FILTER (WHERE b <= 2) FROM testData GROUP BY 1; @@ -58,13 +62,23 @@ select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary), select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary), sum(salary) filter (where id + dept_id > 500) from emp group by dept_id; select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id > 200) from emp group by dept_id; select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id + dept_id > 500) from emp group by dept_id; --- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT --- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate), sum(salary) from emp group by dept_id; --- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id; --- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D) from emp group by dept_id; --- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id > 200) from emp group by dept_id; --- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name), sum(salary) from emp group by dept_id; --- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id; +select dept_id, count(distinct emp_name) filter (where id > 200), sum(salary) from emp group by dept_id; +select dept_id, count(distinct emp_name) filter (where id + dept_id > 500), sum(salary) from emp group by dept_id; +select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id > 200), sum(salary) from emp group by dept_id; +select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id + dept_id > 500), sum(salary) from emp group by dept_id; +select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id > 200), sum(salary), sum(salary) filter (where id > 200) from emp group by dept_id; +select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id + dept_id > 500), sum(salary), sum(salary) filter (where id > 200) from emp group by dept_id; +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate), sum(salary) from emp group by dept_id; +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id; +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D) from emp group by dept_id; +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id > 200) from emp group by dept_id; +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name), sum(salary) from emp group by dept_id; +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id; +select dept_id, sum(distinct (id + dept_id)) filter (where id > 200), count(distinct hiredate), sum(salary) from emp group by dept_id; +select dept_id, sum(distinct (id + dept_id)) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id; +select dept_id, avg(distinct (id + dept_id)) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D) from emp group by dept_id; +select dept_id, count(distinct emp_name, hiredate) filter (where id > 200), sum(salary) from emp group by dept_id; +select dept_id, count(distinct emp_name, hiredate) filter (where id > 0), sum(salary) from emp group by dept_id; -- Aggregate with filter and grouped by literals (hash aggregate), here the input table is filtered using WHERE. SELECT 'foo', APPROX_COUNT_DISTINCT(a) FILTER (WHERE b >= 0) FROM testData WHERE a = 0 GROUP BY 1; @@ -78,9 +92,8 @@ SELECT a + 2, COUNT(b) FILTER (WHERE b IN (1, 2)) FROM testData GROUP BY a + 1; SELECT a + 1 + 1, COUNT(b) FILTER (WHERE b > 0) FROM testData GROUP BY a + 1; -- Aggregate with filter, foldable input and multiple distinct groups. --- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT --- SELECT COUNT(DISTINCT b) FILTER (WHERE b > 0), COUNT(DISTINCT b, c) FILTER (WHERE b > 0 AND c > 2) --- FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a; +SELECT COUNT(DISTINCT b) FILTER (WHERE b > 0), COUNT(DISTINCT b, c) FILTER (WHERE b > 0 AND c > 2) +FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a; -- Check analysis exceptions SELECT a AS k, COUNT(b) FILTER (WHERE b > 0) FROM testData GROUP BY k; diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql index 746b677234832..657ea59ec8f11 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql @@ -241,10 +241,9 @@ select sum(1/ten) filter (where ten > 0) from tenk1; -- select ten, sum(distinct four) filter (where four::text ~ '123') from onek a -- group by ten; --- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT --- select ten, sum(distinct four) filter (where four > 10) from onek a --- group by ten --- having exists (select 1 from onek b where sum(distinct a.four) = b.four); +select ten, sum(distinct four) filter (where four > 10) from onek a +group by ten +having exists (select 1 from onek b where sum(distinct a.four) = b.four); -- [SPARK-28682] ANSI SQL: Collation Support -- select max(foo COLLATE "C") filter (where (bar collate "POSIX") > '0') diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/groupingsets.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/groupingsets.sql index fc54d179f742c..45617c53166aa 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/groupingsets.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/groupingsets.sql @@ -336,9 +336,8 @@ order by 2,1; -- order by 2,1; -- FILTER queries --- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT --- select ten, sum(distinct four) filter (where string(four) like '123') from onek a --- group by rollup(ten); +select ten, sum(distinct four) filter (where string(four) like '123') from onek a +group by rollup(ten); -- More rescan tests -- [SPARK-27877] ANSI SQL: LATERAL derived table(T491) From 4f9c7e6b58af4ead9a8b188476a573b769500482 Mon Sep 17 00:00:00 2001 From: beliefer Date: Thu, 16 Jul 2020 23:17:59 +0800 Subject: [PATCH 03/11] idempotence and regenerate golden files. --- .../optimizer/ProjectFilterInAggregates.scala | 19 +- .../sql-tests/results/explain-aqe.sql.out | 31 +- .../sql-tests/results/explain.sql.out | 29 +- .../sql-tests/results/group-by-filter.sql.out | 316 +++++++++++++++++- .../postgreSQL/aggregates_part3.sql.out | 16 +- .../results/postgreSQL/groupingsets.sql.out | 21 +- 6 files changed, 401 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ProjectFilterInAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ProjectFilterInAggregates.scala index 825ea396d4cc2..3ec6706c52a4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ProjectFilterInAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ProjectFilterInAggregates.scala @@ -107,12 +107,25 @@ object ProjectFilterInAggregates extends Rule[LogicalPlan] { }} } - private def mayNeedtoProject(exprs: Seq[Expression]): Boolean = { - collectAggregateExprs(exprs).exists(_.filter.isDefined) + private def mayNeedtoProject(a: Aggregate): Boolean = { + if (collectAggregateExprs(a.aggregateExpressions).exists(_.filter.isDefined)) { + var flag = true + a resolveOperatorsUp { + case p: Project => + if (p.output.exists(_.name.startsWith("_gen_attr_"))) { + flag = false + } + p + case other => other + } + flag + } else { + false + } } def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case a: Aggregate if mayNeedtoProject(a.aggregateExpressions) => project(a) + case a: Aggregate if mayNeedtoProject(a) => project(a) } def project(a: Aggregate): Aggregate = { diff --git a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out index 36757863ffcb5..132c08626faec 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out @@ -692,11 +692,12 @@ EXPLAIN FORMATTED struct -- !query output == Physical Plan == -AdaptiveSparkPlan (5) -+- HashAggregate (4) - +- Exchange (3) - +- HashAggregate (2) - +- Scan parquet default.explain_temp1 (1) +AdaptiveSparkPlan (6) ++- HashAggregate (5) + +- Exchange (4) + +- HashAggregate (3) + +- Project (2) + +- Scan parquet default.explain_temp1 (1) (1) Scan parquet default.explain_temp1 @@ -705,25 +706,29 @@ Batched: true Location [not included in comparison]/{warehouse_dir}/explain_temp1] ReadSchema: struct -(2) HashAggregate +(2) Project +Output [3]: [val#x AS _gen_attr_324#x, cast(key#x as bigint) AS _gen_attr_326#xL, if ((val#x > 1)) key#x else null AS _gen_attr_328#x] Input [2]: [key#x, val#x] + +(3) HashAggregate +Input [3]: [_gen_attr_324#x, _gen_attr_326#xL, _gen_attr_328#x] Keys: [] -Functions [3]: [partial_count(val#x), partial_sum(cast(key#x as bigint)), partial_count(key#x) FILTER (WHERE (val#x > 1))] +Functions [3]: [partial_count(_gen_attr_324#x), partial_sum(_gen_attr_326#xL), partial_count(_gen_attr_328#x) FILTER (WHERE isnotnull(_gen_attr_328#x))] Aggregate Attributes [3]: [count#xL, sum#xL, count#xL] Results [3]: [count#xL, sum#xL, count#xL] -(3) Exchange +(4) Exchange Input [3]: [count#xL, sum#xL, count#xL] Arguments: SinglePartition, true, [id=#x] -(4) HashAggregate +(5) HashAggregate Input [3]: [count#xL, sum#xL, count#xL] Keys: [] -Functions [3]: [count(val#x), sum(cast(key#x as bigint)), count(key#x)] -Aggregate Attributes [3]: [count(val#x)#xL, sum(cast(key#x as bigint))#xL, count(key#x)#xL] -Results [2]: [(count(val#x)#xL + sum(cast(key#x as bigint))#xL) AS TOTAL#xL, count(key#x)#xL AS count(key) FILTER (WHERE (val > 1))#xL] +Functions [3]: [count(_gen_attr_324#x), sum(_gen_attr_326#xL), count(_gen_attr_328#x)] +Aggregate Attributes [3]: [count(_gen_attr_324#x)#xL, sum(_gen_attr_326#xL)#xL, count(_gen_attr_328#x)#xL] +Results [2]: [(count(_gen_attr_324#x)#xL + sum(_gen_attr_326#xL)#xL) AS TOTAL#xL, count(_gen_attr_328#x)#xL AS count(key) FILTER (WHERE (val > 1))#xL] -(5) AdaptiveSparkPlan +(6) AdaptiveSparkPlan Output [2]: [TOTAL#xL, count(key) FILTER (WHERE (val > 1))#xL] Arguments: isFinalPlan=false diff --git a/sql/core/src/test/resources/sql-tests/results/explain.sql.out b/sql/core/src/test/resources/sql-tests/results/explain.sql.out index 2b07dac0e5d0a..c2c6992ec7e0e 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain.sql.out @@ -898,11 +898,12 @@ EXPLAIN FORMATTED struct -- !query output == Physical Plan == -* HashAggregate (5) -+- Exchange (4) - +- HashAggregate (3) - +- * ColumnarToRow (2) - +- Scan parquet default.explain_temp1 (1) +* HashAggregate (6) ++- Exchange (5) + +- HashAggregate (4) + +- * Project (3) + +- * ColumnarToRow (2) + +- Scan parquet default.explain_temp1 (1) (1) Scan parquet default.explain_temp1 @@ -914,23 +915,27 @@ ReadSchema: struct (2) ColumnarToRow [codegen id : 1] Input [2]: [key#x, val#x] -(3) HashAggregate +(3) Project [codegen id : 1] +Output [3]: [val#x AS _gen_attr_348#x, cast(key#x as bigint) AS _gen_attr_350#xL, if ((val#x > 1)) key#x else null AS _gen_attr_352#x] Input [2]: [key#x, val#x] + +(4) HashAggregate +Input [3]: [_gen_attr_348#x, _gen_attr_350#xL, _gen_attr_352#x] Keys: [] -Functions [3]: [partial_count(val#x), partial_sum(cast(key#x as bigint)), partial_count(key#x) FILTER (WHERE (val#x > 1))] +Functions [3]: [partial_count(_gen_attr_348#x), partial_sum(_gen_attr_350#xL), partial_count(_gen_attr_352#x) FILTER (WHERE isnotnull(_gen_attr_352#x))] Aggregate Attributes [3]: [count#xL, sum#xL, count#xL] Results [3]: [count#xL, sum#xL, count#xL] -(4) Exchange +(5) Exchange Input [3]: [count#xL, sum#xL, count#xL] Arguments: SinglePartition, true, [id=#x] -(5) HashAggregate [codegen id : 2] +(6) HashAggregate [codegen id : 2] Input [3]: [count#xL, sum#xL, count#xL] Keys: [] -Functions [3]: [count(val#x), sum(cast(key#x as bigint)), count(key#x)] -Aggregate Attributes [3]: [count(val#x)#xL, sum(cast(key#x as bigint))#xL, count(key#x)#xL] -Results [2]: [(count(val#x)#xL + sum(cast(key#x as bigint))#xL) AS TOTAL#xL, count(key#x)#xL AS count(key) FILTER (WHERE (val > 1))#xL] +Functions [3]: [count(_gen_attr_348#x), sum(_gen_attr_350#xL), count(_gen_attr_352#x)] +Aggregate Attributes [3]: [count(_gen_attr_348#x)#xL, sum(_gen_attr_350#xL)#xL, count(_gen_attr_352#x)#xL] +Results [2]: [(count(_gen_attr_348#x)#xL + sum(_gen_attr_350#xL)#xL) AS TOTAL#xL, count(_gen_attr_352#x)#xL AS count(key) FILTER (WHERE (val > 1))#xL] -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out index d41d25280146b..4dac09fcefe2e 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 37 +-- Number of queries: 63 -- !query @@ -94,6 +94,38 @@ struct +-- !query output +2 + + +-- !query +SELECT COUNT(DISTINCT id), COUNT(DISTINCT id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") = "2001-01-01 00:00:00") FROM emp +-- !query schema +struct +-- !query output +8 2 + + +-- !query +SELECT COUNT(DISTINCT id) FILTER (WHERE hiredate = to_timestamp("2001-01-01 00:00:00")), COUNT(DISTINCT id) FILTER (WHERE hiredate = to_date('2001-01-01 00:00:00')) FROM emp +-- !query schema +struct +-- !query output +2 2 + + +-- !query +SELECT SUM(salary), COUNT(DISTINCT id), COUNT(DISTINCT id) FILTER (WHERE hiredate = date "2001-01-01") FROM emp +-- !query schema +struct +-- !query output +2450.0 8 2 + + -- !query SELECT a, COUNT(b) FILTER (WHERE a >= 2) FROM testData GROUP BY a -- !query schema @@ -177,6 +209,58 @@ struct "2001-01-01 00:00:00") FROM emp GROUP BY dept_id +-- !query schema +struct 2001-01-01 00:00:00)):double> +-- !query output +10 300.0 +100 400.0 +20 300.0 +30 400.0 +70 150.0 +NULL NULL + + +-- !query +SELECT dept_id, SUM(DISTINCT salary), SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") > "2001-01-01 00:00:00") FROM emp GROUP BY dept_id +-- !query schema +struct 2001-01-01 00:00:00)):double> +-- !query output +10 300.0 300.0 +100 400.0 400.0 +20 300.0 300.0 +30 400.0 400.0 +70 150.0 150.0 +NULL 400.0 NULL + + +-- !query +SELECT dept_id, SUM(DISTINCT salary) FILTER (WHERE hiredate > date "2001-01-01"), SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") > "2001-01-01 00:00:00") FROM emp GROUP BY dept_id +-- !query schema +struct DATE '2001-01-01')):double,sum(DISTINCT salary) FILTER (WHERE (date_format(CAST(hiredate AS TIMESTAMP), yyyy-MM-dd HH:mm:ss) > 2001-01-01 00:00:00)):double> +-- !query output +10 300.0 300.0 +100 400.0 400.0 +20 300.0 300.0 +30 400.0 400.0 +70 150.0 150.0 +NULL NULL NULL + + +-- !query +SELECT dept_id, COUNT(id), SUM(DISTINCT salary), SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") > "2001-01-01") FROM emp GROUP BY dept_id +-- !query schema +struct 2001-01-01)):double> +-- !query output +10 3 300.0 300.0 +100 2 400.0 400.0 +20 1 300.0 300.0 +30 1 400.0 400.0 +70 1 150.0 150.0 +NULL 1 400.0 NULL + + -- !query SELECT 'foo', COUNT(a) FILTER (WHERE b <= 2) FROM testData GROUP BY 1 -- !query schema @@ -261,6 +345,227 @@ struct 200), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,sum(salary):double> +-- !query output +10 0 400.0 +100 2 800.0 +20 1 300.0 +30 1 400.0 +70 1 150.0 +NULL 1 400.0 + + +-- !query +select dept_id, count(distinct emp_name) filter (where id + dept_id > 500), sum(salary) from emp group by dept_id +-- !query schema +struct 500)):bigint,sum(salary):double> +-- !query output +10 0 400.0 +100 2 800.0 +20 0 300.0 +30 0 400.0 +70 1 150.0 +NULL 0 400.0 + + +-- !query +select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id > 200), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,sum(salary):double> +-- !query output +10 2 0 400.0 +100 2 2 800.0 +20 1 1 300.0 +30 1 1 400.0 +70 1 1 150.0 +NULL 1 1 400.0 + + +-- !query +select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id + dept_id > 500), sum(salary) from emp group by dept_id +-- !query schema +struct 500)):bigint,sum(salary):double> +-- !query output +10 2 0 400.0 +100 2 2 800.0 +20 1 0 300.0 +30 1 0 400.0 +70 1 1 150.0 +NULL 1 0 400.0 + + +-- !query +select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id > 200), sum(salary), sum(salary) filter (where id > 200) from emp group by dept_id +-- !query schema +struct 200)):bigint,sum(salary):double,sum(salary) FILTER (WHERE (id > 200)):double> +-- !query output +10 2 0 400.0 NULL +100 2 2 800.0 800.0 +20 1 1 300.0 300.0 +30 1 1 400.0 400.0 +70 1 1 150.0 150.0 +NULL 1 1 400.0 400.0 + + +-- !query +select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id + dept_id > 500), sum(salary), sum(salary) filter (where id > 200) from emp group by dept_id +-- !query schema +struct 500)):bigint,sum(salary):double,sum(salary) FILTER (WHERE (id > 200)):double> +-- !query output +10 2 0 400.0 NULL +100 2 2 800.0 800.0 +20 1 0 300.0 300.0 +30 1 0 400.0 400.0 +70 1 1 150.0 150.0 +NULL 1 0 400.0 400.0 + + +-- !query +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,count(DISTINCT hiredate):bigint,sum(salary):double> +-- !query output +10 0 2 400.0 +100 2 2 800.0 +20 1 1 300.0 +30 1 1 400.0 +70 1 1 150.0 +NULL 1 1 400.0 + + +-- !query +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,count(DISTINCT hiredate) FILTER (WHERE (hiredate > DATE '2003-01-01')):bigint,sum(salary):double> +-- !query output +10 0 1 400.0 +100 2 1 800.0 +20 1 0 300.0 +30 1 1 400.0 +70 1 1 150.0 +NULL 1 0 400.0 + + +-- !query +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D) from emp group by dept_id +-- !query schema +struct 200)):bigint,count(DISTINCT hiredate) FILTER (WHERE (hiredate > DATE '2003-01-01')):bigint,sum(salary) FILTER (WHERE (salary < 400.0)):double> +-- !query output +10 0 1 400.0 +100 2 1 NULL +20 1 0 300.0 +30 1 1 NULL +70 1 1 150.0 +NULL 1 0 NULL + + +-- !query +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id > 200) from emp group by dept_id +-- !query schema +struct 200)):bigint,count(DISTINCT hiredate) FILTER (WHERE (hiredate > DATE '2003-01-01')):bigint,sum(salary) FILTER (WHERE (salary < 400.0)):double,sum(salary) FILTER (WHERE (id > 200)):double> +-- !query output +10 0 1 400.0 NULL +100 2 1 NULL 800.0 +20 1 0 300.0 300.0 +30 1 1 NULL 400.0 +70 1 1 150.0 150.0 +NULL 1 0 NULL 400.0 + + +-- !query +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,count(DISTINCT emp_name):bigint,sum(salary):double> +-- !query output +10 0 2 400.0 +100 2 2 800.0 +20 1 1 300.0 +30 1 1 400.0 +70 1 1 150.0 +NULL 1 1 400.0 + + +-- !query +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,count(DISTINCT emp_name) FILTER (WHERE (hiredate > DATE '2003-01-01')):bigint,sum(salary):double> +-- !query output +10 0 1 400.0 +100 2 1 800.0 +20 1 0 300.0 +30 1 1 400.0 +70 1 1 150.0 +NULL 1 0 400.0 + + +-- !query +select dept_id, sum(distinct (id + dept_id)) filter (where id > 200), count(distinct hiredate), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,count(DISTINCT hiredate):bigint,sum(salary):double> +-- !query output +10 NULL 2 400.0 +100 1500 2 800.0 +20 320 1 300.0 +30 430 1 400.0 +70 870 1 150.0 +NULL NULL 1 400.0 + + +-- !query +select dept_id, sum(distinct (id + dept_id)) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,count(DISTINCT hiredate) FILTER (WHERE (hiredate > DATE '2003-01-01')):bigint,sum(salary):double> +-- !query output +10 NULL 1 400.0 +100 1500 1 800.0 +20 320 0 300.0 +30 430 1 400.0 +70 870 1 150.0 +NULL NULL 0 400.0 + + +-- !query +select dept_id, avg(distinct (id + dept_id)) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D) from emp group by dept_id +-- !query schema +struct 200)):double,count(DISTINCT hiredate) FILTER (WHERE (hiredate > DATE '2003-01-01')):bigint,sum(salary) FILTER (WHERE (salary < 400.0)):double> +-- !query output +10 NULL 1 400.0 +100 750.0 1 NULL +20 320.0 0 300.0 +30 430.0 1 NULL +70 870.0 1 150.0 +NULL NULL 0 NULL + + +-- !query +select dept_id, count(distinct emp_name, hiredate) filter (where id > 200), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,sum(salary):double> +-- !query output +10 0 400.0 +100 2 800.0 +20 1 300.0 +30 1 400.0 +70 1 150.0 +NULL 1 400.0 + + +-- !query +select dept_id, count(distinct emp_name, hiredate) filter (where id > 0), sum(salary) from emp group by dept_id +-- !query schema +struct 0)):bigint,sum(salary):double> +-- !query output +10 2 400.0 +100 2 800.0 +20 1 300.0 +30 1 400.0 +70 1 150.0 +NULL 1 400.0 + + -- !query SELECT 'foo', APPROX_COUNT_DISTINCT(a) FILTER (WHERE b >= 0) FROM testData WHERE a = 0 GROUP BY 1 -- !query schema @@ -309,6 +614,15 @@ struct<((a + 1) + 1):int,count(b) FILTER (WHERE (b > 0)):bigint> NULL 1 +-- !query +SELECT COUNT(DISTINCT b) FILTER (WHERE b > 0), COUNT(DISTINCT b, c) FILTER (WHERE b > 0 AND c > 2) +FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a +-- !query schema +struct 0)):bigint,count(DISTINCT b, c) FILTER (WHERE ((b > 0) AND (c > 2))):bigint> +-- !query output +1 1 + + -- !query SELECT a AS k, COUNT(b) FILTER (WHERE b > 0) FROM testData GROUP BY k -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part3.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part3.sql.out index 69f96b02782e3..e1f735e5fe1dc 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part3.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part3.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 4 +-- Number of queries: 5 -- !query @@ -27,6 +27,20 @@ struct 0)):d 2828.9682539682954 +-- !query +select ten, sum(distinct four) filter (where four > 10) from onek a +group by ten +having exists (select 1 from onek b where sum(distinct a.four) = b.four) +-- !query schema +struct 10)):bigint> +-- !query output +0 NULL +2 NULL +4 NULL +6 NULL +8 NULL + + -- !query select (select count(*) from (values (1)) t0(inner_c)) diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/groupingsets.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/groupingsets.sql.out index 7312c20876296..2619634d7d569 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/groupingsets.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/groupingsets.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 54 +-- Number of queries: 55 -- !query @@ -443,6 +443,25 @@ struct NULL 1 +-- !query +select ten, sum(distinct four) filter (where string(four) like '123') from onek a +group by rollup(ten) +-- !query schema +struct +-- !query output +0 NULL +1 NULL +2 NULL +3 NULL +4 NULL +5 NULL +6 NULL +7 NULL +8 NULL +9 NULL +NULL NULL + + -- !query select count(*) from gstest4 group by rollup(unhashable_col,unsortable_col) -- !query schema From fefbce04af1c62f02870a79686a54e7669584a69 Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Fri, 17 Jul 2020 11:01:04 +0800 Subject: [PATCH 04/11] generate attr use local index. --- .../optimizer/ProjectFilterInAggregates.scala | 8 ++++---- .../test/resources/sql-tests/results/explain.sql.out | 12 ++++++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ProjectFilterInAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ProjectFilterInAggregates.scala index 3ec6706c52a4e..0b66ab1edb8b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ProjectFilterInAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ProjectFilterInAggregates.scala @@ -145,13 +145,13 @@ object ProjectFilterInAggregates extends Rule[LogicalPlan] { // RewriteDistinctAggregates may rewrite the logical plan. val unfoldableChildren = af.children.filter(!_.foldable) // Expand projection - val projectionMap = unfoldableChildren.map { - case e if filter.isDefined => + val projectionMap = unfoldableChildren.zipWithIndex.map { + case (e, i) if filter.isDefined => val ife = If(filter.get, e, Literal.create(null, e.dataType)) - e -> Alias(ife, s"_gen_attr_${NamedExpression.newExprId.id}")() + e -> Alias(ife, s"_gen_attr_$i")() // For convenience and unification, we always alias the column, even if // there is no filter. - case e => e -> Alias(e, s"_gen_attr_${NamedExpression.newExprId.id}")() + case (e, i) => e -> Alias(e, s"_gen_attr_$i")() } val projection = projectionMap.map(_._2) val exprAttrs = projectionMap.map { kv => diff --git a/sql/core/src/test/resources/sql-tests/results/explain.sql.out b/sql/core/src/test/resources/sql-tests/results/explain.sql.out index c2c6992ec7e0e..5487d4d4ab857 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain.sql.out @@ -916,13 +916,13 @@ ReadSchema: struct Input [2]: [key#x, val#x] (3) Project [codegen id : 1] -Output [3]: [val#x AS _gen_attr_348#x, cast(key#x as bigint) AS _gen_attr_350#xL, if ((val#x > 1)) key#x else null AS _gen_attr_352#x] +Output [3]: [val#x AS _gen_attr_0#x, cast(key#x as bigint) AS _gen_attr_1#xL, if ((val#x > 1)) key#x else null AS _gen_attr_2#x] Input [2]: [key#x, val#x] (4) HashAggregate -Input [3]: [_gen_attr_348#x, _gen_attr_350#xL, _gen_attr_352#x] +Input [3]: [_gen_attr_0#x, _gen_attr_1#xL, _gen_attr_2#x] Keys: [] -Functions [3]: [partial_count(_gen_attr_348#x), partial_sum(_gen_attr_350#xL), partial_count(_gen_attr_352#x) FILTER (WHERE isnotnull(_gen_attr_352#x))] +Functions [3]: [partial_count(_gen_attr_0#x), partial_sum(_gen_attr_1#xL), partial_count(_gen_attr_2#x) FILTER (WHERE isnotnull(_gen_attr_2#x))] Aggregate Attributes [3]: [count#xL, sum#xL, count#xL] Results [3]: [count#xL, sum#xL, count#xL] @@ -933,9 +933,9 @@ Arguments: SinglePartition, true, [id=#x] (6) HashAggregate [codegen id : 2] Input [3]: [count#xL, sum#xL, count#xL] Keys: [] -Functions [3]: [count(_gen_attr_348#x), sum(_gen_attr_350#xL), count(_gen_attr_352#x)] -Aggregate Attributes [3]: [count(_gen_attr_348#x)#xL, sum(_gen_attr_350#xL)#xL, count(_gen_attr_352#x)#xL] -Results [2]: [(count(_gen_attr_348#x)#xL + sum(_gen_attr_350#xL)#xL) AS TOTAL#xL, count(_gen_attr_352#x)#xL AS count(key) FILTER (WHERE (val > 1))#xL] +Functions [3]: [count(_gen_attr_0#x), sum(_gen_attr_1#xL), count(_gen_attr_2#x)] +Aggregate Attributes [3]: [count(_gen_attr_0#x)#xL, sum(_gen_attr_1#xL)#xL, count(_gen_attr_2#x)#xL] +Results [2]: [(count(_gen_attr_0#x)#xL + sum(_gen_attr_1#xL)#xL) AS TOTAL#xL, count(_gen_attr_2#x)#xL AS count(key) FILTER (WHERE (val > 1))#xL] -- !query From 202a45492db302914bc5b3e05deb3e9827d4e232 Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Fri, 17 Jul 2020 14:46:08 +0800 Subject: [PATCH 05/11] generate attr use local index. --- .../catalyst/optimizer/ProjectFilterInAggregates.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ProjectFilterInAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ProjectFilterInAggregates.scala index 0b66ab1edb8b7..c94ae481b60dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ProjectFilterInAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ProjectFilterInAggregates.scala @@ -132,6 +132,7 @@ object ProjectFilterInAggregates extends Rule[LogicalPlan] { val aggExpressions = collectAggregateExprs(a.aggregateExpressions) // Constructs pairs between old and new expressions for aggregates. val aggExprs = aggExpressions.filter(e => e.children.exists(!_.foldable)) + var currentExprId = 0 val (projections, aggPairs) = aggExprs.map { case ae @ AggregateExpression(af, _, _, filter, _) => // First, In order to reduce costs, it is better to handle the filter clause locally. @@ -145,14 +146,15 @@ object ProjectFilterInAggregates extends Rule[LogicalPlan] { // RewriteDistinctAggregates may rewrite the logical plan. val unfoldableChildren = af.children.filter(!_.foldable) // Expand projection - val projectionMap = unfoldableChildren.zipWithIndex.map { - case (e, i) if filter.isDefined => + val projectionMap = unfoldableChildren.map { + case e if filter.isDefined => val ife = If(filter.get, e, Literal.create(null, e.dataType)) - e -> Alias(ife, s"_gen_attr_$i")() + e -> Alias(ife, s"_gen_attr_$currentExprId")() // For convenience and unification, we always alias the column, even if // there is no filter. - case (e, i) => e -> Alias(e, s"_gen_attr_$i")() + case e => e -> Alias(e, s"_gen_attr_$currentExprId")() } + currentExprId += 1 val projection = projectionMap.map(_._2) val exprAttrs = projectionMap.map { kv => (kv._1, kv._2.toAttribute) From a2c842e2fc502a7dbb2c9601645563ee145b5770 Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Fri, 17 Jul 2020 18:39:23 +0800 Subject: [PATCH 06/11] Update comment and regenerate golden file. --- .../spark/sql/execution/aggregate/AggUtils.scala | 5 +++-- .../resources/sql-tests/results/explain-aqe.sql.out | 12 ++++++------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 40c022932886d..a2ae6a6945e05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -152,8 +152,9 @@ object AggUtils { // DISTINCT column and the referred attributes in the FILTER clause associated with each // aggregate function. For example, // 1.for AVG(DISTINCT value) GROUP BY key, the grouping expressions will be [key, value]. - // 2.for AVG (DISTINCT value) Filter (WHERE age > 20) GROUP BY key, the grouping expression - // will be [key, value, age]. + // 2.for AVG (DISTINCT value) Filter (WHERE age > 20) GROUP BY key, this will be rewritten + // as AVG (DISTINCT _gen_attr_$id) Filter (WHERE _gen_attr_$id is not null). the grouping + // expression will be [key, _gen_attr_$id]. createAggregate( groupingExpressions = groupingExpressions ++ normalizedNamedDistinctExpressions ++ distinctFilterAttributes, diff --git a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out index 132c08626faec..30b469d17d166 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out @@ -707,13 +707,13 @@ Location [not included in comparison]/{warehouse_dir}/explain_temp1] ReadSchema: struct (2) Project -Output [3]: [val#x AS _gen_attr_324#x, cast(key#x as bigint) AS _gen_attr_326#xL, if ((val#x > 1)) key#x else null AS _gen_attr_328#x] +Output [3]: [val#x AS _gen_attr_0#x, cast(key#x as bigint) AS _gen_attr_1#xL, if ((val#x > 1)) key#x else null AS _gen_attr_2#x] Input [2]: [key#x, val#x] (3) HashAggregate -Input [3]: [_gen_attr_324#x, _gen_attr_326#xL, _gen_attr_328#x] +Input [3]: [_gen_attr_0#x, _gen_attr_1#xL, _gen_attr_2#x] Keys: [] -Functions [3]: [partial_count(_gen_attr_324#x), partial_sum(_gen_attr_326#xL), partial_count(_gen_attr_328#x) FILTER (WHERE isnotnull(_gen_attr_328#x))] +Functions [3]: [partial_count(_gen_attr_0#x), partial_sum(_gen_attr_1#xL), partial_count(_gen_attr_2#x) FILTER (WHERE isnotnull(_gen_attr_2#x))] Aggregate Attributes [3]: [count#xL, sum#xL, count#xL] Results [3]: [count#xL, sum#xL, count#xL] @@ -724,9 +724,9 @@ Arguments: SinglePartition, true, [id=#x] (5) HashAggregate Input [3]: [count#xL, sum#xL, count#xL] Keys: [] -Functions [3]: [count(_gen_attr_324#x), sum(_gen_attr_326#xL), count(_gen_attr_328#x)] -Aggregate Attributes [3]: [count(_gen_attr_324#x)#xL, sum(_gen_attr_326#xL)#xL, count(_gen_attr_328#x)#xL] -Results [2]: [(count(_gen_attr_324#x)#xL + sum(_gen_attr_326#xL)#xL) AS TOTAL#xL, count(_gen_attr_328#x)#xL AS count(key) FILTER (WHERE (val > 1))#xL] +Functions [3]: [count(_gen_attr_0#x), sum(_gen_attr_1#xL), count(_gen_attr_2#x)] +Aggregate Attributes [3]: [count(_gen_attr_0#x)#xL, sum(_gen_attr_1#xL)#xL, count(_gen_attr_2#x)#xL] +Results [2]: [(count(_gen_attr_0#x)#xL + sum(_gen_attr_1#xL)#xL) AS TOTAL#xL, count(_gen_attr_2#x)#xL AS count(key) FILTER (WHERE (val > 1))#xL] (6) AdaptiveSparkPlan Output [2]: [TOTAL#xL, count(key) FILTER (WHERE (val > 1))#xL] From 71277443c4951ed3fb95e912d457da870aec2118 Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Fri, 17 Jul 2020 19:00:15 +0800 Subject: [PATCH 07/11] regenerate golden file. --- .../optimizer/ProjectFilterInAggregates.scala | 17 ++++++++++------- .../sql-tests/results/explain-aqe.sql.out | 12 ++++++------ .../resources/sql-tests/results/explain.sql.out | 12 ++++++------ 3 files changed, 22 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ProjectFilterInAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ProjectFilterInAggregates.scala index c94ae481b60dc..320b7fe8e5c3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ProjectFilterInAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ProjectFilterInAggregates.scala @@ -147,14 +147,17 @@ object ProjectFilterInAggregates extends Rule[LogicalPlan] { val unfoldableChildren = af.children.filter(!_.foldable) // Expand projection val projectionMap = unfoldableChildren.map { - case e if filter.isDefined => - val ife = If(filter.get, e, Literal.create(null, e.dataType)) - e -> Alias(ife, s"_gen_attr_$currentExprId")() - // For convenience and unification, we always alias the column, even if - // there is no filter. - case e => e -> Alias(e, s"_gen_attr_$currentExprId")() + case e => + currentExprId += 1 + val ne = if (filter.isDefined) { + If(filter.get, e, Literal.create(null, e.dataType)) + } else { + e + } + // For convenience and unification, we always alias the column, even if + // there is no filter. + e -> Alias(ne, s"_gen_attr_$currentExprId")() } - currentExprId += 1 val projection = projectionMap.map(_._2) val exprAttrs = projectionMap.map { kv => (kv._1, kv._2.toAttribute) diff --git a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out index 30b469d17d166..e13530ea65713 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out @@ -707,13 +707,13 @@ Location [not included in comparison]/{warehouse_dir}/explain_temp1] ReadSchema: struct (2) Project -Output [3]: [val#x AS _gen_attr_0#x, cast(key#x as bigint) AS _gen_attr_1#xL, if ((val#x > 1)) key#x else null AS _gen_attr_2#x] +Output [3]: [val#x AS _gen_attr_1#x, cast(key#x as bigint) AS _gen_attr_2#xL, if ((val#x > 1)) key#x else null AS _gen_attr_3#x] Input [2]: [key#x, val#x] (3) HashAggregate -Input [3]: [_gen_attr_0#x, _gen_attr_1#xL, _gen_attr_2#x] +Input [3]: [_gen_attr_1#x, _gen_attr_2#xL, _gen_attr_3#x] Keys: [] -Functions [3]: [partial_count(_gen_attr_0#x), partial_sum(_gen_attr_1#xL), partial_count(_gen_attr_2#x) FILTER (WHERE isnotnull(_gen_attr_2#x))] +Functions [3]: [partial_count(_gen_attr_1#x), partial_sum(_gen_attr_2#xL), partial_count(_gen_attr_3#x) FILTER (WHERE isnotnull(_gen_attr_3#x))] Aggregate Attributes [3]: [count#xL, sum#xL, count#xL] Results [3]: [count#xL, sum#xL, count#xL] @@ -724,9 +724,9 @@ Arguments: SinglePartition, true, [id=#x] (5) HashAggregate Input [3]: [count#xL, sum#xL, count#xL] Keys: [] -Functions [3]: [count(_gen_attr_0#x), sum(_gen_attr_1#xL), count(_gen_attr_2#x)] -Aggregate Attributes [3]: [count(_gen_attr_0#x)#xL, sum(_gen_attr_1#xL)#xL, count(_gen_attr_2#x)#xL] -Results [2]: [(count(_gen_attr_0#x)#xL + sum(_gen_attr_1#xL)#xL) AS TOTAL#xL, count(_gen_attr_2#x)#xL AS count(key) FILTER (WHERE (val > 1))#xL] +Functions [3]: [count(_gen_attr_1#x), sum(_gen_attr_2#xL), count(_gen_attr_3#x)] +Aggregate Attributes [3]: [count(_gen_attr_1#x)#xL, sum(_gen_attr_2#xL)#xL, count(_gen_attr_3#x)#xL] +Results [2]: [(count(_gen_attr_1#x)#xL + sum(_gen_attr_2#xL)#xL) AS TOTAL#xL, count(_gen_attr_3#x)#xL AS count(key) FILTER (WHERE (val > 1))#xL] (6) AdaptiveSparkPlan Output [2]: [TOTAL#xL, count(key) FILTER (WHERE (val > 1))#xL] diff --git a/sql/core/src/test/resources/sql-tests/results/explain.sql.out b/sql/core/src/test/resources/sql-tests/results/explain.sql.out index 5487d4d4ab857..cccac580088a1 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain.sql.out @@ -916,13 +916,13 @@ ReadSchema: struct Input [2]: [key#x, val#x] (3) Project [codegen id : 1] -Output [3]: [val#x AS _gen_attr_0#x, cast(key#x as bigint) AS _gen_attr_1#xL, if ((val#x > 1)) key#x else null AS _gen_attr_2#x] +Output [3]: [val#x AS _gen_attr_1#x, cast(key#x as bigint) AS _gen_attr_2#xL, if ((val#x > 1)) key#x else null AS _gen_attr_3#x] Input [2]: [key#x, val#x] (4) HashAggregate -Input [3]: [_gen_attr_0#x, _gen_attr_1#xL, _gen_attr_2#x] +Input [3]: [_gen_attr_1#x, _gen_attr_2#xL, _gen_attr_3#x] Keys: [] -Functions [3]: [partial_count(_gen_attr_0#x), partial_sum(_gen_attr_1#xL), partial_count(_gen_attr_2#x) FILTER (WHERE isnotnull(_gen_attr_2#x))] +Functions [3]: [partial_count(_gen_attr_1#x), partial_sum(_gen_attr_2#xL), partial_count(_gen_attr_3#x) FILTER (WHERE isnotnull(_gen_attr_3#x))] Aggregate Attributes [3]: [count#xL, sum#xL, count#xL] Results [3]: [count#xL, sum#xL, count#xL] @@ -933,9 +933,9 @@ Arguments: SinglePartition, true, [id=#x] (6) HashAggregate [codegen id : 2] Input [3]: [count#xL, sum#xL, count#xL] Keys: [] -Functions [3]: [count(_gen_attr_0#x), sum(_gen_attr_1#xL), count(_gen_attr_2#x)] -Aggregate Attributes [3]: [count(_gen_attr_0#x)#xL, sum(_gen_attr_1#xL)#xL, count(_gen_attr_2#x)#xL] -Results [2]: [(count(_gen_attr_0#x)#xL + sum(_gen_attr_1#xL)#xL) AS TOTAL#xL, count(_gen_attr_2#x)#xL AS count(key) FILTER (WHERE (val > 1))#xL] +Functions [3]: [count(_gen_attr_1#x), sum(_gen_attr_2#xL), count(_gen_attr_3#x)] +Aggregate Attributes [3]: [count(_gen_attr_1#x)#xL, sum(_gen_attr_2#xL)#xL, count(_gen_attr_3#x)#xL] +Results [2]: [(count(_gen_attr_1#x)#xL + sum(_gen_attr_2#xL)#xL) AS TOTAL#xL, count(_gen_attr_3#x)#xL AS count(key) FILTER (WHERE (val > 1))#xL] -- !query From 98e97e8b075db32ee7eca48986611f5dc4543295 Mon Sep 17 00:00:00 2001 From: beliefer Date: Sat, 18 Jul 2020 09:58:14 +0800 Subject: [PATCH 08/11] Replace old attr to new attr. --- .../spark/sql/execution/aggregate/AggUtils.scala | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index a2ae6a6945e05..1cd4bef01851e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -142,7 +142,6 @@ object AggUtils { val distinctAttributes = normalizedNamedDistinctExpressions.map(_.toAttribute) val groupingAttributes = groupingExpressions.map(_.toAttribute) - val distinctFilterAttributes = functionsWithDistinct.flatMap(_.filterAttributes.toSeq) // 1. Create an Aggregate Operator for partial aggregations. val partialAggregate: SparkPlan = { @@ -156,12 +155,10 @@ object AggUtils { // as AVG (DISTINCT _gen_attr_$id) Filter (WHERE _gen_attr_$id is not null). the grouping // expression will be [key, _gen_attr_$id]. createAggregate( - groupingExpressions = groupingExpressions ++ normalizedNamedDistinctExpressions ++ - distinctFilterAttributes, + groupingExpressions = groupingExpressions ++ normalizedNamedDistinctExpressions, aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, resultExpressions = groupingAttributes ++ distinctAttributes ++ - distinctFilterAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), child = child) } @@ -173,13 +170,11 @@ object AggUtils { createAggregate( requiredChildDistributionExpressions = Some(groupingAttributes ++ distinctAttributes), - groupingExpressions = groupingAttributes ++ distinctAttributes ++ - distinctFilterAttributes, + groupingExpressions = groupingAttributes ++ distinctAttributes, aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length, resultExpressions = groupingAttributes ++ distinctAttributes ++ - distinctFilterAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), child = partialAggregate) } @@ -210,7 +205,9 @@ object AggUtils { // its input will have distinct arguments. // We just keep the isDistinct setting to true, so when users look at the query plan, // they still can see distinct aggregations. - val filter = functionsWithDistinct(i).filter + val filter = functionsWithDistinct(i).filter.map(_.transform { + case a: Attribute => distinctColumnAttributeLookup.getOrElse(a, a) + }) val expr = AggregateExpression(func, Partial, isDistinct = true, filter) // Use original AggregationFunction to lookup attributes, which is used to build // aggregateFunctionToAttribute From 2253499de077b26aa77c647af3c514abfba4483e Mon Sep 17 00:00:00 2001 From: beliefer Date: Sat, 18 Jul 2020 10:00:36 +0800 Subject: [PATCH 09/11] Revert comments. --- .../apache/spark/sql/execution/aggregate/AggUtils.scala | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 1cd4bef01851e..767417fa5b4fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -148,12 +148,8 @@ object AggUtils { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) // We will group by the original grouping expression, plus an additional expression for the - // DISTINCT column and the referred attributes in the FILTER clause associated with each - // aggregate function. For example, - // 1.for AVG(DISTINCT value) GROUP BY key, the grouping expressions will be [key, value]. - // 2.for AVG (DISTINCT value) Filter (WHERE age > 20) GROUP BY key, this will be rewritten - // as AVG (DISTINCT _gen_attr_$id) Filter (WHERE _gen_attr_$id is not null). the grouping - // expression will be [key, _gen_attr_$id]. + // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping + // expressions will be [key, value]. createAggregate( groupingExpressions = groupingExpressions ++ normalizedNamedDistinctExpressions, aggregateExpressions = aggregateExpressions, From 715958229de1f9193b18f8955ac6064403046b15 Mon Sep 17 00:00:00 2001 From: beliefer Date: Sat, 18 Jul 2020 11:48:18 +0800 Subject: [PATCH 10/11] Update comments. --- .../optimizer/ProjectFilterInAggregates.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ProjectFilterInAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ProjectFilterInAggregates.scala index 320b7fe8e5c3f..540a24ca45bf2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ProjectFilterInAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ProjectFilterInAggregates.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.rules.Rule * COUNT(DISTINCT cat1) AS cat1_cnt, * COUNT(DISTINCT cat2) FILTER (WHERE id > 1) AS cat2_cnt, * SUM(value) AS total, - * SUM(value) FILTER (WHERE key = "a") AS total2 + * SUM(value) FILTER (WHERE key = "a") AS total2 * FROM * data * GROUP BY @@ -51,9 +51,9 @@ import org.apache.spark.sql.catalyst.rules.Rule * Aggregate( * key = ['key] * functions = [COUNT(DISTINCT 'cat1), - * COUNT(DISTINCT 'cat2) with FILTER('id > 1), + * COUNT(DISTINCT 'cat2) with FILTER('id > 1), * SUM('value), - * SUM('value) with FILTER('key = "a")] + * SUM('value) with FILTER('key = "a")] * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total, 'total2]) * LocalTableScan [...] * }}} @@ -62,15 +62,15 @@ import org.apache.spark.sql.catalyst.rules.Rule * {{{ * Aggregate( * key = ['key] - * functions = [count('_gen_attr_1), - * count('_gen_attr_2) with FILTER('_gen_attr_2 is not null), - * sum('_gen_attr_3), - * sum('_gen_attr_4) with FILTER('_gen_attr_4 is not null)] + * functions = [COUNT(DISTINCT '_gen_attr_1), + * COUNT(DISTINCT '_gen_attr_2) with FILTER('_gen_attr_2 is not null), + * SUM('_gen_attr_3), + * SUM('_gen_attr_4) with FILTER('_gen_attr_4 is not null)] * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total, 'total2]) * Project( * projectionList = ['key, * 'cat1, - * if ('id > 1) 'cat2 else null, + * if ('id > 1) 'cat2 else null, * cast('value as bigint), * if ('key = "a") cast('value as bigint) else null] * output = ['key, '_gen_attr_1, '_gen_attr_2, '_gen_attr_3, '_gen_attr_4]) From ba7c3a4d3ad5827f03c09356f93079732284e29d Mon Sep 17 00:00:00 2001 From: beliefer Date: Sat, 18 Jul 2020 11:49:36 +0800 Subject: [PATCH 11/11] Update comments. --- .../sql/catalyst/optimizer/ProjectFilterInAggregates.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ProjectFilterInAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ProjectFilterInAggregates.scala index 540a24ca45bf2..b599001bb8b56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ProjectFilterInAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ProjectFilterInAggregates.scala @@ -70,7 +70,7 @@ import org.apache.spark.sql.catalyst.rules.Rule * Project( * projectionList = ['key, * 'cat1, - * if ('id > 1) 'cat2 else null, + * if ('id > 1) 'cat2 else null, * cast('value as bigint), * if ('key = "a") cast('value as bigint) else null] * output = ['key, '_gen_attr_1, '_gen_attr_2, '_gen_attr_3, '_gen_attr_4])