diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java index a2c23dd9b412b0..a8b86f54eb9666 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java @@ -17,8 +17,10 @@ package org.apache.doris.nereids.rules.rewrite; +import org.apache.doris.nereids.annotation.DependsRules; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; @@ -44,6 +46,9 @@ import java.util.stream.Collectors; /**MergeAggregate*/ +@DependsRules({ + NormalizeAggregate.class +}) public class MergeAggregate implements RewriteRuleFactory { private static final ImmutableSet ALLOW_MERGE_AGGREGATE_FUNCTIONS = ImmutableSet.of("min", "max", "sum", "any_value"); @@ -108,10 +113,17 @@ private Plan mergeAggProjectAgg(LogicalAggregate childToAlias = project.getProjects().stream() - .filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof SlotReference)) - .collect(Collectors.toMap(alias -> (SlotReference) alias.child(0), alias -> (Alias) alias)); - List projectGroupBy = ExpressionUtils.replace(replacedGroupBy, childToAlias); + Map exprIdToNameExpressionMap = new HashMap<>(); + for (NamedExpression pro : project.getProjects()) { + exprIdToNameExpressionMap.put(pro.getExprId(), pro); + } + List originOuterAggGroupBy = outerAgg.getGroupByExpressions(); + List projectGroupBy = new ArrayList<>(); + for (Expression expression : originOuterAggGroupBy) { + ExprId exprId = ((NamedExpression) expression).getExprId(); + NamedExpression namedExpression = exprIdToNameExpressionMap.get(exprId); + projectGroupBy.add(namedExpression); + } List upperProjects = ImmutableList.builder() .addAll(projectGroupBy.stream().map(namedExpr -> (NamedExpression) namedExpr).iterator()) .addAll(replacedAggFunc.stream().map(expr -> ((NamedExpression) expr).toSlot()).iterator()) diff --git a/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out b/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out index fba17e8d7b9c27..d7103bfed9f686 100644 --- a/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out +++ b/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out @@ -297,3 +297,12 @@ PhysicalResultSink ------hashAgg[LOCAL] --------PhysicalOlapScan[mal_test2] +-- !agg_project_agg_the_project_has_duplicate_slot_output -- +1 7 7 +2 4 4 +6 \N \N +7 1 1 +8 2 2 +8 5 5 +9 3 3 + diff --git a/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy b/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy index 039f087c9382bc..4a20cf4d68b97c 100644 --- a/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy +++ b/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy @@ -256,4 +256,10 @@ suite("merge_aggregate") { explain shape plan select sum(col1),min(col2),max(col3) from (select sum(a) col1, min(b) col2, max(pk) col3 from mal_test2 group by a) t; """ + + qt_agg_project_agg_the_project_has_duplicate_slot_output """ + select max(col1), col10, col11 from + (select a,max(b) as col1, count(b) as col4, a as col10, a as col11 + from mal_test1 group by a) t group by col10, col11 order by 1,2,3; + """ }