From d0f454423999c1e55a1f0f3de291710773a2c1da Mon Sep 17 00:00:00 2001 From: feiniaofeiafei Date: Fri, 19 Apr 2024 16:34:47 +0800 Subject: [PATCH 1/4] [Fix](nereids) fix rule merge_aggregate when has project --- .../nereids/rules/rewrite/MergeAggregate.java | 14 +++-- .../merge_aggregate/merge_aggregate.out | 26 ++++++++++ .../merge_aggregate/merge_aggregate.groovy | 52 +++++++++++++++++++ 3 files changed, 88 insertions(+), 4 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java index 9a0b9f8b5e0353..3fac433e722f3d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java @@ -34,10 +34,12 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -138,13 +140,17 @@ private NamedExpression rewriteAggregateFunction(NamedExpression e, } boolean commonCheck(LogicalAggregate outerAgg, LogicalAggregate innerAgg, - boolean sameGroupBy) { + boolean sameGroupBy, Optional projectOptional) { innerAggExprIdToAggFunc = innerAgg.getOutputExpressions().stream() .filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction)) .collect(Collectors.toMap(NamedExpression::getExprId, value -> (AggregateFunction) value.child(0), (existValue, newValue) -> existValue)); Set aggregateFunctions = outerAgg.getAggregateFunctions(); - for (AggregateFunction outerFunc : aggregateFunctions) { + List replacedAggFunctions = projectOptional.map(project -> + (List) PlanUtils.replaceExpressionByProjections( + projectOptional.get().getProjects(), new ArrayList<>(aggregateFunctions))) + .orElse(new ArrayList<>(aggregateFunctions)); + for (AggregateFunction outerFunc : replacedAggFunctions) { if (!(ALLOW_MERGE_AGGREGATE_FUNCTIONS.contains(outerFunc.getName()))) { return false; } @@ -188,7 +194,7 @@ private boolean canMergeAggregateWithoutProject(LogicalAggregate>> outerAgg) { @@ -206,6 +212,6 @@ private boolean canMergeAggregateWithProject(LogicalAggregate Date: Fri, 19 Apr 2024 18:05:38 +0800 Subject: [PATCH 2/4] [Fix](nereids) fix rule merge_aggregate when has project --- .../nereids/rules/rewrite/MergeAggregate.java | 9 ++++---- .../merge_aggregate/merge_aggregate.out | 19 +++++++++++++-- .../merge_aggregate/merge_aggregate.groovy | 23 +++++++++++++++++-- 3 files changed, 42 insertions(+), 9 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java index 3fac433e722f3d..a2c23dd9b412b0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java @@ -89,15 +89,14 @@ private Plan mergeTwoAggregate(LogicalAggregate> outerAgg private Plan mergeAggProjectAgg(LogicalAggregate>> outerAgg) { LogicalProject> project = outerAgg.child(); LogicalAggregate innerAgg = project.child(); - + List outputExpressions = outerAgg.getOutputExpressions(); + List replacedOutputExpressions = PlanUtils.replaceExpressionByProjections( + project.getProjects(), (List) outputExpressions); // rewrite agg function. e.g. max(max) - List aggFunc = outerAgg.getOutputExpressions().stream() + List replacedAggFunc = replacedOutputExpressions.stream() .filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction)) .map(e -> rewriteAggregateFunction(e, innerAggExprIdToAggFunc)) .collect(Collectors.toList()); - // rewrite agg function directly refer to the slot below the project - List replacedAggFunc = PlanUtils.replaceExpressionByProjections(project.getProjects(), - (List) aggFunc); // replace groupByKeys directly refer to the slot below the project List replacedGroupBy = PlanUtils.replaceExpressionByProjections(project.getProjects(), outerAgg.getGroupByExpressions()); diff --git a/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out b/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out index 8d816232afa714..abaccf0356828f 100644 --- a/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out +++ b/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out @@ -261,10 +261,10 @@ PhysicalResultSink ----------------PhysicalProject ------------------PhysicalOlapScan[mal_test_merge_agg] --- !test_has_project_distinct_expr_transform -- +-- !test_distinct_expr_transform -- -1 --- !test_has_project_distinct_expr_transform_shape -- +-- !test_distinct_expr_transform_shape -- PhysicalResultSink --hashAgg[GLOBAL] ----PhysicalDistribute[DistributionSpecGather] @@ -272,3 +272,18 @@ PhysicalResultSink --------PhysicalProject ----------PhysicalOlapScan[mal_test_merge_agg] +-- !test_has_project_distinct_expr_transform -- +1 +1 +1 + +-- !test_has_project_distinct_expr_transform -- +PhysicalResultSink +--PhysicalDistribute[DistributionSpecGather] +----PhysicalProject +------hashAgg[GLOBAL] +--------PhysicalDistribute[DistributionSpecHash] +----------hashAgg[LOCAL] +------------PhysicalProject +--------------PhysicalOlapScan[mal_test_merge_agg] + diff --git a/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy b/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy index aeb020b69801fd..2b9d8746e47b16 100644 --- a/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy +++ b/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy @@ -209,7 +209,7 @@ suite("merge_aggregate") { ) t ; """ - qt_test_has_project_distinct_expr_transform """ + qt_test_distinct_expr_transform """ select max(count_col) from ( select k4, @@ -217,7 +217,7 @@ suite("merge_aggregate") { from mal_test_merge_agg group by k4 ) t ; """ - qt_test_has_project_distinct_expr_transform_shape """ + qt_test_distinct_expr_transform_shape """ explain shape plan select max(count_col) from ( @@ -226,4 +226,23 @@ suite("merge_aggregate") { from mal_test_merge_agg group by k4 ) t ; """ + + qt_test_has_project_distinct_expr_transform """ + select sum(count_col) + from ( + select k4, + count(distinct case when k3 is null then 1 else 0 end) as count_col + from mal_test_merge_agg group by k4 + ) t group by k4; + """ + + qt_test_has_project_distinct_expr_transform """ + explain shape plan + select sum(count_col) + from ( + select k4, + count(distinct case when k3 is null then 1 else 0 end) as count_col + from mal_test_merge_agg group by k4 + ) t group by k4; + """ } From f95e4248e13feb4aa7a47f05e5b20b69ab22e12a Mon Sep 17 00:00:00 2001 From: feiniaofeiafei Date: Mon, 22 Apr 2024 11:49:06 +0800 Subject: [PATCH 3/4] adjust nullable after merge aggregate --- .../apache/doris/nereids/jobs/executor/Rewriter.java | 3 ++- .../merge_aggregate/merge_aggregate.out | 10 ++++++++++ .../merge_aggregate/merge_aggregate.groovy | 9 +++++++++ 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index 2361c2763727bf..dc93879fd43ca9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -299,7 +299,8 @@ public class Rewriter extends AbstractBatchJobExecutor { topic("Eliminate GroupBy", topDown(new EliminateGroupBy(), - new MergeAggregate()) + new MergeAggregate(), + new AdjustAggregateNullableForEmptySet()) ), topic("Eager aggregation", diff --git a/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out b/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out index abaccf0356828f..fba17e8d7b9c27 100644 --- a/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out +++ b/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out @@ -287,3 +287,13 @@ PhysicalResultSink ------------PhysicalProject --------------PhysicalOlapScan[mal_test_merge_agg] +-- !test_sum_empty_table -- +\N \N \N + +-- !test_sum_empty_table_shape -- +PhysicalResultSink +--hashAgg[GLOBAL] +----PhysicalDistribute[DistributionSpecGather] +------hashAgg[LOCAL] +--------PhysicalOlapScan[mal_test2] + diff --git a/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy b/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy index 2b9d8746e47b16..46cd4a0a9b78e8 100644 --- a/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy +++ b/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy @@ -245,4 +245,13 @@ suite("merge_aggregate") { from mal_test_merge_agg group by k4 ) t group by k4; """ + + qt_test_sum_empty_table """ + select sum(col1),min(col2),max(col3) from (select sum(a) col1, min(b) col2, max(pk) col3 from mal_test2 group by a) t; + """ + + qt_test_sum_empty_table_shape """ + explain shape plan + select sum(col1),min(col2),max(col3) from (select sum(a) col1, min(b) col2, max(pk) col3 from mal_test2 group by a) t; + """ } From 429ee561c05151b369f2b88f8cf192a53d4b287c Mon Sep 17 00:00:00 2001 From: feiniaofeiafei Date: Wed, 24 Apr 2024 16:31:38 +0800 Subject: [PATCH 4/4] add comment --- .../java/org/apache/doris/nereids/jobs/executor/Rewriter.java | 1 + 1 file changed, 1 insertion(+) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index dc93879fd43ca9..3edb35c26932c0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -300,6 +300,7 @@ public class Rewriter extends AbstractBatchJobExecutor { topic("Eliminate GroupBy", topDown(new EliminateGroupBy(), new MergeAggregate(), + // need to adjust min/max/sum nullable attribute after merge aggregate new AdjustAggregateNullableForEmptySet()) ),