From c15292d4cf0d25c52078fe9a3c2c2dc202c6c2ce Mon Sep 17 00:00:00 2001 From: feiniaofeiafei Date: Tue, 11 Jun 2024 21:37:48 +0800 Subject: [PATCH 1/2] [Fix](nereids) fix merge aggregate rule, rules should not have mutable members --- .../doris/nereids/rules/rewrite/MergeAggregate.java | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java index 23fbd9786568eb..bd3f1e4a5a886d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java @@ -52,7 +52,6 @@ public class MergeAggregate implements RewriteRuleFactory { private static final ImmutableSet ALLOW_MERGE_AGGREGATE_FUNCTIONS = ImmutableSet.of("min", "max", "sum", "any_value"); - private Map innerAggExprIdToAggFunc = new HashMap<>(); @Override public List buildRules() { @@ -75,7 +74,10 @@ public List buildRules() { */ private Plan mergeTwoAggregate(LogicalAggregate> outerAgg) { LogicalAggregate innerAgg = outerAgg.child(); - + Map innerAggExprIdToAggFunc = innerAgg.getOutputExpressions().stream() + .filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction)) + .collect(Collectors.toMap(NamedExpression::getExprId, value -> (AggregateFunction) value.child(0), + (existValue, newValue) -> existValue)); List newOutputExpressions = outerAgg.getOutputExpressions().stream() .map(e -> rewriteAggregateFunction(e, innerAggExprIdToAggFunc)) .collect(Collectors.toList()); @@ -97,6 +99,10 @@ private Plan mergeAggProjectAgg(LogicalAggregate outputExpressions = outerAgg.getOutputExpressions(); List replacedOutputExpressions = PlanUtils.replaceExpressionByProjections( project.getProjects(), (List) outputExpressions); + Map innerAggExprIdToAggFunc = innerAgg.getOutputExpressions().stream() + .filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction)) + .collect(Collectors.toMap(NamedExpression::getExprId, value -> (AggregateFunction) value.child(0), + (existValue, newValue) -> existValue)); // rewrite agg function. e.g. max(max) List replacedAggFunc = replacedOutputExpressions.stream() .filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction)) @@ -152,7 +158,7 @@ private NamedExpression rewriteAggregateFunction(NamedExpression e, private boolean commonCheck(LogicalAggregate outerAgg, LogicalAggregate innerAgg, boolean sameGroupBy, Optional projectOptional) { - innerAggExprIdToAggFunc = innerAgg.getOutputExpressions().stream() + Map innerAggExprIdToAggFunc = innerAgg.getOutputExpressions().stream() .filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction)) .collect(Collectors.toMap(NamedExpression::getExprId, value -> (AggregateFunction) value.child(0), (existValue, newValue) -> existValue)); From fc2a170d015f77d49d0ac2867e11c68d65fe4605 Mon Sep 17 00:00:00 2001 From: feiniaofeiafei Date: Wed, 12 Jun 2024 17:52:51 +0800 Subject: [PATCH 2/2] [Fix](nereids) fix merge aggregate rule, rules should not have mutable members --- .../nereids/rules/rewrite/MergeAggregate.java | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java index bd3f1e4a5a886d..4b9e745ee201f4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java @@ -74,10 +74,7 @@ public List buildRules() { */ private Plan mergeTwoAggregate(LogicalAggregate> outerAgg) { LogicalAggregate innerAgg = outerAgg.child(); - Map innerAggExprIdToAggFunc = innerAgg.getOutputExpressions().stream() - .filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction)) - .collect(Collectors.toMap(NamedExpression::getExprId, value -> (AggregateFunction) value.child(0), - (existValue, newValue) -> existValue)); + Map innerAggExprIdToAggFunc = getInnerAggExprIdToAggFuncMap(innerAgg); List newOutputExpressions = outerAgg.getOutputExpressions().stream() .map(e -> rewriteAggregateFunction(e, innerAggExprIdToAggFunc)) .collect(Collectors.toList()); @@ -99,10 +96,7 @@ private Plan mergeAggProjectAgg(LogicalAggregate outputExpressions = outerAgg.getOutputExpressions(); List replacedOutputExpressions = PlanUtils.replaceExpressionByProjections( project.getProjects(), (List) outputExpressions); - Map innerAggExprIdToAggFunc = innerAgg.getOutputExpressions().stream() - .filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction)) - .collect(Collectors.toMap(NamedExpression::getExprId, value -> (AggregateFunction) value.child(0), - (existValue, newValue) -> existValue)); + Map innerAggExprIdToAggFunc = getInnerAggExprIdToAggFuncMap(innerAgg); // rewrite agg function. e.g. max(max) List replacedAggFunc = replacedOutputExpressions.stream() .filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction)) @@ -158,10 +152,7 @@ private NamedExpression rewriteAggregateFunction(NamedExpression e, private boolean commonCheck(LogicalAggregate outerAgg, LogicalAggregate innerAgg, boolean sameGroupBy, Optional projectOptional) { - Map innerAggExprIdToAggFunc = innerAgg.getOutputExpressions().stream() - .filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction)) - .collect(Collectors.toMap(NamedExpression::getExprId, value -> (AggregateFunction) value.child(0), - (existValue, newValue) -> existValue)); + Map innerAggExprIdToAggFunc = getInnerAggExprIdToAggFuncMap(innerAgg); Set aggregateFunctions = outerAgg.getAggregateFunctions(); List replacedAggFunctions = projectOptional.map(project -> (List) PlanUtils.replaceExpressionByProjections( @@ -231,4 +222,11 @@ private boolean canMergeAggregateWithProject(LogicalAggregate getInnerAggExprIdToAggFuncMap(LogicalAggregate innerAgg) { + return innerAgg.getOutputExpressions().stream() + .filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction)) + .collect(Collectors.toMap(NamedExpression::getExprId, value -> (AggregateFunction) value.child(0), + (existValue, newValue) -> existValue)); + } }