From 2dc9db472a5369c958aad3a89309707a4c62dd74 Mon Sep 17 00:00:00 2001 From: beliefer Date: Sat, 1 Feb 2020 19:26:14 +0800 Subject: [PATCH 01/21] Support distinct with filter --- .../sql/catalyst/analysis/Analyzer.scala | 12 +- .../optimizer/RewriteDistinctAggregates.scala | 177 ++++++++++++++++-- .../analysis/AnalysisErrorSuite.scala | 5 - .../sql-tests/inputs/group-by-filter.sql | 39 ++-- .../inputs/postgreSQL/aggregates_part3.sql | 7 +- .../inputs/postgreSQL/groupingsets.sql | 5 +- 6 files changed, 196 insertions(+), 49 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3fd5039a4f116..01bd3edd290c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1814,15 +1814,9 @@ class Analyzer( } // We get an aggregate function, we need to wrap it in an AggregateExpression. case agg: AggregateFunction => - // TODO: SPARK-30276 Support Filter expression allows simultaneous use of DISTINCT - if (filter.isDefined) { - if (isDistinct) { - failAnalysis("DISTINCT and FILTER cannot be used in aggregate functions " + - "at the same time") - } else if (!filter.get.deterministic) { - failAnalysis("FILTER expression is non-deterministic, " + - "it cannot be used in aggregate functions") - } + if (filter.isDefined && !filter.get.deterministic) { + failAnalysis("FILTER expression is non-deterministic, " + + "it cannot be used in aggregate functions") } AggregateExpression(agg, Complete, isDistinct, filter) // This function is not an aggregate function, just return the resolved one. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index e5571069a7c41..dc05bb363e356 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -31,10 +31,10 @@ import org.apache.spark.sql.types.IntegerType * First example: query without filter clauses (in scala): * {{{ * val data = Seq( - * ("a", "ca1", "cb1", 10), - * ("a", "ca1", "cb2", 5), - * ("b", "ca1", "cb1", 13)) - * .toDF("key", "cat1", "cat2", "value") + * (1, "a", "ca1", "cb1", 10), + * (2, "a", "ca1", "cb2", 5), + * (3, "b", "ca1", "cb1", 13)) + * .toDF("id", "key", "cat1", "cat2", "value") * data.createOrReplaceTempView("data") * * val agg = data.groupBy($"key") @@ -118,7 +118,66 @@ import org.apache.spark.sql.types.IntegerType * LocalTableScan [...] * }}} * - * The rule does the following things here: + * Third example: single distinct aggregate function with filter clauses (in sql): + * {{{ + * SELECT + * COUNT(DISTINCT cat1) FILTER (WHERE id > 1) as cat1_cnt1, + * COUNT(DISTINCT cat1) as cat1_cnt2, + * SUM(value) AS total + * FROM + * data + * GROUP BY + * key + * }}} + * + * This translates to the following (pseudo) logical plan: + * {{{ + * Aggregate( + * key = ['key] + * functions = [COUNT(DISTINCT 'cat1) with FILTER('id > 1), + * COUNT(DISTINCT 'cat1), + * sum('value)] + * output = ['key, 'cat1_cnt1, 'cat1_cnt2, 'total]) + * LocalTableScan [...] + * }}} + * + * This rule rewrites this logical plan to the following (pseudo) logical plan: + * {{{ + * Aggregate( + * key = ['key] + * functions = [count(if (('gid = 1)) '_gen_distinct_1 else null), + * count(if (('gid = 2)) '_gen_distinct_2 else null), + * first(if (('gid = 0)) 'total else null) ignore nulls] + * output = ['key, 'cat1_cnt1, 'cat1_cnt2, 'total]) + * Aggregate( + * key = ['key, '_gen_distinct_1, '_gen_distinct_2, 'gid] + * functions = [sum('value)] + * output = ['key, '_gen_distinct_1, '_gen_distinct_2, 'gid, 'total]) + * Expand( + * projections = [('key, null, null, 0, 'value), + * ('key, '_gen_distinct_1, null, 1, null), + * ('key, null, '_gen_distinct_2, 2, null)] + * output = ['key, '_gen_distinct_1, '_gen_distinct_2, 'gid, 'value]) + * Expand( + * projections = [('key, if ('id > 1) 'cat1 else null, 'cat1, cast('value as bigint))] + * output = ['key, '_gen_distinct_1, '_gen_distinct_2, 'value]) + * LocalTableScan [...] + * }}} + * + * The rule serves two purposes: + * 1. Expand distinct aggregates which exists filter clause. + * 2. Rewrite when aggregate exists at least two distinct aggregates. + * + * The first child rule does the following things here: + * 1. Guaranteed to compute filter clause locally. + * 2. The attributes referenced by different distinct aggregate expressions are likely to overlap, + * and if no additional processing is performed, data loss will occur. To prevent this, we + * generate new attributes and replace the original ones. + * 3. If we apply the first rule to distinct aggregate expressions which exists filter + * clause, the aggregate after expand may have at least two distinct aggregates, so we need to + * apply the second rule too. + * + * The second child rule does the following things here: * 1. Expand the data. There are three aggregation groups in this query: * i. the non-distinct group; * ii. the distinct 'cat1 group; @@ -148,24 +207,106 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { val distinctAggs = exprs.flatMap { _.collect { case ae: AggregateExpression if ae.isDistinct => ae }} - // We need at least two distinct aggregates for this rule because aggregation - // strategy can handle a single distinct group. + // This rule serves two purposes: + // One is to rewrite when there exists at least two distinct aggregates. We need at least + // two distinct aggregates for this rule because aggregation strategy can handle a single + // distinct group. + // Another is to expand distinct aggregates which exists filter clause so that we can + // evaluate the filter locally. // This check can produce false-positives, e.g., SUM(DISTINCT a) & COUNT(DISTINCT a). - distinctAggs.size > 1 + distinctAggs.size >= 1 || distinctAggs.exists(_.filter.isDefined) } def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case a: Aggregate if mayNeedtoRewrite(a.aggregateExpressions) => rewrite(a) + case a: Aggregate if mayNeedtoRewrite(a.aggregateExpressions) => + val expandAggregate = extractFiltersInDistinctAggregate(a) + rewriteDistinctAggregate(expandAggregate) } - def rewrite(a: Aggregate): Aggregate = { + private def extractFiltersInDistinctAggregate(a: Aggregate): Aggregate = { + val aggExpressions = collectAggregateExprs(a) + val (distinctAggExpressions, regularAggExpressions) = aggExpressions.partition(_.isDistinct) + if (distinctAggExpressions.exists(_.filter.isDefined)) { + // Setup expand for the 'regular' aggregate expressions. Because we will construct a new + // aggregate, the children of the distinct aggregates will be changed to the generate + // ones, so we need creates new references to avoid collisions between distinct and + // regular aggregate children. + val regularAggExprs = regularAggExpressions.filter(_.children.exists(!_.foldable)) + val regularFunChildren = regularAggExprs + .flatMap(_.aggregateFunction.children.filter(!_.foldable)) + val regularFilterAttrs = regularAggExprs.flatMap(_.filterAttributes) + val regularAggChildren = (regularFunChildren ++ regularFilterAttrs).distinct + val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair) + val regularAggChildAttrLookup = regularAggChildAttrMap.toMap + val regularAggMap = regularAggExprs.map { + case ae @ AggregateExpression(af, _, _, filter, _) => + val newChildren = af.children.map(c => regularAggChildAttrLookup.getOrElse(c, c)) + val raf = af.withNewChildren(newChildren).asInstanceOf[AggregateFunction] + val filterOpt = filter.map(_.transform { + case a: Attribute => regularAggChildAttrLookup.getOrElse(a, a) + }) + val aggExpr = ae.copy(aggregateFunction = raf, filter = filterOpt) + (ae, aggExpr) + } - // Collect all aggregate expressions. - val aggExpressions = a.aggregateExpressions.flatMap { e => - e.collect { - case ae: AggregateExpression => ae + // Setup expand for the 'distinct' aggregate expressions. + val distinctAggExprs = distinctAggExpressions.filter(e => e.children.exists(!_.foldable)) + val (projections, expressionAttrs, aggExprPairs) = distinctAggExprs.map { + case ae @ AggregateExpression(af, _, _, filter, _) => + // Why do we need to construct the `exprId` ? + // First, In order to reduce costs, it is better to handle the filter clause locally. + // e.g. COUNT (DISTINCT a) FILTER (WHERE id > 1), evaluate expression + // If(id > 1) 'a else null first, and use the result as output. + // Second, If at least two DISTINCT aggregate expression which may references the + // same attributes. We need to construct the generated attributes so as the output not + // lost. e.g. SUM (DISTINCT a), COUNT (DISTINCT a) FILTER (WHERE id > 1) will output + // attribute '_gen_distinct-1 and attribute '_gen_distinct-2 instead of two 'a. + // Note: We just need to illusion the expression with filter clause. + // The illusionary mechanism may result in multiple distinct aggregations uses + // different column, so we still need to call `rewrite`. + val exprId = NamedExpression.newExprId.id + val unfoldableChildren = af.children.filter(!_.foldable) + val exprAttrs = unfoldableChildren.map { e => + (e, AttributeReference(s"_gen_distinct_$exprId", e.dataType, nullable = true)()) + } + val exprAttrLookup = exprAttrs.toMap + val newChildren = af.children.map(c => exprAttrLookup.getOrElse(c, c)) + val raf = af.withNewChildren(newChildren).asInstanceOf[AggregateFunction] + val aggExpr = ae.copy(aggregateFunction = raf, filter = None) + // Expand projection + val projection = unfoldableChildren.map { + case e if filter.isDefined => If(filter.get, e, nullify(e)) + case e => e + } + (projection, exprAttrs, (ae, aggExpr)) + }.unzip3 + val distinctAggChildAttrs = expressionAttrs.flatten.map(_._2) + val allAggAttrs = regularAggChildAttrMap.map(_._2) ++ distinctAggChildAttrs + // Construct the aggregate input projection. + val rewriteAggProjections = + Seq(a.groupingExpressions ++ regularAggChildren ++ projections.flatten) + val groupByMap = a.groupingExpressions.collect { + case ne: NamedExpression => ne -> ne.toAttribute + case e => e -> AttributeReference(e.sql, e.dataType, e.nullable)() } + val groupByAttrs = groupByMap.map(_._2) + // Construct the expand operator. + val expand = Expand(rewriteAggProjections, groupByAttrs ++ allAggAttrs, a.child) + val rewriteAggExprLookup = (aggExprPairs ++ regularAggMap).toMap + val patchedAggExpressions = a.aggregateExpressions.map { e => + e.transformDown { + case ae: AggregateExpression => rewriteAggExprLookup.getOrElse(ae, ae) + }.asInstanceOf[NamedExpression] + } + val expandAggregate = Aggregate(groupByAttrs, patchedAggExpressions, expand) + expandAggregate + } else { + a } + } + + private def rewriteDistinctAggregate(a: Aggregate): Aggregate = { + val aggExpressions = collectAggregateExprs(a) // Extract distinct aggregate expressions. val distinctAggGroups = aggExpressions.filter(_.isDistinct).groupBy { e => @@ -331,6 +472,14 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } } + private def collectAggregateExprs(a: Aggregate): Seq[AggregateExpression] = { + a.aggregateExpressions.flatMap { e => + e.collect { + case ae: AggregateExpression => ae + } + } + } + private def nullify(e: Expression) = Literal.create(null, e.dataType) private def expressionAttributePair(e: Expression) = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 7023dbe2a3672..6966fc32787d0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -190,11 +190,6 @@ class AnalysisErrorSuite extends AnalysisTest { "FILTER (WHERE c > 1)"), "FILTER predicate specified, but aggregate is not an aggregate function" :: Nil) - errorTest( - "DISTINCT and FILTER cannot be used in aggregate functions at the same time", - CatalystSqlParser.parsePlan("SELECT count(DISTINCT a) FILTER (WHERE c > 1) FROM TaBlE2"), - "DISTINCT and FILTER cannot be used in aggregate functions at the same time" :: Nil) - errorTest( "FILTER expression is non-deterministic, it cannot be used in aggregate functions", CatalystSqlParser.parsePlan("SELECT count(a) FILTER (WHERE rand(int(c)) > 1) FROM TaBlE2"), diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql index beb5b9e5fe516..14f0eb70657da 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql @@ -33,8 +33,10 @@ SELECT COUNT(id) FILTER (WHERE hiredate = date "2001-01-01") FROM emp; SELECT COUNT(id) FILTER (WHERE hiredate = to_date('2001-01-01 00:00:00')) FROM emp; SELECT COUNT(id) FILTER (WHERE hiredate = to_timestamp("2001-01-01 00:00:00")) FROM emp; SELECT COUNT(id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") = "2001-01-01") FROM emp; --- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT --- SELECT COUNT(DISTINCT id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") = "2001-01-01 00:00:00") FROM emp; +SELECT COUNT(DISTINCT id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") = "2001-01-01 00:00:00") FROM emp; +SELECT COUNT(DISTINCT id), COUNT(DISTINCT id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") = "2001-01-01 00:00:00") FROM emp; +SELECT COUNT(DISTINCT id) FILTER (WHERE hiredate = to_timestamp("2001-01-01 00:00:00")), COUNT(DISTINCT id) FILTER (WHERE hiredate = to_date('2001-01-01 00:00:00')) FROM emp; +SELECT SUM(salary), COUNT(DISTINCT id), COUNT(DISTINCT id) FILTER (WHERE hiredate = date "2001-01-01") FROM emp; -- Aggregate with filter and non-empty GroupBy expressions. SELECT a, COUNT(b) FILTER (WHERE a >= 2) FROM testData GROUP BY a; @@ -44,8 +46,10 @@ SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > date "2003-01-01") FROM emp SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > to_date("2003-01-01")) FROM emp GROUP BY dept_id; SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > to_timestamp("2003-01-01 00:00:00")) FROM emp GROUP BY dept_id; SELECT dept_id, SUM(salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") > "2003-01-01") FROM emp GROUP BY dept_id; --- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT --- SELECT dept_id, SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") > "2001-01-01 00:00:00") FROM emp GROUP BY dept_id; +SELECT dept_id, SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") > "2001-01-01 00:00:00") FROM emp GROUP BY dept_id; +SELECT dept_id, SUM(DISTINCT salary), SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") > "2001-01-01 00:00:00") FROM emp GROUP BY dept_id; +SELECT dept_id, SUM(DISTINCT salary) FILTER (WHERE hiredate > date "2001-01-01"), SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") > "2001-01-01 00:00:00") FROM emp GROUP BY dept_id; +SELECT dept_id, COUNT(id), SUM(DISTINCT salary), SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") > "2001-01-01") FROM emp GROUP BY dept_id; -- Aggregate with filter and grouped by literals. SELECT 'foo', COUNT(a) FILTER (WHERE b <= 2) FROM testData GROUP BY 1; @@ -58,13 +62,21 @@ select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary), select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary), sum(salary) filter (where id + dept_id > 500) from emp group by dept_id; select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id > 200) from emp group by dept_id; select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id + dept_id > 500) from emp group by dept_id; --- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT --- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate), sum(salary) from emp group by dept_id; --- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id; --- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D) from emp group by dept_id; --- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id > 200) from emp group by dept_id; --- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name), sum(salary) from emp group by dept_id; --- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id; +select dept_id, count(distinct emp_name) filter (where id > 200), sum(salary) from emp group by dept_id; +select dept_id, count(distinct emp_name) filter (where id + dept_id > 500), sum(salary) from emp group by dept_id; +select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id > 200), sum(salary) from emp group by dept_id; +select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id + dept_id > 500), sum(salary) from emp group by dept_id; +select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id > 200), sum(salary), sum(salary) filter (where id > 200) from emp group by dept_id; +select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id + dept_id > 500), sum(salary), sum(salary) filter (where id > 200) from emp group by dept_id; +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate), sum(salary) from emp group by dept_id; +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id; +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D) from emp group by dept_id; +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id > 200) from emp group by dept_id; +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name), sum(salary) from emp group by dept_id; +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id; +select dept_id, sum(distinct (id + dept_id))) filter (where id > 200), count(distinct hiredate), sum(salary) from emp group by dept_id; +select dept_id, sum(distinct (id + dept_id)) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id; +select dept_id, avg(distinct (id + dept_id)) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D) from emp group by dept_id; -- Aggregate with filter and grouped by literals (hash aggregate), here the input table is filtered using WHERE. SELECT 'foo', APPROX_COUNT_DISTINCT(a) FILTER (WHERE b >= 0) FROM testData WHERE a = 0 GROUP BY 1; @@ -78,9 +90,8 @@ SELECT a + 2, COUNT(b) FILTER (WHERE b IN (1, 2)) FROM testData GROUP BY a + 1; SELECT a + 1 + 1, COUNT(b) FILTER (WHERE b > 0) FROM testData GROUP BY a + 1; -- Aggregate with filter, foldable input and multiple distinct groups. --- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT --- SELECT COUNT(DISTINCT b) FILTER (WHERE b > 0), COUNT(DISTINCT b, c) FILTER (WHERE b > 0 AND c > 2) --- FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a; +SELECT COUNT(DISTINCT b) FILTER (WHERE b > 0), COUNT(DISTINCT b, c) FILTER (WHERE b > 0 AND c > 2) +FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a; -- Check analysis exceptions SELECT a AS k, COUNT(b) FILTER (WHERE b > 0) FROM testData GROUP BY k; diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql index 746b677234832..657ea59ec8f11 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql @@ -241,10 +241,9 @@ select sum(1/ten) filter (where ten > 0) from tenk1; -- select ten, sum(distinct four) filter (where four::text ~ '123') from onek a -- group by ten; --- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT --- select ten, sum(distinct four) filter (where four > 10) from onek a --- group by ten --- having exists (select 1 from onek b where sum(distinct a.four) = b.four); +select ten, sum(distinct four) filter (where four > 10) from onek a +group by ten +having exists (select 1 from onek b where sum(distinct a.four) = b.four); -- [SPARK-28682] ANSI SQL: Collation Support -- select max(foo COLLATE "C") filter (where (bar collate "POSIX") > '0') diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/groupingsets.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/groupingsets.sql index fc54d179f742c..45617c53166aa 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/groupingsets.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/groupingsets.sql @@ -336,9 +336,8 @@ order by 2,1; -- order by 2,1; -- FILTER queries --- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT --- select ten, sum(distinct four) filter (where string(four) like '123') from onek a --- group by rollup(ten); +select ten, sum(distinct four) filter (where string(four) like '123') from onek a +group by rollup(ten); -- More rescan tests -- [SPARK-27877] ANSI SQL: LATERAL derived table(T491) From 6a32d83358ce871715c2a3587b0cbac70f384f4e Mon Sep 17 00:00:00 2001 From: beliefer Date: Sat, 1 Feb 2020 22:07:11 +0800 Subject: [PATCH 02/21] Add results of test case --- .../sql-tests/results/group-by-filter.sql.out | 291 +++++++++++++++++- .../postgreSQL/aggregates_part3.sql.out | 16 +- .../results/postgreSQL/groupingsets.sql.out | 21 +- 3 files changed, 325 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out index a032678e90fe8..079b651445f62 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 37 +-- Number of queries: 61 -- !query @@ -94,6 +94,38 @@ struct +-- !query output +2 + + +-- !query +SELECT COUNT(DISTINCT id), COUNT(DISTINCT id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") = "2001-01-01 00:00:00") FROM emp +-- !query schema +struct +-- !query output +8 2 + + +-- !query +SELECT COUNT(DISTINCT id) FILTER (WHERE hiredate = to_timestamp("2001-01-01 00:00:00")), COUNT(DISTINCT id) FILTER (WHERE hiredate = to_date('2001-01-01 00:00:00')) FROM emp +-- !query schema +struct +-- !query output +2 2 + + +-- !query +SELECT SUM(salary), COUNT(DISTINCT id), COUNT(DISTINCT id) FILTER (WHERE hiredate = date "2001-01-01") FROM emp +-- !query schema +struct +-- !query output +2450.0 8 2 + + -- !query SELECT a, COUNT(b) FILTER (WHERE a >= 2) FROM testData GROUP BY a -- !query schema @@ -177,6 +209,58 @@ struct "2001-01-01 00:00:00") FROM emp GROUP BY dept_id +-- !query schema +struct 2001-01-01 00:00:00)):double> +-- !query output +10 300.0 +100 400.0 +20 300.0 +30 400.0 +70 150.0 +NULL NULL + + +-- !query +SELECT dept_id, SUM(DISTINCT salary), SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") > "2001-01-01 00:00:00") FROM emp GROUP BY dept_id +-- !query schema +struct 2001-01-01 00:00:00)):double> +-- !query output +10 300.0 300.0 +100 400.0 400.0 +20 300.0 300.0 +30 400.0 400.0 +70 150.0 150.0 +NULL 400.0 NULL + + +-- !query +SELECT dept_id, SUM(DISTINCT salary) FILTER (WHERE hiredate > date "2001-01-01"), SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") > "2001-01-01 00:00:00") FROM emp GROUP BY dept_id +-- !query schema +struct DATE '2001-01-01')):double,sum(DISTINCT salary) FILTER (WHERE (date_format(CAST(hiredate AS TIMESTAMP), yyyy-MM-dd HH:mm:ss) > 2001-01-01 00:00:00)):double> +-- !query output +10 300.0 300.0 +100 400.0 400.0 +20 300.0 300.0 +30 400.0 400.0 +70 150.0 150.0 +NULL NULL NULL + + +-- !query +SELECT dept_id, COUNT(id), SUM(DISTINCT salary), SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") > "2001-01-01") FROM emp GROUP BY dept_id +-- !query schema +struct 2001-01-01)):double> +-- !query output +10 3 300.0 300.0 +100 2 400.0 400.0 +20 1 300.0 300.0 +30 1 400.0 400.0 +70 1 150.0 150.0 +NULL 1 400.0 NULL + + -- !query SELECT 'foo', COUNT(a) FILTER (WHERE b <= 2) FROM testData GROUP BY 1 -- !query schema @@ -261,6 +345,202 @@ struct 200), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,sum(salary):double> +-- !query output +10 0 400.0 +100 2 800.0 +20 1 300.0 +30 1 400.0 +70 1 150.0 +NULL 1 400.0 + + +-- !query +select dept_id, count(distinct emp_name) filter (where id + dept_id > 500), sum(salary) from emp group by dept_id +-- !query schema +struct 500)):bigint,sum(salary):double> +-- !query output +10 0 400.0 +100 2 800.0 +20 0 300.0 +30 0 400.0 +70 1 150.0 +NULL 0 400.0 + + +-- !query +select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id > 200), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,sum(salary):double> +-- !query output +10 2 0 400.0 +100 2 2 800.0 +20 1 1 300.0 +30 1 1 400.0 +70 1 1 150.0 +NULL 1 1 400.0 + + +-- !query +select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id + dept_id > 500), sum(salary) from emp group by dept_id +-- !query schema +struct 500)):bigint,sum(salary):double> +-- !query output +10 2 0 400.0 +100 2 2 800.0 +20 1 0 300.0 +30 1 0 400.0 +70 1 1 150.0 +NULL 1 0 400.0 + + +-- !query +select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id > 200), sum(salary), sum(salary) filter (where id > 200) from emp group by dept_id +-- !query schema +struct 200)):bigint,sum(salary):double,sum(salary) FILTER (WHERE (id > 200)):double> +-- !query output +10 2 0 400.0 NULL +100 2 2 800.0 800.0 +20 1 1 300.0 300.0 +30 1 1 400.0 400.0 +70 1 1 150.0 150.0 +NULL 1 1 400.0 400.0 + + +-- !query +select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id + dept_id > 500), sum(salary), sum(salary) filter (where id > 200) from emp group by dept_id +-- !query schema +struct 500)):bigint,sum(salary):double,sum(salary) FILTER (WHERE (id > 200)):double> +-- !query output +10 2 0 400.0 NULL +100 2 2 800.0 800.0 +20 1 0 300.0 300.0 +30 1 0 400.0 400.0 +70 1 1 150.0 150.0 +NULL 1 0 400.0 400.0 + + +-- !query +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,count(DISTINCT hiredate):bigint,sum(salary):double> +-- !query output +10 0 2 400.0 +100 2 2 800.0 +20 1 1 300.0 +30 1 1 400.0 +70 1 1 150.0 +NULL 1 1 400.0 + + +-- !query +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,count(DISTINCT hiredate) FILTER (WHERE (hiredate > DATE '2003-01-01')):bigint,sum(salary):double> +-- !query output +10 0 1 400.0 +100 2 1 800.0 +20 1 0 300.0 +30 1 1 400.0 +70 1 1 150.0 +NULL 1 0 400.0 + + +-- !query +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D) from emp group by dept_id +-- !query schema +struct 200)):bigint,count(DISTINCT hiredate) FILTER (WHERE (hiredate > DATE '2003-01-01')):bigint,sum(salary) FILTER (WHERE (salary < 400.0)):double> +-- !query output +10 0 1 400.0 +100 2 1 NULL +20 1 0 300.0 +30 1 1 NULL +70 1 1 150.0 +NULL 1 0 NULL + + +-- !query +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id > 200) from emp group by dept_id +-- !query schema +struct 200)):bigint,count(DISTINCT hiredate) FILTER (WHERE (hiredate > DATE '2003-01-01')):bigint,sum(salary) FILTER (WHERE (salary < 400.0)):double,sum(salary) FILTER (WHERE (id > 200)):double> +-- !query output +10 0 1 400.0 NULL +100 2 1 NULL 800.0 +20 1 0 300.0 300.0 +30 1 1 NULL 400.0 +70 1 1 150.0 150.0 +NULL 1 0 NULL 400.0 + + +-- !query +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,count(DISTINCT emp_name):bigint,sum(salary):double> +-- !query output +10 0 2 400.0 +100 2 2 800.0 +20 1 1 300.0 +30 1 1 400.0 +70 1 1 150.0 +NULL 1 1 400.0 + + +-- !query +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,count(DISTINCT emp_name) FILTER (WHERE (hiredate > DATE '2003-01-01')):bigint,sum(salary):double> +-- !query output +10 0 1 400.0 +100 2 1 800.0 +20 1 0 300.0 +30 1 1 400.0 +70 1 1 150.0 +NULL 1 0 400.0 + + +-- !query +select dept_id, sum(distinct (id + dept_id))) filter (where id > 200), count(distinct hiredate), sum(salary) from emp group by dept_id +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException + +mismatched input ')' expecting {, '(', ',', '.', '[', 'ADD', 'AFTER', 'ALL', 'ALTER', 'ANALYZE', 'AND', 'ANTI', 'ANY', 'ARCHIVE', 'ARRAY', 'AS', 'ASC', 'AT', 'AUTHORIZATION', 'BETWEEN', 'BOTH', 'BUCKET', 'BUCKETS', 'BY', 'CACHE', 'CASCADE', 'CASE', 'CAST', 'CHANGE', 'CHECK', 'CLEAR', 'CLUSTER', 'CLUSTERED', 'CODEGEN', 'COLLATE', 'COLLECTION', 'COLUMN', 'COLUMNS', 'COMMENT', 'COMMIT', 'COMPACT', 'COMPACTIONS', 'COMPUTE', 'CONCATENATE', 'CONSTRAINT', 'COST', 'CREATE', 'CROSS', 'CUBE', 'CURRENT', 'CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP', 'CURRENT_USER', 'DATA', 'DATABASE', DATABASES, 'DAY', 'DBPROPERTIES', 'DEFINED', 'DELETE', 'DELIMITED', 'DESC', 'DESCRIBE', 'DFS', 'DIRECTORIES', 'DIRECTORY', 'DISTINCT', 'DISTRIBUTE', 'DROP', 'ELSE', 'END', 'ESCAPE', 'ESCAPED', 'EXCEPT', 'EXCHANGE', 'EXISTS', 'EXPLAIN', 'EXPORT', 'EXTENDED', 'EXTERNAL', 'EXTRACT', 'FALSE', 'FETCH', 'FIELDS', 'FILTER', 'FILEFORMAT', 'FIRST', 'FIRST_VALUE', 'FOLLOWING', 'FOR', 'FOREIGN', 'FORMAT', 'FORMATTED', 'FROM', 'FULL', 'FUNCTION', 'FUNCTIONS', 'GLOBAL', 'GRANT', 'GROUP', 'GROUPING', 'HAVING', 'HOUR', 'IF', 'IGNORE', 'IMPORT', 'IN', 'INDEX', 'INDEXES', 'INNER', 'INPATH', 'INPUTFORMAT', 'INSERT', 'INTERSECT', 'INTERVAL', 'INTO', 'IS', 'ITEMS', 'JOIN', 'KEYS', 'LAST', 'LAST_VALUE', 'LATERAL', 'LAZY', 'LEADING', 'LEFT', 'LIKE', 'LIMIT', 'LINES', 'LIST', 'LOAD', 'LOCAL', 'LOCATION', 'LOCK', 'LOCKS', 'LOGICAL', 'MACRO', 'MAP', 'MATCHED', 'MERGE', 'MINUTE', 'MONTH', 'MSCK', 'NAMESPACE', 'NAMESPACES', 'NATURAL', 'NO', NOT, 'NULL', 'NULLS', 'OF', 'ON', 'ONLY', 'OPTION', 'OPTIONS', 'OR', 'ORDER', 'OUT', 'OUTER', 'OUTPUTFORMAT', 'OVER', 'OVERLAPS', 'OVERLAY', 'OVERWRITE', 'PARTITION', 'PARTITIONED', 'PARTITIONS', 'PERCENT', 'PIVOT', 'PLACING', 'POSITION', 'PRECEDING', 'PRIMARY', 'PRINCIPALS', 'PROPERTIES', 'PURGE', 'QUERY', 'RANGE', 'RECORDREADER', 'RECORDWRITER', 'RECOVER', 'REDUCE', 'REFERENCES', 'REFRESH', 'RENAME', 'REPAIR', 'REPLACE', 'RESET', 'RESPECT', 'RESTRICT', 'REVOKE', 'RIGHT', RLIKE, 'ROLE', 'ROLES', 'ROLLBACK', 'ROLLUP', 'ROW', 'ROWS', 'SCHEMA', 'SECOND', 'SELECT', 'SEMI', 'SEPARATED', 'SERDE', 'SERDEPROPERTIES', 'SESSION_USER', 'SET', 'MINUS', 'SETS', 'SHOW', 'SKEWED', 'SOME', 'SORT', 'SORTED', 'START', 'STATISTICS', 'STORED', 'STRATIFY', 'STRUCT', 'SUBSTR', 'SUBSTRING', 'TABLE', 'TABLES', 'TABLESAMPLE', 'TBLPROPERTIES', TEMPORARY, 'TERMINATED', 'THEN', 'TO', 'TOUCH', 'TRAILING', 'TRANSACTION', 'TRANSACTIONS', 'TRANSFORM', 'TRIM', 'TRUE', 'TRUNCATE', 'TYPE', 'UNARCHIVE', 'UNBOUNDED', 'UNCACHE', 'UNION', 'UNIQUE', 'UNKNOWN', 'UNLOCK', 'UNSET', 'UPDATE', 'USE', 'USER', 'USING', 'VALUES', 'VIEW', 'WHEN', 'WHERE', 'WINDOW', 'WITH', 'YEAR', EQ, '<=>', '<>', '!=', '<', LTE, '>', GTE, '+', '-', '*', '/', '%', 'DIV', '&', '|', '||', '^', IDENTIFIER, BACKQUOTED_IDENTIFIER}(line 1, pos 44) + +== SQL == +select dept_id, sum(distinct (id + dept_id))) filter (where id > 200), count(distinct hiredate), sum(salary) from emp group by dept_id +--------------------------------------------^^^ + + +-- !query +select dept_id, sum(distinct (id + dept_id)) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,count(DISTINCT hiredate) FILTER (WHERE (hiredate > DATE '2003-01-01')):bigint,sum(salary):double> +-- !query output +10 NULL 1 400.0 +100 1500 1 800.0 +20 320 0 300.0 +30 430 1 400.0 +70 870 1 150.0 +NULL NULL 0 400.0 + + +-- !query +select dept_id, avg(distinct (id + dept_id)) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D) from emp group by dept_id +-- !query schema +struct 200)):double,count(DISTINCT hiredate) FILTER (WHERE (hiredate > DATE '2003-01-01')):bigint,sum(salary) FILTER (WHERE (salary < 400.0)):double> +-- !query output +10 NULL 1 400.0 +100 750.0 1 NULL +20 320.0 0 300.0 +30 430.0 1 NULL +70 870.0 1 150.0 +NULL NULL 0 NULL + + -- !query SELECT 'foo', APPROX_COUNT_DISTINCT(a) FILTER (WHERE b >= 0) FROM testData WHERE a = 0 GROUP BY 1 -- !query schema @@ -309,6 +589,15 @@ struct<((a + 1) + 1):int,count(b) FILTER (WHERE (b > 0)):bigint> NULL 1 +-- !query +SELECT COUNT(DISTINCT b) FILTER (WHERE b > 0), COUNT(DISTINCT b, c) FILTER (WHERE b > 0 AND c > 2) +FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a +-- !query schema +struct 0)):bigint,count(DISTINCT b, c) FILTER (WHERE ((b > 0) AND (c > 2))):bigint> +-- !query output +1 1 + + -- !query SELECT a AS k, COUNT(b) FILTER (WHERE b > 0) FROM testData GROUP BY k -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part3.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part3.sql.out index 69f96b02782e3..e1f735e5fe1dc 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part3.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part3.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 4 +-- Number of queries: 5 -- !query @@ -27,6 +27,20 @@ struct 0)):d 2828.9682539682954 +-- !query +select ten, sum(distinct four) filter (where four > 10) from onek a +group by ten +having exists (select 1 from onek b where sum(distinct a.four) = b.four) +-- !query schema +struct 10)):bigint> +-- !query output +0 NULL +2 NULL +4 NULL +6 NULL +8 NULL + + -- !query select (select count(*) from (values (1)) t0(inner_c)) diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/groupingsets.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/groupingsets.sql.out index 24fd9dcbfc826..1ee653ad67bb7 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/groupingsets.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/groupingsets.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 54 +-- Number of queries: 55 -- !query @@ -443,6 +443,25 @@ struct NULL 1 +-- !query +select ten, sum(distinct four) filter (where string(four) like '123') from onek a +group by rollup(ten) +-- !query schema +struct +-- !query output +0 NULL +1 NULL +2 NULL +3 NULL +4 NULL +5 NULL +6 NULL +7 NULL +8 NULL +9 NULL +NULL NULL + + -- !query select count(*) from gstest4 group by rollup(unhashable_col,unsortable_col) -- !query schema From 5c38bbee893fb684980c2b10ba0442311b5cedae Mon Sep 17 00:00:00 2001 From: beliefer Date: Sun, 2 Feb 2020 10:48:38 +0800 Subject: [PATCH 03/21] Optimize code --- .../spark/sql/catalyst/dsl/package.scala | 2 + .../expressions/aggregate/interfaces.scala | 16 +++-- .../optimizer/RewriteDistinctAggregates.scala | 70 ++++++++----------- .../RewriteDistinctAggregatesSuite.scala | 26 ++++++- 4 files changed, 69 insertions(+), 45 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 40998080bc4e3..aa2d6f59b677a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -170,6 +170,8 @@ package object dsl { def count(e: Expression): Expression = Count(e).toAggregateExpression() def countDistinct(e: Expression*): Expression = Count(e).toAggregateExpression(isDistinct = true) + def countDistinct(filter: Option[Expression], e: Expression*): Expression = + Count(e).toAggregateExpression(isDistinct = true, filter = filter) def approxCountDistinct(e: Expression, rsd: Double = 0.05): Expression = HyperLogLogPlusPlus(e, rsd).toAggregateExpression() def avg(e: Expression): Expression = Average(e).toAggregateExpression() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 222ad6fab19e0..40508e18b5a00 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -216,15 +216,21 @@ abstract class AggregateFunction extends Expression { def toAggregateExpression(): AggregateExpression = toAggregateExpression(isDistinct = false) /** - * Wraps this [[AggregateFunction]] in an [[AggregateExpression]] and sets `isDistinct` - * flag of the [[AggregateExpression]] to the given value because + * Wraps this [[AggregateFunction]] in an [[AggregateExpression]] with `isDistinct` + * flag and `filter` option of the [[AggregateExpression]] to the given value because * [[AggregateExpression]] is the container of an [[AggregateFunction]], aggregation mode, - * and the flag indicating if this aggregation is distinct aggregation or not. + * the flag indicating if this aggregation is distinct aggregation or not and filter option. * An [[AggregateFunction]] should not be used without being wrapped in * an [[AggregateExpression]]. */ - def toAggregateExpression(isDistinct: Boolean): AggregateExpression = { - AggregateExpression(aggregateFunction = this, mode = Complete, isDistinct = isDistinct) + def toAggregateExpression( + isDistinct: Boolean, + filter: Option[Expression] = None): AggregateExpression = { + AggregateExpression( + aggregateFunction = this, + mode = Complete, + isDistinct = isDistinct, + filter = filter) } def sql(isDistinct: Boolean): String = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index dc05bb363e356..62f74459937a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -121,8 +121,8 @@ import org.apache.spark.sql.types.IntegerType * Third example: single distinct aggregate function with filter clauses (in sql): * {{{ * SELECT - * COUNT(DISTINCT cat1) FILTER (WHERE id > 1) as cat1_cnt1, - * COUNT(DISTINCT cat1) as cat1_cnt2, + * COUNT(DISTINCT cat1) FILTER (WHERE id > 1) as cat1_cnt, + * COUNT(DISTINCT cat2) as cat2_cnt, * SUM(value) AS total * FROM * data @@ -135,9 +135,9 @@ import org.apache.spark.sql.types.IntegerType * Aggregate( * key = ['key] * functions = [COUNT(DISTINCT 'cat1) with FILTER('id > 1), - * COUNT(DISTINCT 'cat1), + * COUNT(DISTINCT 'cat2), * sum('value)] - * output = ['key, 'cat1_cnt1, 'cat1_cnt2, 'total]) + * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) * LocalTableScan [...] * }}} * @@ -148,36 +148,33 @@ import org.apache.spark.sql.types.IntegerType * functions = [count(if (('gid = 1)) '_gen_distinct_1 else null), * count(if (('gid = 2)) '_gen_distinct_2 else null), * first(if (('gid = 0)) 'total else null) ignore nulls] - * output = ['key, 'cat1_cnt1, 'cat1_cnt2, 'total]) + * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) * Aggregate( * key = ['key, '_gen_distinct_1, '_gen_distinct_2, 'gid] * functions = [sum('value)] * output = ['key, '_gen_distinct_1, '_gen_distinct_2, 'gid, 'total]) * Expand( * projections = [('key, null, null, 0, 'value), - * ('key, '_gen_distinct_1, null, 1, null), - * ('key, null, '_gen_distinct_2, 2, null)] + * ('key, '_gen_distinct_1, null, 1, null), + * ('key, null, '_gen_distinct_2, 2, null)] * output = ['key, '_gen_distinct_1, '_gen_distinct_2, 'gid, 'value]) * Expand( - * projections = [('key, if ('id > 1) 'cat1 else null, 'cat1, cast('value as bigint))] + * projections = [('key, if ('id > 1) 'cat1 else null, 'cat2, cast('value as bigint))] * output = ['key, '_gen_distinct_1, '_gen_distinct_2, 'value]) * LocalTableScan [...] * }}} * - * The rule serves two purposes: - * 1. Expand distinct aggregates which exists filter clause. - * 2. Rewrite when aggregate exists at least two distinct aggregates. + * The rule consists of the two phases as follows: * - * The first child rule does the following things here: - * 1. Guaranteed to compute filter clause locally. + * In the first phase, expands data for the distinct aggregates where filter clauses exist: + * 1. Guaranteed to compute filter clauses in the first aggregate locally. * 2. The attributes referenced by different distinct aggregate expressions are likely to overlap, * and if no additional processing is performed, data loss will occur. To prevent this, we * generate new attributes and replace the original ones. - * 3. If we apply the first rule to distinct aggregate expressions which exists filter - * clause, the aggregate after expand may have at least two distinct aggregates, so we need to - * apply the second rule too. + * 3. After generate new attributes, the aggregate may have at least two distinct aggregates, + * so we need the second phase too. * - * The second child rule does the following things here: + * In the second phase, rewrite a query with two or more distinct groups: * 1. Expand the data. There are three aggregation groups in this query: * i. the non-distinct group; * ii. the distinct 'cat1 group; @@ -207,30 +204,26 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { val distinctAggs = exprs.flatMap { _.collect { case ae: AggregateExpression if ae.isDistinct => ae }} - // This rule serves two purposes: - // One is to rewrite when there exists at least two distinct aggregates. We need at least - // two distinct aggregates for this rule because aggregation strategy can handle a single - // distinct group. - // Another is to expand distinct aggregates which exists filter clause so that we can - // evaluate the filter locally. + // We need at least two distinct aggregates or a single distinct aggregate with a filter for + // this rule because aggregation strategy can handle a single distinct group without a filter. // This check can produce false-positives, e.g., SUM(DISTINCT a) & COUNT(DISTINCT a). - distinctAggs.size >= 1 || distinctAggs.exists(_.filter.isDefined) + distinctAggs.size > 1 || distinctAggs.exists(_.filter.isDefined) } def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case a: Aggregate if mayNeedtoRewrite(a.aggregateExpressions) => - val expandAggregate = extractFiltersInDistinctAggregate(a) - rewriteDistinctAggregate(expandAggregate) + val expandAggregate = extractFiltersInDistinctAggregates(a) + rewriteDistinctAggregates(expandAggregate) } - private def extractFiltersInDistinctAggregate(a: Aggregate): Aggregate = { + private def extractFiltersInDistinctAggregates(a: Aggregate): Aggregate = { val aggExpressions = collectAggregateExprs(a) val (distinctAggExpressions, regularAggExpressions) = aggExpressions.partition(_.isDistinct) if (distinctAggExpressions.exists(_.filter.isDefined)) { - // Setup expand for the 'regular' aggregate expressions. Because we will construct a new - // aggregate, the children of the distinct aggregates will be changed to the generate - // ones, so we need creates new references to avoid collisions between distinct and - // regular aggregate children. + // Constructs pairs between old and new expressions for regular aggregates. Because we + // will construct a new `Aggregate` and the children of the distinct aggregates will be + // changed to generated ones, we need to create new references to avoid collisions between + // distinct and regular aggregate children. val regularAggExprs = regularAggExpressions.filter(_.children.exists(!_.foldable)) val regularFunChildren = regularAggExprs .flatMap(_.aggregateFunction.children.filter(!_.foldable)) @@ -238,7 +231,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { val regularAggChildren = (regularFunChildren ++ regularFilterAttrs).distinct val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair) val regularAggChildAttrLookup = regularAggChildAttrMap.toMap - val regularAggMap = regularAggExprs.map { + val regularAggPairs = regularAggExprs.map { case ae @ AggregateExpression(af, _, _, filter, _) => val newChildren = af.children.map(c => regularAggChildAttrLookup.getOrElse(c, c)) val raf = af.withNewChildren(newChildren).asInstanceOf[AggregateFunction] @@ -249,9 +242,9 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { (ae, aggExpr) } - // Setup expand for the 'distinct' aggregate expressions. + // Constructs pairs between old and new expressions for distinct aggregates, too. val distinctAggExprs = distinctAggExpressions.filter(e => e.children.exists(!_.foldable)) - val (projections, expressionAttrs, aggExprPairs) = distinctAggExprs.map { + val (projections, expressionAttrs, distinctAggPairs) = distinctAggExprs.map { case ae @ AggregateExpression(af, _, _, filter, _) => // Why do we need to construct the `exprId` ? // First, In order to reduce costs, it is better to handle the filter clause locally. @@ -261,9 +254,8 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // same attributes. We need to construct the generated attributes so as the output not // lost. e.g. SUM (DISTINCT a), COUNT (DISTINCT a) FILTER (WHERE id > 1) will output // attribute '_gen_distinct-1 and attribute '_gen_distinct-2 instead of two 'a. - // Note: We just need to illusion the expression with filter clause. - // The illusionary mechanism may result in multiple distinct aggregations uses - // different column, so we still need to call `rewrite`. + // Note: The illusionary mechanism may result in at least two distinct groups, so we + // still need to call `rewrite`. val exprId = NamedExpression.newExprId.id val unfoldableChildren = af.children.filter(!_.foldable) val exprAttrs = unfoldableChildren.map { e => @@ -292,7 +284,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { val groupByAttrs = groupByMap.map(_._2) // Construct the expand operator. val expand = Expand(rewriteAggProjections, groupByAttrs ++ allAggAttrs, a.child) - val rewriteAggExprLookup = (aggExprPairs ++ regularAggMap).toMap + val rewriteAggExprLookup = (distinctAggPairs ++ regularAggPairs).toMap val patchedAggExpressions = a.aggregateExpressions.map { e => e.transformDown { case ae: AggregateExpression => rewriteAggExprLookup.getOrElse(ae, ae) @@ -305,7 +297,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } } - private def rewriteDistinctAggregate(a: Aggregate): Aggregate = { + private def rewriteDistinctAggregates(a: Aggregate): Aggregate = { val aggExpressions = collectAggregateExprs(a) // Extract distinct aggregate expressions. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala index 8cb939e010c68..8391b650e41f7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala @@ -20,7 +20,7 @@ import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.{EqualTo, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan} @@ -42,6 +42,16 @@ class RewriteDistinctAggregatesSuite extends PlanTest { case _ => fail(s"Plan is not rewritten:\n$rewrite") } + private def checkGenerate(generate: LogicalPlan): Unit = generate match { + case Aggregate(_, _, _: Expand) => + case _ => fail(s"Plan is not generated:\n$generate") + } + + private def checkGenerateAndRewrite(rewrite: LogicalPlan): Unit = rewrite match { + case Aggregate(_, _, Aggregate(_, _, Expand(_, _, _: Expand))) => + case _ => fail(s"Plan is not rewritten:\n$rewrite") + } + test("single distinct group") { val input = testRelation .groupBy('a)(countDistinct('e)) @@ -50,6 +60,13 @@ class RewriteDistinctAggregatesSuite extends PlanTest { comparePlans(input, rewrite) } + test("single distinct group with filter") { + val input = testRelation + .groupBy('a)(countDistinct(Some(EqualTo('d, Literal(""))), 'e)) + .analyze + checkGenerate(RewriteDistinctAggregates(input)) + } + test("single distinct group with partial aggregates") { val input = testRelation .groupBy('a, 'd)( @@ -67,6 +84,13 @@ class RewriteDistinctAggregatesSuite extends PlanTest { checkRewrite(RewriteDistinctAggregates(input)) } + test("multiple distinct groups with filter") { + val input = testRelation + .groupBy('a)(countDistinct(Some(EqualTo('d, Literal(""))), 'b, 'c), countDistinct('d)) + .analyze + checkGenerateAndRewrite(RewriteDistinctAggregates(input)) + } + test("multiple distinct groups with partial aggregates") { val input = testRelation .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e)) From a6498f9bd1b1e8e1744cc9859d07632ebc1ceb14 Mon Sep 17 00:00:00 2001 From: beliefer Date: Sun, 2 Feb 2020 16:24:21 +0800 Subject: [PATCH 04/21] Fix incorrect sql --- .../sql-tests/inputs/group-by-filter.sql | 2 +- .../sql-tests/results/group-by-filter.sql.out | 17 ++++++++--------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql index 14f0eb70657da..f065c7cf2979a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql @@ -74,7 +74,7 @@ select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id > 200) from emp group by dept_id; select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name), sum(salary) from emp group by dept_id; select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id; -select dept_id, sum(distinct (id + dept_id))) filter (where id > 200), count(distinct hiredate), sum(salary) from emp group by dept_id; +select dept_id, sum(distinct (id + dept_id)) filter (where id > 200), count(distinct hiredate), sum(salary) from emp group by dept_id; select dept_id, sum(distinct (id + dept_id)) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id; select dept_id, avg(distinct (id + dept_id)) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D) from emp group by dept_id; diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out index 079b651445f62..00186b8173ff3 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out @@ -502,17 +502,16 @@ NULL 1 0 400.0 -- !query -select dept_id, sum(distinct (id + dept_id))) filter (where id > 200), count(distinct hiredate), sum(salary) from emp group by dept_id +select dept_id, sum(distinct (id + dept_id)) filter (where id > 200), count(distinct hiredate), sum(salary) from emp group by dept_id -- !query schema -struct<> +struct 200)):bigint,count(DISTINCT hiredate):bigint,sum(salary):double> -- !query output -org.apache.spark.sql.catalyst.parser.ParseException - -mismatched input ')' expecting {, '(', ',', '.', '[', 'ADD', 'AFTER', 'ALL', 'ALTER', 'ANALYZE', 'AND', 'ANTI', 'ANY', 'ARCHIVE', 'ARRAY', 'AS', 'ASC', 'AT', 'AUTHORIZATION', 'BETWEEN', 'BOTH', 'BUCKET', 'BUCKETS', 'BY', 'CACHE', 'CASCADE', 'CASE', 'CAST', 'CHANGE', 'CHECK', 'CLEAR', 'CLUSTER', 'CLUSTERED', 'CODEGEN', 'COLLATE', 'COLLECTION', 'COLUMN', 'COLUMNS', 'COMMENT', 'COMMIT', 'COMPACT', 'COMPACTIONS', 'COMPUTE', 'CONCATENATE', 'CONSTRAINT', 'COST', 'CREATE', 'CROSS', 'CUBE', 'CURRENT', 'CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP', 'CURRENT_USER', 'DATA', 'DATABASE', DATABASES, 'DAY', 'DBPROPERTIES', 'DEFINED', 'DELETE', 'DELIMITED', 'DESC', 'DESCRIBE', 'DFS', 'DIRECTORIES', 'DIRECTORY', 'DISTINCT', 'DISTRIBUTE', 'DROP', 'ELSE', 'END', 'ESCAPE', 'ESCAPED', 'EXCEPT', 'EXCHANGE', 'EXISTS', 'EXPLAIN', 'EXPORT', 'EXTENDED', 'EXTERNAL', 'EXTRACT', 'FALSE', 'FETCH', 'FIELDS', 'FILTER', 'FILEFORMAT', 'FIRST', 'FIRST_VALUE', 'FOLLOWING', 'FOR', 'FOREIGN', 'FORMAT', 'FORMATTED', 'FROM', 'FULL', 'FUNCTION', 'FUNCTIONS', 'GLOBAL', 'GRANT', 'GROUP', 'GROUPING', 'HAVING', 'HOUR', 'IF', 'IGNORE', 'IMPORT', 'IN', 'INDEX', 'INDEXES', 'INNER', 'INPATH', 'INPUTFORMAT', 'INSERT', 'INTERSECT', 'INTERVAL', 'INTO', 'IS', 'ITEMS', 'JOIN', 'KEYS', 'LAST', 'LAST_VALUE', 'LATERAL', 'LAZY', 'LEADING', 'LEFT', 'LIKE', 'LIMIT', 'LINES', 'LIST', 'LOAD', 'LOCAL', 'LOCATION', 'LOCK', 'LOCKS', 'LOGICAL', 'MACRO', 'MAP', 'MATCHED', 'MERGE', 'MINUTE', 'MONTH', 'MSCK', 'NAMESPACE', 'NAMESPACES', 'NATURAL', 'NO', NOT, 'NULL', 'NULLS', 'OF', 'ON', 'ONLY', 'OPTION', 'OPTIONS', 'OR', 'ORDER', 'OUT', 'OUTER', 'OUTPUTFORMAT', 'OVER', 'OVERLAPS', 'OVERLAY', 'OVERWRITE', 'PARTITION', 'PARTITIONED', 'PARTITIONS', 'PERCENT', 'PIVOT', 'PLACING', 'POSITION', 'PRECEDING', 'PRIMARY', 'PRINCIPALS', 'PROPERTIES', 'PURGE', 'QUERY', 'RANGE', 'RECORDREADER', 'RECORDWRITER', 'RECOVER', 'REDUCE', 'REFERENCES', 'REFRESH', 'RENAME', 'REPAIR', 'REPLACE', 'RESET', 'RESPECT', 'RESTRICT', 'REVOKE', 'RIGHT', RLIKE, 'ROLE', 'ROLES', 'ROLLBACK', 'ROLLUP', 'ROW', 'ROWS', 'SCHEMA', 'SECOND', 'SELECT', 'SEMI', 'SEPARATED', 'SERDE', 'SERDEPROPERTIES', 'SESSION_USER', 'SET', 'MINUS', 'SETS', 'SHOW', 'SKEWED', 'SOME', 'SORT', 'SORTED', 'START', 'STATISTICS', 'STORED', 'STRATIFY', 'STRUCT', 'SUBSTR', 'SUBSTRING', 'TABLE', 'TABLES', 'TABLESAMPLE', 'TBLPROPERTIES', TEMPORARY, 'TERMINATED', 'THEN', 'TO', 'TOUCH', 'TRAILING', 'TRANSACTION', 'TRANSACTIONS', 'TRANSFORM', 'TRIM', 'TRUE', 'TRUNCATE', 'TYPE', 'UNARCHIVE', 'UNBOUNDED', 'UNCACHE', 'UNION', 'UNIQUE', 'UNKNOWN', 'UNLOCK', 'UNSET', 'UPDATE', 'USE', 'USER', 'USING', 'VALUES', 'VIEW', 'WHEN', 'WHERE', 'WINDOW', 'WITH', 'YEAR', EQ, '<=>', '<>', '!=', '<', LTE, '>', GTE, '+', '-', '*', '/', '%', 'DIV', '&', '|', '||', '^', IDENTIFIER, BACKQUOTED_IDENTIFIER}(line 1, pos 44) - -== SQL == -select dept_id, sum(distinct (id + dept_id))) filter (where id > 200), count(distinct hiredate), sum(salary) from emp group by dept_id ---------------------------------------------^^^ +10 NULL 2 400.0 +100 1500 2 800.0 +20 320 1 300.0 +30 430 1 400.0 +70 870 1 150.0 +NULL NULL 1 400.0 -- !query From cd00f915cae0ba453903d6f7329ff11289801954 Mon Sep 17 00:00:00 2001 From: beliefer Date: Fri, 7 Feb 2020 20:06:10 +0800 Subject: [PATCH 05/21] Fix conflict --- .../spark/sql/catalyst/analysis/AnalysisErrorSuite.scala | 5 ----- 1 file changed, 5 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 5cc0453135c07..551dca7b3669a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -206,11 +206,6 @@ class AnalysisErrorSuite extends AnalysisTest { "FILTER (WHERE c > 1)"), "FILTER predicate specified, but aggregate is not an aggregate function" :: Nil) - errorTest( - "DISTINCT aggregate function with filter predicate", - CatalystSqlParser.parsePlan("SELECT count(DISTINCT a) FILTER (WHERE c > 1) FROM TaBlE2"), - "DISTINCT and FILTER cannot be used in aggregate functions at the same time" :: Nil) - errorTest( "non-deterministic filter predicate in aggregate functions", CatalystSqlParser.parsePlan("SELECT count(a) FILTER (WHERE rand(int(c)) > 1) FROM TaBlE2"), From 4a6f903897d28a3038918997e692410259a90ae3 Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Fri, 19 Jun 2020 10:36:52 +0800 Subject: [PATCH 06/21] Reuse completeNextStageWithFetchFailure --- .../apache/spark/scheduler/DAGSchedulerSuite.scala | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 9d412f2dba3ce..762b14e170fcc 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -1796,9 +1796,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // lets say there is a fetch failure in this task set, which makes us go back and // run stage 0, attempt 1 - complete(taskSets(1), Seq( - (FetchFailed(makeBlockManagerId("hostA"), - shuffleDep1.shuffleId, 0L, 0, 0, "ignored"), null))) + completeNextStageWithFetchFailure(1, 0, shuffleDep1) scheduler.resubmitFailedStages() // stage 0, attempt 1 should have the properties of job2 @@ -1872,9 +1870,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // have the second stage complete normally completeShuffleMapStageSuccessfully(1, 0, 1, Seq("hostA", "hostC")) // fail the third stage because hostA went down - complete(taskSets(2), Seq( - (FetchFailed(makeBlockManagerId("hostA"), - shuffleDepTwo.shuffleId, 0L, 0, 0, "ignored"), null))) + completeNextStageWithFetchFailure(2, 0, shuffleDepTwo) // TODO assert this: // blockManagerMaster.removeExecutor("hostA-exec") // have DAGScheduler try again @@ -1900,9 +1896,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // complete stage 1 completeShuffleMapStageSuccessfully(1, 0, 1) // pretend stage 2 failed because hostA went down - complete(taskSets(2), Seq( - (FetchFailed(makeBlockManagerId("hostA"), - shuffleDepTwo.shuffleId, 0L, 0, 0, "ignored"), null))) + completeNextStageWithFetchFailure(2, 0, shuffleDepTwo) // TODO assert this: // blockManagerMaster.removeExecutor("hostA-exec") // DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun. From a56f2b0fe7fa25b919fc88568760e34779cc0d87 Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Fri, 3 Jul 2020 11:26:55 +0800 Subject: [PATCH 07/21] Optimize comments --- .../sql/catalyst/expressions/aggregate/interfaces.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 698c53c697706..0e46dcc0ee3d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -217,10 +217,10 @@ abstract class AggregateFunction extends Expression { /** * Wraps this [[AggregateFunction]] in an [[AggregateExpression]] with `isDistinct` - * flag and `filter` option of the [[AggregateExpression]] to the given value because + * flag and an optional `filter` of the [[AggregateExpression]] to the given value because * [[AggregateExpression]] is the container of an [[AggregateFunction]], aggregation mode, - * the flag indicating if this aggregation is distinct aggregation or not and filter option. - * An [[AggregateFunction]] should not be used without being wrapped in + * the flag indicating if this aggregation is distinct aggregation or not and the optional + * `filter`. An [[AggregateFunction]] should not be used without being wrapped in * an [[AggregateExpression]]. */ def toAggregateExpression( From 7d6ada47062ea7d4afbfac10b12e578fb627ec2f Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Fri, 3 Jul 2020 16:45:56 +0800 Subject: [PATCH 08/21] Expand to Project --- .../optimizer/RewriteDistinctAggregates.scala | 68 ++++++++++++------- .../RewriteDistinctAggregatesSuite.scala | 6 +- 2 files changed, 48 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 62f74459937a5..1911527c808c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types.IntegerType @@ -229,7 +229,16 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { .flatMap(_.aggregateFunction.children.filter(!_.foldable)) val regularFilterAttrs = regularAggExprs.flatMap(_.filterAttributes) val regularAggChildren = (regularFunChildren ++ regularFilterAttrs).distinct - val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair) + val regularAggChildrenMap = regularAggChildren.map { + case ne: NamedExpression => ne -> ne + // If the expression is not a NamedExpressions, we add an alias. + // So, when we generate the result of the operator, the Aggregate Operator + // can directly get the Seq of attributes representing the grouping expressions. + case other => other -> Alias(other, other.toString)() + } + val regularAggChildAttrMap = regularAggChildrenMap.map { kv => + (kv._1, kv._2.toAttribute) + } val regularAggChildAttrLookup = regularAggChildAttrMap.toMap val regularAggPairs = regularAggExprs.map { case ae @ AggregateExpression(af, _, _, filter, _) => @@ -244,7 +253,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // Constructs pairs between old and new expressions for distinct aggregates, too. val distinctAggExprs = distinctAggExpressions.filter(e => e.children.exists(!_.foldable)) - val (projections, expressionAttrs, distinctAggPairs) = distinctAggExprs.map { + val (projections, distinctAggPairs) = distinctAggExprs.map { case ae @ AggregateExpression(af, _, _, filter, _) => // Why do we need to construct the `exprId` ? // First, In order to reduce costs, it is better to handle the filter clause locally. @@ -256,41 +265,54 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // attribute '_gen_distinct-1 and attribute '_gen_distinct-2 instead of two 'a. // Note: The illusionary mechanism may result in at least two distinct groups, so we // still need to call `rewrite`. - val exprId = NamedExpression.newExprId.id val unfoldableChildren = af.children.filter(!_.foldable) - val exprAttrs = unfoldableChildren.map { e => - (e, AttributeReference(s"_gen_distinct_$exprId", e.dataType, nullable = true)()) + // Expand projection + val projectionMap = unfoldableChildren.map { + case e if filter.isDefined => + val ife = If(filter.get, e, nullify(e)) + e -> Alias(ife, ife.toString)() + case ne: NamedExpression => ne -> ne + case e => e -> Alias(e, e.toString)() + } + val projection = projectionMap.map(_._2) + val exprAttrs = projectionMap.map { kv => + (kv._1, kv._2.toAttribute) } val exprAttrLookup = exprAttrs.toMap val newChildren = af.children.map(c => exprAttrLookup.getOrElse(c, c)) val raf = af.withNewChildren(newChildren).asInstanceOf[AggregateFunction] val aggExpr = ae.copy(aggregateFunction = raf, filter = None) - // Expand projection - val projection = unfoldableChildren.map { - case e if filter.isDefined => If(filter.get, e, nullify(e)) - case e => e - } - (projection, exprAttrs, (ae, aggExpr)) - }.unzip3 - val distinctAggChildAttrs = expressionAttrs.flatten.map(_._2) - val allAggAttrs = regularAggChildAttrMap.map(_._2) ++ distinctAggChildAttrs + (projection, (ae, aggExpr)) + }.unzip +// val distinctAggChildAttrs = expressionAttrs.flatten.map(_._2) +// val allAggAttrs = regularAggChildAttrMap.map(_._2) ++ distinctAggChildAttrs // Construct the aggregate input projection. - val rewriteAggProjections = - Seq(a.groupingExpressions ++ regularAggChildren ++ projections.flatten) - val groupByMap = a.groupingExpressions.collect { - case ne: NamedExpression => ne -> ne.toAttribute - case e => e -> AttributeReference(e.sql, e.dataType, e.nullable)() + val namedGroupingExpressions = a.groupingExpressions.map { + case ne: NamedExpression => ne + // If the expression is not a NamedExpressions, we add an alias. + // So, when we generate the result of the operator, the Aggregate Operator + // can directly get the Seq of attributes representing the grouping expressions. + case other => Alias(other, other.toString)() } - val groupByAttrs = groupByMap.map(_._2) + val rewriteAggProjection = namedGroupingExpressions ++ projections.flatten + val project = Project(rewriteAggProjection, a.child) +// val rewriteAggProjections = +// Seq(a.groupingExpressions ++ regularAggChildren ++ projections.flatten) +// val groupByMap = a.groupingExpressions.collect { +// case ne: NamedExpression => ne -> ne.toAttribute +// case e => e -> AttributeReference(e.sql, e.dataType, e.nullable)() +// } +// val groupByAttrs = groupByMap.map(_._2) + val groupByAttrs = namedGroupingExpressions.map(_.toAttribute) // Construct the expand operator. - val expand = Expand(rewriteAggProjections, groupByAttrs ++ allAggAttrs, a.child) +// val expand = Expand(rewriteAggProjections, groupByAttrs ++ allAggAttrs, a.child) val rewriteAggExprLookup = (distinctAggPairs ++ regularAggPairs).toMap val patchedAggExpressions = a.aggregateExpressions.map { e => e.transformDown { case ae: AggregateExpression => rewriteAggExprLookup.getOrElse(ae, ae) }.asInstanceOf[NamedExpression] } - val expandAggregate = Aggregate(groupByAttrs, patchedAggExpressions, expand) + val expandAggregate = Aggregate(groupByAttrs, patchedAggExpressions, project) expandAggregate } else { a diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala index 8391b650e41f7..a7d1288d10a8a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{EqualTo, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, GROUP_BY_ORDINAL} import org.apache.spark.sql.types.{IntegerType, StringType} @@ -43,12 +43,12 @@ class RewriteDistinctAggregatesSuite extends PlanTest { } private def checkGenerate(generate: LogicalPlan): Unit = generate match { - case Aggregate(_, _, _: Expand) => + case Aggregate(_, _, _: Project) => case _ => fail(s"Plan is not generated:\n$generate") } private def checkGenerateAndRewrite(rewrite: LogicalPlan): Unit = rewrite match { - case Aggregate(_, _, Aggregate(_, _, Expand(_, _, _: Expand))) => + case Aggregate(_, _, Aggregate(_, _, Expand(_, _, _: Project))) => case _ => fail(s"Plan is not rewritten:\n$rewrite") } From 529b69ebb5e12c50b81231f81b1dc728e416f7ca Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Fri, 3 Jul 2020 19:16:58 +0800 Subject: [PATCH 09/21] Expand to Project --- .../optimizer/RewriteDistinctAggregates.scala | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 1911527c808c6..19e7e35e81542 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -231,9 +231,6 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { val regularAggChildren = (regularFunChildren ++ regularFilterAttrs).distinct val regularAggChildrenMap = regularAggChildren.map { case ne: NamedExpression => ne -> ne - // If the expression is not a NamedExpressions, we add an alias. - // So, when we generate the result of the operator, the Aggregate Operator - // can directly get the Seq of attributes representing the grouping expressions. case other => other -> Alias(other, other.toString)() } val regularAggChildAttrMap = regularAggChildrenMap.map { kv => @@ -289,9 +286,6 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // Construct the aggregate input projection. val namedGroupingExpressions = a.groupingExpressions.map { case ne: NamedExpression => ne - // If the expression is not a NamedExpressions, we add an alias. - // So, when we generate the result of the operator, the Aggregate Operator - // can directly get the Seq of attributes representing the grouping expressions. case other => Alias(other, other.toString)() } val rewriteAggProjection = namedGroupingExpressions ++ projections.flatten @@ -502,4 +496,9 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // children, in this case attribute reuse causes the input of the regular aggregate to bound to // the (nulled out) input of the distinct aggregate. e -> AttributeReference(e.sql, e.dataType, nullable = true)() + + private def addAlias(expressions: Seq[Expression]) = expressions.map { + case ne: NamedExpression => ne + case other => Alias(other, other.toString)() + } } From 54f6d84d7a455dfc471daea3f8ecb9d348920916 Mon Sep 17 00:00:00 2001 From: beliefer Date: Sat, 4 Jul 2020 14:42:50 +0800 Subject: [PATCH 10/21] change Expand to Project --- .../optimizer/RewriteDistinctAggregates.scala | 29 +++++++------------ .../sql-tests/results/group-by-filter.sql.out | 2 +- 2 files changed, 11 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 19e7e35e81542..f7c7525c9872e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -158,8 +158,8 @@ import org.apache.spark.sql.types.IntegerType * ('key, '_gen_distinct_1, null, 1, null), * ('key, null, '_gen_distinct_2, 2, null)] * output = ['key, '_gen_distinct_1, '_gen_distinct_2, 'gid, 'value]) - * Expand( - * projections = [('key, if ('id > 1) 'cat1 else null, 'cat2, cast('value as bigint))] + * Project( + * projectList = ['key, if ('id > 1) 'cat1 else null, 'cat2, cast('value as bigint)] * output = ['key, '_gen_distinct_1, '_gen_distinct_2, 'value]) * LocalTableScan [...] * }}} @@ -233,6 +233,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { case ne: NamedExpression => ne -> ne case other => other -> Alias(other, other.toString)() } + val namedRegularAggChildren = regularAggChildrenMap.map(_._2) val regularAggChildAttrMap = regularAggChildrenMap.map { kv => (kv._1, kv._2.toAttribute) } @@ -262,14 +263,14 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // attribute '_gen_distinct-1 and attribute '_gen_distinct-2 instead of two 'a. // Note: The illusionary mechanism may result in at least two distinct groups, so we // still need to call `rewrite`. + val exprId = NamedExpression.newExprId.id val unfoldableChildren = af.children.filter(!_.foldable) // Expand projection val projectionMap = unfoldableChildren.map { case e if filter.isDefined => val ife = If(filter.get, e, nullify(e)) - e -> Alias(ife, ife.toString)() - case ne: NamedExpression => ne -> ne - case e => e -> Alias(e, e.toString)() + e -> Alias(ife, s"_gen_distinct_$exprId")() + case e => e -> Alias(e, s"_gen_distinct_$exprId")() } val projection = projectionMap.map(_._2) val exprAttrs = projectionMap.map { kv => @@ -281,33 +282,23 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { val aggExpr = ae.copy(aggregateFunction = raf, filter = None) (projection, (ae, aggExpr)) }.unzip -// val distinctAggChildAttrs = expressionAttrs.flatten.map(_._2) -// val allAggAttrs = regularAggChildAttrMap.map(_._2) ++ distinctAggChildAttrs // Construct the aggregate input projection. val namedGroupingExpressions = a.groupingExpressions.map { case ne: NamedExpression => ne case other => Alias(other, other.toString)() } - val rewriteAggProjection = namedGroupingExpressions ++ projections.flatten + val rewriteAggProjection = + namedGroupingExpressions ++ namedRegularAggChildren ++ projections.flatten + // Construct the project operator. val project = Project(rewriteAggProjection, a.child) -// val rewriteAggProjections = -// Seq(a.groupingExpressions ++ regularAggChildren ++ projections.flatten) -// val groupByMap = a.groupingExpressions.collect { -// case ne: NamedExpression => ne -> ne.toAttribute -// case e => e -> AttributeReference(e.sql, e.dataType, e.nullable)() -// } -// val groupByAttrs = groupByMap.map(_._2) val groupByAttrs = namedGroupingExpressions.map(_.toAttribute) - // Construct the expand operator. -// val expand = Expand(rewriteAggProjections, groupByAttrs ++ allAggAttrs, a.child) val rewriteAggExprLookup = (distinctAggPairs ++ regularAggPairs).toMap val patchedAggExpressions = a.aggregateExpressions.map { e => e.transformDown { case ae: AggregateExpression => rewriteAggExprLookup.getOrElse(ae, ae) }.asInstanceOf[NamedExpression] } - val expandAggregate = Aggregate(groupByAttrs, patchedAggExpressions, project) - expandAggregate + Aggregate(groupByAttrs, patchedAggExpressions, project) } else { a } diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out index 43b99772f3c2f..669f39de00568 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out @@ -113,7 +113,7 @@ struct +struct -- !query output 2 2 From a7bcbc9baa632d31d8f6b483d3ba98db4826a85e Mon Sep 17 00:00:00 2001 From: beliefer Date: Sat, 4 Jul 2020 14:49:25 +0800 Subject: [PATCH 11/21] Optimize code --- .../sql/catalyst/optimizer/RewriteDistinctAggregates.scala | 5 ----- 1 file changed, 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index f7c7525c9872e..72cf29a1ee2e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -487,9 +487,4 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // children, in this case attribute reuse causes the input of the regular aggregate to bound to // the (nulled out) input of the distinct aggregate. e -> AttributeReference(e.sql, e.dataType, nullable = true)() - - private def addAlias(expressions: Seq[Expression]) = expressions.map { - case ne: NamedExpression => ne - case other => Alias(other, other.toString)() - } } From 73dc600015d98a2a64b76293fa17bed89254dd2c Mon Sep 17 00:00:00 2001 From: beliefer Date: Sat, 4 Jul 2020 23:56:24 +0800 Subject: [PATCH 12/21] Optimize code --- .../optimizer/RewriteDistinctAggregates.scala | 11 +++----- .../sql-tests/inputs/group-by-filter.sql | 2 ++ .../sql-tests/results/group-by-filter.sql.out | 28 ++++++++++++++++++- 3 files changed, 33 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 72cf29a1ee2e9..68c0c2b5714c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -234,10 +234,9 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { case other => other -> Alias(other, other.toString)() } val namedRegularAggChildren = regularAggChildrenMap.map(_._2) - val regularAggChildAttrMap = regularAggChildrenMap.map { kv => + val regularAggChildAttrLookup = regularAggChildrenMap.map { kv => (kv._1, kv._2.toAttribute) - } - val regularAggChildAttrLookup = regularAggChildAttrMap.toMap + }.toMap val regularAggPairs = regularAggExprs.map { case ae @ AggregateExpression(af, _, _, filter, _) => val newChildren = af.children.map(c => regularAggChildAttrLookup.getOrElse(c, c)) @@ -253,7 +252,6 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { val distinctAggExprs = distinctAggExpressions.filter(e => e.children.exists(!_.foldable)) val (projections, distinctAggPairs) = distinctAggExprs.map { case ae @ AggregateExpression(af, _, _, filter, _) => - // Why do we need to construct the `exprId` ? // First, In order to reduce costs, it is better to handle the filter clause locally. // e.g. COUNT (DISTINCT a) FILTER (WHERE id > 1), evaluate expression // If(id > 1) 'a else null first, and use the result as output. @@ -263,14 +261,13 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // attribute '_gen_distinct-1 and attribute '_gen_distinct-2 instead of two 'a. // Note: The illusionary mechanism may result in at least two distinct groups, so we // still need to call `rewrite`. - val exprId = NamedExpression.newExprId.id val unfoldableChildren = af.children.filter(!_.foldable) // Expand projection val projectionMap = unfoldableChildren.map { case e if filter.isDefined => val ife = If(filter.get, e, nullify(e)) - e -> Alias(ife, s"_gen_distinct_$exprId")() - case e => e -> Alias(e, s"_gen_distinct_$exprId")() + e -> Alias(ife, s"_gen_distinct_${NamedExpression.newExprId.id}")() + case e => e -> Alias(e, s"_gen_distinct_${NamedExpression.newExprId.id}")() } val projection = projectionMap.map(_._2) val exprAttrs = projectionMap.map { kv => diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql index f065c7cf2979a..c36fd3bee7ded 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql @@ -77,6 +77,8 @@ select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct select dept_id, sum(distinct (id + dept_id)) filter (where id > 200), count(distinct hiredate), sum(salary) from emp group by dept_id; select dept_id, sum(distinct (id + dept_id)) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id; select dept_id, avg(distinct (id + dept_id)) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D) from emp group by dept_id; +select dept_id, count(distinct emp_name, hiredate) filter (where id > 200), sum(salary) from emp group by dept_id; +select dept_id, count(distinct emp_name, hiredate) filter (where id > 0), sum(salary) from emp group by dept_id; -- Aggregate with filter and grouped by literals (hash aggregate), here the input table is filtered using WHERE. SELECT 'foo', APPROX_COUNT_DISTINCT(a) FILTER (WHERE b >= 0) FROM testData WHERE a = 0 GROUP BY 1; diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out index 669f39de00568..4dac09fcefe2e 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 61 +-- Number of queries: 63 -- !query @@ -540,6 +540,32 @@ struct 200)):double NULL NULL 0 NULL +-- !query +select dept_id, count(distinct emp_name, hiredate) filter (where id > 200), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,sum(salary):double> +-- !query output +10 0 400.0 +100 2 800.0 +20 1 300.0 +30 1 400.0 +70 1 150.0 +NULL 1 400.0 + + +-- !query +select dept_id, count(distinct emp_name, hiredate) filter (where id > 0), sum(salary) from emp group by dept_id +-- !query schema +struct 0)):bigint,sum(salary):double> +-- !query output +10 2 400.0 +100 2 800.0 +20 1 300.0 +30 1 400.0 +70 1 150.0 +NULL 1 400.0 + + -- !query SELECT 'foo', APPROX_COUNT_DISTINCT(a) FILTER (WHERE b >= 0) FROM testData WHERE a = 0 GROUP BY 1 -- !query schema From 5a4ca029d8d04c30e1d6d76932497b85b4624383 Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Tue, 7 Jul 2020 10:44:38 +0800 Subject: [PATCH 13/21] Supplement docs. --- .../optimizer/RewriteDistinctAggregates.scala | 35 ++++++++++++++----- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 68c0c2b5714c4..d4526726c097e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -166,13 +166,28 @@ import org.apache.spark.sql.types.IntegerType * * The rule consists of the two phases as follows: * - * In the first phase, expands data for the distinct aggregates where filter clauses exist: - * 1. Guaranteed to compute filter clauses in the first aggregate locally. - * 2. The attributes referenced by different distinct aggregate expressions are likely to overlap, - * and if no additional processing is performed, data loss will occur. To prevent this, we - * generate new attributes and replace the original ones. - * 3. After generate new attributes, the aggregate may have at least two distinct aggregates, - * so we need the second phase too. + * In the first phase, if the aggregate query with distinct aggregations and + * filter clauses, project the output of the child of the aggregate query: + * 1. Project the data. There are three aggregation groups in this query: + * i. the non-distinct group; + * ii. the distinct 'cat1 group; + * iii. the distinct 'cat2 group with filter clause. + * Because there is at least one distinct group with filter clause (e.g. the distinct 'cat2 + * group with filter clause), then will project the data. + * 2. Avoid projections that may output the same attributes. There are three aggregation groups + * in this query: + * i. the non-distinct group; + * ii. the distinct 'cat1 group; + * iii. the distinct 'cat1 group with filter clause. + * The attributes referenced by different distinct aggregate expressions are likely to overlap, + * and if no additional processing is performed, data loss will occur. If we directly output + * the attributes of the aggregate expression, we may get two attributes 'cat1. To prevent + * this, we generate new attributes (e.g. '_gen_distinct_1) and replace the original ones. + * + * Why we need the first phase? guaranteed to compute filter clauses in the first aggregate + * locally. + * Note: after generate new attributes, the aggregate may have at least two distinct aggregates, + * so we need the second phase too. * * In the second phase, rewrite a query with two or more distinct groups: * 1. Expand the data. There are three aggregation groups in this query: @@ -212,11 +227,11 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case a: Aggregate if mayNeedtoRewrite(a.aggregateExpressions) => - val expandAggregate = extractFiltersInDistinctAggregates(a) + val expandAggregate = projectFiltersInDistinctAggregates(a) rewriteDistinctAggregates(expandAggregate) } - private def extractFiltersInDistinctAggregates(a: Aggregate): Aggregate = { + private def projectFiltersInDistinctAggregates(a: Aggregate): Aggregate = { val aggExpressions = collectAggregateExprs(a) val (distinctAggExpressions, regularAggExpressions) = aggExpressions.partition(_.isDistinct) if (distinctAggExpressions.exists(_.filter.isDefined)) { @@ -267,6 +282,8 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { case e if filter.isDefined => val ife = If(filter.get, e, nullify(e)) e -> Alias(ife, s"_gen_distinct_${NamedExpression.newExprId.id}")() + // For convenience and unification, we always alias the distinct column, even if + // there is no filter. case e => e -> Alias(e, s"_gen_distinct_${NamedExpression.newExprId.id}")() } val projection = projectionMap.map(_._2) From 70ff08e2e0863bd3a25ca371331bac281ca44190 Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Wed, 8 Jul 2020 14:48:38 +0800 Subject: [PATCH 14/21] Merge project with expand --- .../optimizer/RewriteDistinctAggregates.scala | 58 +++++++++++++------ 1 file changed, 41 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index d4526726c097e..ba86840d93f2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -154,14 +154,11 @@ import org.apache.spark.sql.types.IntegerType * functions = [sum('value)] * output = ['key, '_gen_distinct_1, '_gen_distinct_2, 'gid, 'total]) * Expand( - * projections = [('key, null, null, 0, 'value), - * ('key, '_gen_distinct_1, null, 1, null), - * ('key, null, '_gen_distinct_2, 2, null)] - * output = ['key, '_gen_distinct_1, '_gen_distinct_2, 'gid, 'value]) - * Project( - * projectList = ['key, if ('id > 1) 'cat1 else null, 'cat2, cast('value as bigint)] - * output = ['key, '_gen_distinct_1, '_gen_distinct_2, 'value]) - * LocalTableScan [...] + * projections = [('key, null, null, 0, cast('value as bigint)), + * ('key, if ('id > 1) 'cat1 else null, null, 1, null), + * ('key, null, 'cat2, 2, null)] + * output = ['key, '_gen_distinct_1, '_gen_distinct_2, 'gid, 'value]) + * LocalTableScan [...] * }}} * * The rule consists of the two phases as follows: @@ -206,6 +203,9 @@ import org.apache.spark.sql.types.IntegerType * aggregation. In this step we use the group id to filter the inputs for the aggregate * functions. The result of the non-distinct group are 'aggregated' by using the first operator, * it might be more elegant to use the native UDAF merge mechanism for this in the future. + * 4. If the first phase inserted a project operator as the child of aggregate and the second phase + * already decided to insert an expand operator as the child of aggregate, the second phase will + * merge the project operator with expand operator. * * This rule duplicates the input data by two or more times (# distinct groups + an optional * non-distinct group). This will put quite a bit of memory pressure of the used aggregate and @@ -227,11 +227,11 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case a: Aggregate if mayNeedtoRewrite(a.aggregateExpressions) => - val expandAggregate = projectFiltersInDistinctAggregates(a) - rewriteDistinctAggregates(expandAggregate) + val (aggregate, projected) = projectFiltersInDistinctAggregates(a) + rewriteDistinctAggregates(aggregate, projected) } - private def projectFiltersInDistinctAggregates(a: Aggregate): Aggregate = { + private def projectFiltersInDistinctAggregates(a: Aggregate): (Aggregate, Boolean) = { val aggExpressions = collectAggregateExprs(a) val (distinctAggExpressions, regularAggExpressions) = aggExpressions.partition(_.isDistinct) if (distinctAggExpressions.exists(_.filter.isDefined)) { @@ -312,13 +312,13 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { case ae: AggregateExpression => rewriteAggExprLookup.getOrElse(ae, ae) }.asInstanceOf[NamedExpression] } - Aggregate(groupByAttrs, patchedAggExpressions, project) + (Aggregate(groupByAttrs, patchedAggExpressions, project), true) } else { - a + (a, false) } } - private def rewriteDistinctAggregates(a: Aggregate): Aggregate = { + private def rewriteDistinctAggregates(a: Aggregate, projected: Boolean): Aggregate = { val aggExpressions = collectAggregateExprs(a) // Extract distinct aggregate expressions. @@ -359,7 +359,15 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // Setup unique distinct aggregate children. val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct - val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair) + val distinctAggChildAttrMap = if (projected) { + // To facilitate merging Project with Expand, not need creating a new reference here. + distinctAggChildren.map { + case ar: AttributeReference => ar -> ar + case other => expressionAttributePair(other) + } + } else { + distinctAggChildren.map(expressionAttributePair) + } val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2) // Setup expand & aggregate operators for distinct aggregate expressions. @@ -448,11 +456,27 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { regularAggNulls } + val (projections, expandChild) = if (projected) { + // If `projectFiltersInDistinctAggregates` inserts Project as child of Aggregate and + // `rewriteDistinctAggregates` will insert Expand here, merge Project with the Expand. + val projectAttributeExpressionMap = a.child.asInstanceOf[Project].projectList.map { + case ne: NamedExpression => ne.name -> ne + }.toMap + val projections = (regularAggProjection ++ distinctAggProjections).map { + case projection: Seq[Expression] => projection.map { + case ne: NamedExpression => projectAttributeExpressionMap.getOrElse(ne.name, ne) + case other => other + } + } + (projections, a.child.asInstanceOf[Project].child) + } else { + (regularAggProjection ++ distinctAggProjections, a.child) + } // Construct the expand operator. val expand = Expand( - regularAggProjection ++ distinctAggProjections, + projections, groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.map(_._2), - a.child) + expandChild) // Construct the first aggregate operator. This de-duplicates all the children of // distinct operators, and applies the regular aggregate operators. From 16d8c1d26681c875d61e50305de43f3ca5e75154 Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Wed, 8 Jul 2020 14:52:10 +0800 Subject: [PATCH 15/21] Merge project with expand --- .../sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala index a7d1288d10a8a..393d28fc3244f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala @@ -48,7 +48,7 @@ class RewriteDistinctAggregatesSuite extends PlanTest { } private def checkGenerateAndRewrite(rewrite: LogicalPlan): Unit = rewrite match { - case Aggregate(_, _, Aggregate(_, _, Expand(_, _, _: Project))) => + case Aggregate(_, _, Aggregate(_, _, _: Expand)) => case _ => fail(s"Plan is not rewritten:\n$rewrite") } From 3c491564c8db4935699a25eb8f9fc74367f6ee95 Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Thu, 9 Jul 2020 15:27:47 +0800 Subject: [PATCH 16/21] Supplement comments. --- .../optimizer/RewriteDistinctAggregates.scala | 37 ++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index ba86840d93f2e..8d23a223a0ff5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -118,7 +118,42 @@ import org.apache.spark.sql.types.IntegerType * LocalTableScan [...] * }}} * - * Third example: single distinct aggregate function with filter clauses (in sql): + * Third example: single distinct aggregate function with filter clauses and have + * not other distinct aggregate function (in sql): + * {{{ + * SELECT + * COUNT(DISTINCT cat1) FILTER (WHERE id > 1) as cat1_cnt, + * SUM(value) AS total + * FROM + * data + * GROUP BY + * key + * }}} + * + * This translates to the following (pseudo) logical plan: + * {{{ + * Aggregate( + * key = ['key] + * functions = [COUNT(DISTINCT 'cat1) with FILTER('id > 1), + * sum('value)] + * output = ['key, 'cat1_cnt, 'total]) + * LocalTableScan [...] + * }}} + * + * This rule rewrites this logical plan to the following (pseudo) logical plan: + * {{{ + * Aggregate( + * key = ['key] + * functions = [count('_gen_distinct_1), + * sum('value)] + * output = ['key, 'cat1_cnt, 'total]) + * Project( + * projectionList = ['key, if ('id > 1) 'cat1 else null, cast('value as bigint)] + * output = ['key, '_gen_distinct_1, 'value]) + * LocalTableScan [...] + * }}} + * + * Four example: single distinct aggregate function with filter clauses (in sql): * {{{ * SELECT * COUNT(DISTINCT cat1) FILTER (WHERE id > 1) as cat1_cnt, From 762e839f8968103ef3c204f83ece187c83086933 Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Thu, 9 Jul 2020 18:45:50 +0800 Subject: [PATCH 17/21] Optimize code. --- .../catalyst/optimizer/RewriteDistinctAggregates.scala | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 8d23a223a0ff5..9771706d51cbb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -394,15 +394,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // Setup unique distinct aggregate children. val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct - val distinctAggChildAttrMap = if (projected) { - // To facilitate merging Project with Expand, not need creating a new reference here. - distinctAggChildren.map { - case ar: AttributeReference => ar -> ar - case other => expressionAttributePair(other) - } - } else { - distinctAggChildren.map(expressionAttributePair) - } + val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair) val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2) // Setup expand & aggregate operators for distinct aggregate expressions. From 12e6fbccba0773d955c9e5fed62da0a7e4f0c1e4 Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Mon, 13 Jul 2020 15:32:54 +0800 Subject: [PATCH 18/21] Unified implementation of filter in regular aggregates and distinct aggregates. --- .../optimizer/RewriteDistinctAggregates.scala | 116 ++++++------------ .../aggregate/AggregationIterator.scala | 39 ++---- 2 files changed, 47 insertions(+), 108 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 9771706d51cbb..8533ca47ebec6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -102,19 +102,19 @@ import org.apache.spark.sql.types.IntegerType * {{{ * Aggregate( * key = ['key] - * functions = [count(if (('gid = 1)) 'cat1 else null), - * count(if (('gid = 2)) 'cat2 else null), + * functions = [count(if (('gid = 1)) '_gen_attr_1 else null), + * count(if (('gid = 2)) '_gen_attr_2 else null), * first(if (('gid = 0)) 'total else null) ignore nulls] * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) * Aggregate( - * key = ['key, 'cat1, 'cat2, 'gid] - * functions = [sum('value) with FILTER('id > 1)] - * output = ['key, 'cat1, 'cat2, 'gid, 'total]) + * key = ['key, '_gen_attr_1, '_gen_attr_2, 'gid] + * functions = [sum('_gen_attr_3)] + * output = ['key, '_gen_attr_1, '_gen_attr_2, 'gid, 'total]) * Expand( - * projections = [('key, null, null, 0, cast('value as bigint), 'id), + * projections = [('key, null, null, 0, if ('id > 1) cast('value as bigint) else null, 'id), * ('key, 'cat1, null, 1, null, null), * ('key, null, 'cat2, 2, null, null)] - * output = ['key, 'cat1, 'cat2, 'gid, 'value, 'id]) + * output = ['key, '_gen_attr_1, '_gen_attr_2, 'gid, '_gen_attr_3, 'id]) * LocalTableScan [...] * }}} * @@ -144,12 +144,12 @@ import org.apache.spark.sql.types.IntegerType * {{{ * Aggregate( * key = ['key] - * functions = [count('_gen_distinct_1), - * sum('value)] + * functions = [count('_gen_attr_1), + * sum('_gen_attr_2)] * output = ['key, 'cat1_cnt, 'total]) * Project( * projectionList = ['key, if ('id > 1) 'cat1 else null, cast('value as bigint)] - * output = ['key, '_gen_distinct_1, 'value]) + * output = ['key, '_gen_attr_1, '_gen_attr_2]) * LocalTableScan [...] * }}} * @@ -180,45 +180,45 @@ import org.apache.spark.sql.types.IntegerType * {{{ * Aggregate( * key = ['key] - * functions = [count(if (('gid = 1)) '_gen_distinct_1 else null), - * count(if (('gid = 2)) '_gen_distinct_2 else null), + * functions = [count(if (('gid = 1)) '_gen_attr_1 else null), + * count(if (('gid = 2)) '_gen_attr_2 else null), * first(if (('gid = 0)) 'total else null) ignore nulls] * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) * Aggregate( - * key = ['key, '_gen_distinct_1, '_gen_distinct_2, 'gid] - * functions = [sum('value)] - * output = ['key, '_gen_distinct_1, '_gen_distinct_2, 'gid, 'total]) + * key = ['key, '_gen_attr_1, '_gen_attr_2, 'gid] + * functions = [sum('_gen_attr_3)] + * output = ['key, '_gen_attr_1, '_gen_attr_2, 'gid, 'total]) * Expand( * projections = [('key, null, null, 0, cast('value as bigint)), * ('key, if ('id > 1) 'cat1 else null, null, 1, null), * ('key, null, 'cat2, 2, null)] - * output = ['key, '_gen_distinct_1, '_gen_distinct_2, 'gid, 'value]) + * output = ['key, '_gen_attr_1, '_gen_attr_2, 'gid, '_gen_attr_3]) * LocalTableScan [...] * }}} * * The rule consists of the two phases as follows: * - * In the first phase, if the aggregate query with distinct aggregations and - * filter clauses, project the output of the child of the aggregate query: + * In the first phase, if the aggregate query exists filter clauses, project the output of + * the child of the aggregate query: * 1. Project the data. There are three aggregation groups in this query: * i. the non-distinct group; * ii. the distinct 'cat1 group; * iii. the distinct 'cat2 group with filter clause. - * Because there is at least one distinct group with filter clause (e.g. the distinct 'cat2 + * Because there is at least one group with filter clause (e.g. the distinct 'cat2 * group with filter clause), then will project the data. * 2. Avoid projections that may output the same attributes. There are three aggregation groups * in this query: - * i. the non-distinct group; + * i. the non-distinct 'cat1 group; * ii. the distinct 'cat1 group; * iii. the distinct 'cat1 group with filter clause. - * The attributes referenced by different distinct aggregate expressions are likely to overlap, + * The attributes referenced by different aggregate expressions are likely to overlap, * and if no additional processing is performed, data loss will occur. If we directly output - * the attributes of the aggregate expression, we may get two attributes 'cat1. To prevent - * this, we generate new attributes (e.g. '_gen_distinct_1) and replace the original ones. + * the attributes of the aggregate expression, we may get three attributes 'cat1. To prevent + * this, we generate new attributes (e.g. '_gen_attr_1) and replace the original ones. * * Why we need the first phase? guaranteed to compute filter clauses in the first aggregate * locally. - * Note: after generate new attributes, the aggregate may have at least two distinct aggregates, + * Note: after generate new attributes, the aggregate may have at least two distinct groups, * so we need the second phase too. * * In the second phase, rewrite a query with two or more distinct groups: @@ -262,45 +262,16 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case a: Aggregate if mayNeedtoRewrite(a.aggregateExpressions) => - val (aggregate, projected) = projectFiltersInDistinctAggregates(a) + val (aggregate, projected) = projectFiltersInAggregates(a) rewriteDistinctAggregates(aggregate, projected) } - private def projectFiltersInDistinctAggregates(a: Aggregate): (Aggregate, Boolean) = { + private def projectFiltersInAggregates(a: Aggregate): (Aggregate, Boolean) = { val aggExpressions = collectAggregateExprs(a) - val (distinctAggExpressions, regularAggExpressions) = aggExpressions.partition(_.isDistinct) - if (distinctAggExpressions.exists(_.filter.isDefined)) { - // Constructs pairs between old and new expressions for regular aggregates. Because we - // will construct a new `Aggregate` and the children of the distinct aggregates will be - // changed to generated ones, we need to create new references to avoid collisions between - // distinct and regular aggregate children. - val regularAggExprs = regularAggExpressions.filter(_.children.exists(!_.foldable)) - val regularFunChildren = regularAggExprs - .flatMap(_.aggregateFunction.children.filter(!_.foldable)) - val regularFilterAttrs = regularAggExprs.flatMap(_.filterAttributes) - val regularAggChildren = (regularFunChildren ++ regularFilterAttrs).distinct - val regularAggChildrenMap = regularAggChildren.map { - case ne: NamedExpression => ne -> ne - case other => other -> Alias(other, other.toString)() - } - val namedRegularAggChildren = regularAggChildrenMap.map(_._2) - val regularAggChildAttrLookup = regularAggChildrenMap.map { kv => - (kv._1, kv._2.toAttribute) - }.toMap - val regularAggPairs = regularAggExprs.map { - case ae @ AggregateExpression(af, _, _, filter, _) => - val newChildren = af.children.map(c => regularAggChildAttrLookup.getOrElse(c, c)) - val raf = af.withNewChildren(newChildren).asInstanceOf[AggregateFunction] - val filterOpt = filter.map(_.transform { - case a: Attribute => regularAggChildAttrLookup.getOrElse(a, a) - }) - val aggExpr = ae.copy(aggregateFunction = raf, filter = filterOpt) - (ae, aggExpr) - } - - // Constructs pairs between old and new expressions for distinct aggregates, too. - val distinctAggExprs = distinctAggExpressions.filter(e => e.children.exists(!_.foldable)) - val (projections, distinctAggPairs) = distinctAggExprs.map { + if (aggExpressions.exists(_.filter.isDefined)) { + // Constructs pairs between old and new expressions for aggregates. + val aggExprs = aggExpressions.filter(e => e.children.exists(!_.foldable)) + val (projections, aggPairs) = aggExprs.map { case ae @ AggregateExpression(af, _, _, filter, _) => // First, In order to reduce costs, it is better to handle the filter clause locally. // e.g. COUNT (DISTINCT a) FILTER (WHERE id > 1), evaluate expression @@ -308,18 +279,18 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // Second, If at least two DISTINCT aggregate expression which may references the // same attributes. We need to construct the generated attributes so as the output not // lost. e.g. SUM (DISTINCT a), COUNT (DISTINCT a) FILTER (WHERE id > 1) will output - // attribute '_gen_distinct-1 and attribute '_gen_distinct-2 instead of two 'a. + // attribute '_gen_attr-1 and attribute '_gen_attr-2 instead of two 'a. // Note: The illusionary mechanism may result in at least two distinct groups, so we - // still need to call `rewrite`. + // still need to call `rewriteDistinctAggregates`. val unfoldableChildren = af.children.filter(!_.foldable) // Expand projection val projectionMap = unfoldableChildren.map { case e if filter.isDefined => val ife = If(filter.get, e, nullify(e)) - e -> Alias(ife, s"_gen_distinct_${NamedExpression.newExprId.id}")() + e -> Alias(ife, s"_gen_attr_${NamedExpression.newExprId.id}")() // For convenience and unification, we always alias the distinct column, even if // there is no filter. - case e => e -> Alias(e, s"_gen_distinct_${NamedExpression.newExprId.id}")() + case e => e -> Alias(e, s"_gen_attr_${NamedExpression.newExprId.id}")() } val projection = projectionMap.map(_._2) val exprAttrs = projectionMap.map { kv => @@ -336,12 +307,11 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { case ne: NamedExpression => ne case other => Alias(other, other.toString)() } - val rewriteAggProjection = - namedGroupingExpressions ++ namedRegularAggChildren ++ projections.flatten + val rewriteAggProjection = namedGroupingExpressions ++ projections.flatten // Construct the project operator. val project = Project(rewriteAggProjection, a.child) val groupByAttrs = namedGroupingExpressions.map(_.toAttribute) - val rewriteAggExprLookup = (distinctAggPairs ++ regularAggPairs).toMap + val rewriteAggExprLookup = aggPairs.toMap val patchedAggExpressions = a.aggregateExpressions.map { e => e.transformDown { case ae: AggregateExpression => rewriteAggExprLookup.getOrElse(ae, ae) @@ -425,10 +395,9 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // only expand unfoldable children val regularAggExprs = aggExpressions .filter(e => !e.isDistinct && e.children.exists(!_.foldable)) - val regularAggFunChildren = regularAggExprs + val regularAggChildren = regularAggExprs .flatMap(_.aggregateFunction.children.filter(!_.foldable)) - val regularAggFilterAttrs = regularAggExprs.flatMap(_.filterAttributes) - val regularAggChildren = (regularAggFunChildren ++ regularAggFilterAttrs).distinct + .distinct val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair) // Setup aggregates for 'regular' aggregate expressions. @@ -437,12 +406,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { val regularAggOperatorMap = regularAggExprs.map { e => // Perform the actual aggregation in the initial aggregate. val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup.get) - // We changed the attributes in the [[Expand]] output using expressionAttributePair. - // So we need to replace the attributes in FILTER expression with new ones. - val filterOpt = e.filter.map(_.transform { - case a: Attribute => regularAggChildAttrLookup.getOrElse(a, a) - }) - val operator = Alias(e.copy(aggregateFunction = af, filter = filterOpt), e.sql)() + val operator = Alias(e.copy(aggregateFunction = af), e.sql)() // Select the result of the first aggregate in the last aggregate. val result = AggregateExpression( @@ -484,7 +448,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } val (projections, expandChild) = if (projected) { - // If `projectFiltersInDistinctAggregates` inserts Project as child of Aggregate and + // If `projectFiltersInAggregates` inserts Project as child of Aggregate and // `rewriteDistinctAggregates` will insert Expand here, merge Project with the Expand. val projectAttributeExpressionMap = a.child.asInstanceOf[Project].projectList.map { case ne: NamedExpression => ne.name -> ne diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 527a9eac9948e..d03de1507fbbd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -157,44 +157,19 @@ abstract class AggregationIterator( inputAttributes: Seq[Attribute]): (InternalRow, InternalRow) => Unit = { val joinedRow = new JoinedRow if (expressions.nonEmpty) { - val mergeExpressions = - functions.zip(expressions.map(ae => (ae.mode, ae.isDistinct, ae.filter))).flatMap { - case (ae: DeclarativeAggregate, (mode, isDistinct, filter)) => - mode match { - case Partial | Complete => - if (filter.isDefined) { - ae.updateExpressions.zip(ae.aggBufferAttributes).map { - case (updateExpr, attr) => If(filter.get, updateExpr, attr) - } - } else { - ae.updateExpressions - } - case PartialMerge | Final => ae.mergeExpressions - } - case (agg: AggregateFunction, _) => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - // Initialize predicates for aggregate functions if necessary - val predicateOptions = expressions.map { - case AggregateExpression(_, mode, _, Some(filter), _) => - mode match { - case Partial | Complete => - val predicate = Predicate.create(filter, inputAttributes) - predicate.initialize(partIndex) - Some(predicate) - case _ => None + val mergeExpressions = functions.zip(expressions).flatMap { + case (ae: DeclarativeAggregate, expression) => + expression.mode match { + case Partial | Complete => ae.updateExpressions + case PartialMerge | Final => ae.mergeExpressions } - case _ => None + case (agg: AggregateFunction, _) => Seq.fill(agg.aggBufferAttributes.length)(NoOp) } val updateFunctions = functions.zipWithIndex.collect { case (ae: ImperativeAggregate, i) => expressions(i).mode match { case Partial | Complete => - if (predicateOptions(i).isDefined) { - (buffer: InternalRow, row: InternalRow) => - if (predicateOptions(i).get.eval(row)) { ae.update(buffer, row) } - } else { - (buffer: InternalRow, row: InternalRow) => ae.update(buffer, row) - } + (buffer: InternalRow, row: InternalRow) => ae.update(buffer, row) case PartialMerge | Final => (buffer: InternalRow, row: InternalRow) => ae.merge(buffer, row) } From d531864c7da4244674377369854ab2f02b5acfa2 Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Mon, 13 Jul 2020 16:21:48 +0800 Subject: [PATCH 19/21] Update comments. --- .../sql/catalyst/optimizer/RewriteDistinctAggregates.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 8533ca47ebec6..d88cdecb7d22f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -288,7 +288,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { case e if filter.isDefined => val ife = If(filter.get, e, nullify(e)) e -> Alias(ife, s"_gen_attr_${NamedExpression.newExprId.id}")() - // For convenience and unification, we always alias the distinct column, even if + // For convenience and unification, we always alias the column, even if // there is no filter. case e => e -> Alias(e, s"_gen_attr_${NamedExpression.newExprId.id}")() } From 5bbbfd771e7aa1cfa6c41ccced92c8f0cf760a4d Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Mon, 13 Jul 2020 17:25:54 +0800 Subject: [PATCH 20/21] Optimize code. --- .../optimizer/RewriteDistinctAggregates.scala | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index d88cdecb7d22f..15fa4f9bc8bc3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -250,18 +250,17 @@ import org.apache.spark.sql.types.IntegerType */ object RewriteDistinctAggregates extends Rule[LogicalPlan] { - private def mayNeedtoRewrite(exprs: Seq[Expression]): Boolean = { - val distinctAggs = exprs.flatMap { _.collect { - case ae: AggregateExpression if ae.isDistinct => ae - }} - // We need at least two distinct aggregates or a single distinct aggregate with a filter for + private def mayNeedtoRewrite(a: Aggregate): Boolean = { + val aggExpressions = collectAggregateExprs(a) + val distinctAggs = aggExpressions.filter(_.isDistinct) + // We need at least two distinct aggregates or the aggregate query exists filter clause for // this rule because aggregation strategy can handle a single distinct group without a filter. // This check can produce false-positives, e.g., SUM(DISTINCT a) & COUNT(DISTINCT a). - distinctAggs.size > 1 || distinctAggs.exists(_.filter.isDefined) + distinctAggs.size > 1 || aggExpressions.exists(_.filter.isDefined) } def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case a: Aggregate if mayNeedtoRewrite(a.aggregateExpressions) => + case a: Aggregate if mayNeedtoRewrite(a) => val (aggregate, projected) = projectFiltersInAggregates(a) rewriteDistinctAggregates(aggregate, projected) } From 20ad143c620ef75e8d446f8f1e595992a1959b4a Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Tue, 14 Jul 2020 10:20:33 +0800 Subject: [PATCH 21/21] Update comments. --- .../optimizer/RewriteDistinctAggregates.scala | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 15fa4f9bc8bc3..ba1c1dae69d89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -153,7 +153,8 @@ import org.apache.spark.sql.types.IntegerType * LocalTableScan [...] * }}} * - * Four example: single distinct aggregate function with filter clauses (in sql): + * Four example: at least two distinct aggregate function and one of them having + * filter clauses (in sql): * {{{ * SELECT * COUNT(DISTINCT cat1) FILTER (WHERE id > 1) as cat1_cnt, @@ -201,15 +202,15 @@ import org.apache.spark.sql.types.IntegerType * In the first phase, if the aggregate query exists filter clauses, project the output of * the child of the aggregate query: * 1. Project the data. There are three aggregation groups in this query: - * i. the non-distinct group; - * ii. the distinct 'cat1 group; + * i. the non-distinct group without filter clause; + * ii. the distinct 'cat1 group without filter clause; * iii. the distinct 'cat2 group with filter clause. - * Because there is at least one group with filter clause (e.g. the distinct 'cat2 - * group with filter clause), then will project the data. + * When there is at least one aggregate function having the filter clause, we add a project + * node on the input plan. * 2. Avoid projections that may output the same attributes. There are three aggregation groups * in this query: - * i. the non-distinct 'cat1 group; - * ii. the distinct 'cat1 group; + * i. the non-distinct 'cat1 group without filter clause; + * ii. the distinct 'cat1 group without filter clause; * iii. the distinct 'cat1 group with filter clause. * The attributes referenced by different aggregate expressions are likely to overlap, * and if no additional processing is performed, data loss will occur. If we directly output @@ -302,21 +303,21 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { (projection, (ae, aggExpr)) }.unzip // Construct the aggregate input projection. - val namedGroupingExpressions = a.groupingExpressions.map { - case ne: NamedExpression => ne - case other => Alias(other, other.toString)() + val namedGroupingProjection = a.groupingExpressions.flatMap { e => + e.collect { + case ar: AttributeReference => ar + } } - val rewriteAggProjection = namedGroupingExpressions ++ projections.flatten + val rewriteAggProjection = namedGroupingProjection ++ projections.flatten // Construct the project operator. val project = Project(rewriteAggProjection, a.child) - val groupByAttrs = namedGroupingExpressions.map(_.toAttribute) val rewriteAggExprLookup = aggPairs.toMap val patchedAggExpressions = a.aggregateExpressions.map { e => e.transformDown { case ae: AggregateExpression => rewriteAggExprLookup.getOrElse(ae, ae) }.asInstanceOf[NamedExpression] } - (Aggregate(groupByAttrs, patchedAggExpressions, project), true) + (Aggregate(a.groupingExpressions, patchedAggExpressions, project), true) } else { (a, false) }