diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index 2361c2763727bf..3edb35c26932c0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -299,7 +299,9 @@ public class Rewriter extends AbstractBatchJobExecutor { topic("Eliminate GroupBy", topDown(new EliminateGroupBy(), - new MergeAggregate()) + new MergeAggregate(), + // need to adjust min/max/sum nullable attribute after merge aggregate + new AdjustAggregateNullableForEmptySet()) ), topic("Eager aggregation", diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java index 9a0b9f8b5e0353..a2c23dd9b412b0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java @@ -34,10 +34,12 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -87,15 +89,14 @@ private Plan mergeTwoAggregate(LogicalAggregate> outerAgg private Plan mergeAggProjectAgg(LogicalAggregate>> outerAgg) { LogicalProject> project = outerAgg.child(); LogicalAggregate innerAgg = project.child(); - + List outputExpressions = outerAgg.getOutputExpressions(); + List replacedOutputExpressions = PlanUtils.replaceExpressionByProjections( + project.getProjects(), (List) outputExpressions); // rewrite agg function. e.g. max(max) - List aggFunc = outerAgg.getOutputExpressions().stream() + List replacedAggFunc = replacedOutputExpressions.stream() .filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction)) .map(e -> rewriteAggregateFunction(e, innerAggExprIdToAggFunc)) .collect(Collectors.toList()); - // rewrite agg function directly refer to the slot below the project - List replacedAggFunc = PlanUtils.replaceExpressionByProjections(project.getProjects(), - (List) aggFunc); // replace groupByKeys directly refer to the slot below the project List replacedGroupBy = PlanUtils.replaceExpressionByProjections(project.getProjects(), outerAgg.getGroupByExpressions()); @@ -138,13 +139,17 @@ private NamedExpression rewriteAggregateFunction(NamedExpression e, } boolean commonCheck(LogicalAggregate outerAgg, LogicalAggregate innerAgg, - boolean sameGroupBy) { + boolean sameGroupBy, Optional projectOptional) { innerAggExprIdToAggFunc = innerAgg.getOutputExpressions().stream() .filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction)) .collect(Collectors.toMap(NamedExpression::getExprId, value -> (AggregateFunction) value.child(0), (existValue, newValue) -> existValue)); Set aggregateFunctions = outerAgg.getAggregateFunctions(); - for (AggregateFunction outerFunc : aggregateFunctions) { + List replacedAggFunctions = projectOptional.map(project -> + (List) PlanUtils.replaceExpressionByProjections( + projectOptional.get().getProjects(), new ArrayList<>(aggregateFunctions))) + .orElse(new ArrayList<>(aggregateFunctions)); + for (AggregateFunction outerFunc : replacedAggFunctions) { if (!(ALLOW_MERGE_AGGREGATE_FUNCTIONS.contains(outerFunc.getName()))) { return false; } @@ -188,7 +193,7 @@ private boolean canMergeAggregateWithoutProject(LogicalAggregate>> outerAgg) { @@ -206,6 +211,6 @@ private boolean canMergeAggregateWithProject(LogicalAggregate