From a989db63c93dec8c80bfda7f445b007cc28dc55d Mon Sep 17 00:00:00 2001 From: feiniaofeiafei Date: Thu, 23 May 2024 22:16:12 +0800 Subject: [PATCH 1/4] [Fix](nereids) fix merge aggregate bug --- .../nereids/rules/rewrite/MergeAggregate.java | 27 +++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java index a2c23dd9b412b0..1b471be9591430 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java @@ -31,6 +31,7 @@ import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.PlanUtils; +import com.clearspring.analytics.util.Lists; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; @@ -86,6 +87,15 @@ private Plan mergeTwoAggregate(LogicalAggregate> outerAgg * LogicalProject (projects = [a as col2, sum(col1) as sum(col1)] * +--LogicalAggregate (outputExpression = [a, sum(c) as sum(col1)], groupByKeys = [a]) */ + /** + * before: + * LogicalAggregate (outputExpressions = [col2, col3, sum(col1)], groupByKeys = [col2, col3]) + * +--LogicalProject (projects = [a as col2, a as col3, col1]) + * +--LogicalAggregate (outputExpressions = [a, b, sum(c) as col1], groupByKeys = [a,b]) + * after: + * LogicalProject (projects = [a as col2, a as col3, sum(col1) as sum(col1)] + * +--LogicalAggregate (outputExpression = [a, sum(c) as sum(col1)], groupByKeys = [a, a]) + */ private Plan mergeAggProjectAgg(LogicalAggregate>> outerAgg) { LogicalProject> project = outerAgg.child(); LogicalAggregate innerAgg = project.child(); @@ -110,8 +120,21 @@ private Plan mergeAggProjectAgg(LogicalAggregate childToAlias = project.getProjects().stream() .filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof SlotReference)) - .collect(Collectors.toMap(alias -> (SlotReference) alias.child(0), alias -> (Alias) alias)); - List projectGroupBy = ExpressionUtils.replace(replacedGroupBy, childToAlias); + .collect(Collectors.toMap(alias -> (SlotReference) alias.child(0), alias -> (Alias) alias, + (existValue, newValue) -> existValue)); + + Map exprIdToNameExpressionMap = new HashMap<>(); + for (NamedExpression pro : project.getProjects()) { + exprIdToNameExpressionMap.put(pro.getExprId(), pro); + } + List originOuterAggGroupBy = outerAgg.getGroupByExpressions(); + List projectGroupBy = Lists.newArrayList(); + for (int i = 0; i < replacedGroupBy.size(); i++) { + ExprId exprId = ((NamedExpression) (originOuterAggGroupBy.get(i))).getExprId(); + NamedExpression namedExpression = exprIdToNameExpressionMap.get(exprId); + projectGroupBy.add(namedExpression); + } + // List projectGroupBy = ExpressionUtils.replace(replacedGroupBy, childToAlias); List upperProjects = ImmutableList.builder() .addAll(projectGroupBy.stream().map(namedExpr -> (NamedExpression) namedExpr).iterator()) .addAll(replacedAggFunc.stream().map(expr -> ((NamedExpression) expr).toSlot()).iterator()) From 207919ef6cf494946e5cae0ba322a89be74c417d Mon Sep 17 00:00:00 2001 From: feiniaofeiafei Date: Fri, 24 May 2024 15:28:48 +0800 Subject: [PATCH 2/4] [Fix](nereids) fix merge aggregate bug --- .../nereids/rules/rewrite/MergeAggregate.java | 24 ++++++------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java index 1b471be9591430..7d21f37810c967 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java @@ -17,8 +17,10 @@ package org.apache.doris.nereids.rules.rewrite; +import org.apache.doris.nereids.annotation.DependsRules; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; @@ -31,7 +33,6 @@ import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.PlanUtils; -import com.clearspring.analytics.util.Lists; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; @@ -45,6 +46,9 @@ import java.util.stream.Collectors; /**MergeAggregate*/ +@DependsRules({ + NormalizeAggregate.class +}) public class MergeAggregate implements RewriteRuleFactory { private static final ImmutableSet ALLOW_MERGE_AGGREGATE_FUNCTIONS = ImmutableSet.of("min", "max", "sum", "any_value"); @@ -87,15 +91,6 @@ private Plan mergeTwoAggregate(LogicalAggregate> outerAgg * LogicalProject (projects = [a as col2, sum(col1) as sum(col1)] * +--LogicalAggregate (outputExpression = [a, sum(c) as sum(col1)], groupByKeys = [a]) */ - /** - * before: - * LogicalAggregate (outputExpressions = [col2, col3, sum(col1)], groupByKeys = [col2, col3]) - * +--LogicalProject (projects = [a as col2, a as col3, col1]) - * +--LogicalAggregate (outputExpressions = [a, b, sum(c) as col1], groupByKeys = [a,b]) - * after: - * LogicalProject (projects = [a as col2, a as col3, sum(col1) as sum(col1)] - * +--LogicalAggregate (outputExpression = [a, sum(c) as sum(col1)], groupByKeys = [a, a]) - */ private Plan mergeAggProjectAgg(LogicalAggregate>> outerAgg) { LogicalProject> project = outerAgg.child(); LogicalAggregate innerAgg = project.child(); @@ -118,18 +113,13 @@ private Plan mergeAggProjectAgg(LogicalAggregate childToAlias = project.getProjects().stream() - .filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof SlotReference)) - .collect(Collectors.toMap(alias -> (SlotReference) alias.child(0), alias -> (Alias) alias, - (existValue, newValue) -> existValue)); - Map exprIdToNameExpressionMap = new HashMap<>(); for (NamedExpression pro : project.getProjects()) { exprIdToNameExpressionMap.put(pro.getExprId(), pro); } List originOuterAggGroupBy = outerAgg.getGroupByExpressions(); - List projectGroupBy = Lists.newArrayList(); - for (int i = 0; i < replacedGroupBy.size(); i++) { + List projectGroupBy = new ArrayList<>(); + for (int i = 0; i < originOuterAggGroupBy.size(); i++) { ExprId exprId = ((NamedExpression) (originOuterAggGroupBy.get(i))).getExprId(); NamedExpression namedExpression = exprIdToNameExpressionMap.get(exprId); projectGroupBy.add(namedExpression); From e6aa9dc8eedfaac8a26aba48cc07a79d53acc33f Mon Sep 17 00:00:00 2001 From: feiniaofeiafei Date: Sun, 26 May 2024 21:21:16 +0800 Subject: [PATCH 3/4] [Fix](nereids) fix merge aggregate bug --- .../nereids_rules_p0/merge_aggregate/merge_aggregate.out | 9 +++++++++ .../merge_aggregate/merge_aggregate.groovy | 6 ++++++ 2 files changed, 15 insertions(+) diff --git a/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out b/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out index fba17e8d7b9c27..d7103bfed9f686 100644 --- a/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out +++ b/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out @@ -297,3 +297,12 @@ PhysicalResultSink ------hashAgg[LOCAL] --------PhysicalOlapScan[mal_test2] +-- !agg_project_agg_the_project_has_duplicate_slot_output -- +1 7 7 +2 4 4 +6 \N \N +7 1 1 +8 2 2 +8 5 5 +9 3 3 + diff --git a/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy b/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy index 039f087c9382bc..4a20cf4d68b97c 100644 --- a/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy +++ b/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy @@ -256,4 +256,10 @@ suite("merge_aggregate") { explain shape plan select sum(col1),min(col2),max(col3) from (select sum(a) col1, min(b) col2, max(pk) col3 from mal_test2 group by a) t; """ + + qt_agg_project_agg_the_project_has_duplicate_slot_output """ + select max(col1), col10, col11 from + (select a,max(b) as col1, count(b) as col4, a as col10, a as col11 + from mal_test1 group by a) t group by col10, col11 order by 1,2,3; + """ } From 3defeb6dd438f56943128e3fcdaab6fcadff3d69 Mon Sep 17 00:00:00 2001 From: feiniaofeiafei Date: Mon, 27 May 2024 12:05:13 +0800 Subject: [PATCH 4/4] [Fix](nereids) fix merge aggregate bug --- .../apache/doris/nereids/rules/rewrite/MergeAggregate.java | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java index 7d21f37810c967..a8b86f54eb9666 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java @@ -119,12 +119,11 @@ private Plan mergeAggProjectAgg(LogicalAggregate originOuterAggGroupBy = outerAgg.getGroupByExpressions(); List projectGroupBy = new ArrayList<>(); - for (int i = 0; i < originOuterAggGroupBy.size(); i++) { - ExprId exprId = ((NamedExpression) (originOuterAggGroupBy.get(i))).getExprId(); + for (Expression expression : originOuterAggGroupBy) { + ExprId exprId = ((NamedExpression) expression).getExprId(); NamedExpression namedExpression = exprIdToNameExpressionMap.get(exprId); projectGroupBy.add(namedExpression); } - // List projectGroupBy = ExpressionUtils.replace(replacedGroupBy, childToAlias); List upperProjects = ImmutableList.builder() .addAll(projectGroupBy.stream().map(namedExpr -> (NamedExpression) namedExpr).iterator()) .addAll(replacedAggFunc.stream().map(expr -> ((NamedExpression) expr).toSlot()).iterator())