From f22cdd949e54d6715c95934e314982f1280b195d Mon Sep 17 00:00:00 2001 From: feiniaofeiafei Date: Tue, 3 Jun 2025 15:36:48 +0800 Subject: [PATCH] [fix](nereids) fix sum0 cannot pass multi distinct check (#51234) ### What problem does this PR solve? Related PR: #32541 Problem Summary: Before this pr, execute below sql will report error: sum0(DISTINCT c#2) can't support multi distinct. This pr change the check, and the sql can be executed. sql is: select sum0(distinct b),sum(distinct c) from test_sum0_multi_distinct_with_group_by group by a --- .../nereids/rules/rewrite/CheckMultiDistinct.java | 12 ++---------- .../data/nereids_syntax_p0/analyze_agg.out | 12 ++++++++++++ .../suites/nereids_syntax_p0/analyze_agg.groovy | 11 +++++++++++ 3 files changed, 25 insertions(+), 10 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckMultiDistinct.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckMultiDistinct.java index dd76457c41181f..d9c953444ba090 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckMultiDistinct.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckMultiDistinct.java @@ -22,15 +22,10 @@ import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.OrderExpression; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; -import org.apache.doris.nereids.trees.expressions.functions.agg.Avg; -import org.apache.doris.nereids.trees.expressions.functions.agg.Count; -import org.apache.doris.nereids.trees.expressions.functions.agg.GroupConcat; -import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.expressions.functions.agg.SupportMultiDistinct; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; -import com.google.common.collect.ImmutableSet; - /** * If there are multiple distinct aggregate functions that cannot * be transformed into multi_distinct, an error is reported. @@ -41,9 +36,6 @@ * - group_concat -> MULTI_DISTINCT_GROUP_CONCAT */ public class CheckMultiDistinct extends OneRewriteRuleFactory { - private final ImmutableSet> supportedFunctions = - ImmutableSet.of(Count.class, Sum.class, Avg.class, GroupConcat.class); - @Override public Rule build() { return logicalAggregate().then(agg -> checkDistinct(agg)).toRule(RuleType.CHECK_ANALYSIS); @@ -53,7 +45,7 @@ private LogicalAggregate checkDistinct(LogicalAggregate aggregat if (aggregate.getDistinctArguments().size() > 1) { for (AggregateFunction func : aggregate.getAggregateFunctions()) { - if (func.isDistinct() && !supportedFunctions.contains(func.getClass())) { + if (func.isDistinct() && !(func instanceof SupportMultiDistinct)) { throw new AnalysisException(func.toString() + " can't support multi distinct."); } } diff --git a/regression-test/data/nereids_syntax_p0/analyze_agg.out b/regression-test/data/nereids_syntax_p0/analyze_agg.out index 9c9c4c6c8a2be1..8316c4aefe20c8 100644 --- a/regression-test/data/nereids_syntax_p0/analyze_agg.out +++ b/regression-test/data/nereids_syntax_p0/analyze_agg.out @@ -1,3 +1,15 @@ -- This file is automatically generated. You should know what you did if you want to edit this -- !sql -- +-- !test_sum0 -- +0 0 +0 3 +0 5 +0 7 +5 21 + +-- !test_sum0_all_null -- +0 3 +0 5 +0 7 + diff --git a/regression-test/suites/nereids_syntax_p0/analyze_agg.groovy b/regression-test/suites/nereids_syntax_p0/analyze_agg.groovy index 9a79df6bad524b..cf93cad471ca4b 100644 --- a/regression-test/suites/nereids_syntax_p0/analyze_agg.groovy +++ b/regression-test/suites/nereids_syntax_p0/analyze_agg.groovy @@ -88,4 +88,15 @@ suite("analyze_agg") { 1, x """ + + sql "drop table if exists test_sum0_multi_distinct_with_group_by" + sql "create table test_sum0_multi_distinct_with_group_by (a int, b int, c int) distributed by hash(a) properties('replication_num'='1');" + sql """ + INSERT INTO test_sum0_multi_distinct_with_group_by VALUES + (1, NULL, 3), (2, NULL, 5), (3, NULL, 7), + (4,5,6),(4,5,7),(4,5,8), + (5,0,0),(5,0,0),(5,0,0); + """ + qt_test_sum0 "select sum0(distinct b),sum(distinct c) from test_sum0_multi_distinct_with_group_by group by a order by 1,2" + qt_test_sum0_all_null "select sum0(distinct b),sum(distinct c) from test_sum0_multi_distinct_with_group_by where a in (1,2,3) group by a order by 1,2" } \ No newline at end of file