diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckMultiDistinct.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckMultiDistinct.java index dd76457c41181f..d9c953444ba090 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckMultiDistinct.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckMultiDistinct.java @@ -22,15 +22,10 @@ import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.OrderExpression; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; -import org.apache.doris.nereids.trees.expressions.functions.agg.Avg; -import org.apache.doris.nereids.trees.expressions.functions.agg.Count; -import org.apache.doris.nereids.trees.expressions.functions.agg.GroupConcat; -import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.expressions.functions.agg.SupportMultiDistinct; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; -import com.google.common.collect.ImmutableSet; - /** * If there are multiple distinct aggregate functions that cannot * be transformed into multi_distinct, an error is reported. @@ -41,9 +36,6 @@ * - group_concat -> MULTI_DISTINCT_GROUP_CONCAT */ public class CheckMultiDistinct extends OneRewriteRuleFactory { - private final ImmutableSet> supportedFunctions = - ImmutableSet.of(Count.class, Sum.class, Avg.class, GroupConcat.class); - @Override public Rule build() { return logicalAggregate().then(agg -> checkDistinct(agg)).toRule(RuleType.CHECK_ANALYSIS); @@ -53,7 +45,7 @@ private LogicalAggregate checkDistinct(LogicalAggregate aggregat if (aggregate.getDistinctArguments().size() > 1) { for (AggregateFunction func : aggregate.getAggregateFunctions()) { - if (func.isDistinct() && !supportedFunctions.contains(func.getClass())) { + if (func.isDistinct() && !(func instanceof SupportMultiDistinct)) { throw new AnalysisException(func.toString() + " can't support multi distinct."); } } diff --git a/regression-test/data/nereids_syntax_p0/analyze_agg.out b/regression-test/data/nereids_syntax_p0/analyze_agg.out index 9c9c4c6c8a2be1..8316c4aefe20c8 100644 --- a/regression-test/data/nereids_syntax_p0/analyze_agg.out +++ b/regression-test/data/nereids_syntax_p0/analyze_agg.out @@ -1,3 +1,15 @@ -- This file is automatically generated. You should know what you did if you want to edit this -- !sql -- +-- !test_sum0 -- +0 0 +0 3 +0 5 +0 7 +5 21 + +-- !test_sum0_all_null -- +0 3 +0 5 +0 7 + diff --git a/regression-test/suites/nereids_syntax_p0/analyze_agg.groovy b/regression-test/suites/nereids_syntax_p0/analyze_agg.groovy index 9a79df6bad524b..cf93cad471ca4b 100644 --- a/regression-test/suites/nereids_syntax_p0/analyze_agg.groovy +++ b/regression-test/suites/nereids_syntax_p0/analyze_agg.groovy @@ -88,4 +88,15 @@ suite("analyze_agg") { 1, x """ + + sql "drop table if exists test_sum0_multi_distinct_with_group_by" + sql "create table test_sum0_multi_distinct_with_group_by (a int, b int, c int) distributed by hash(a) properties('replication_num'='1');" + sql """ + INSERT INTO test_sum0_multi_distinct_with_group_by VALUES + (1, NULL, 3), (2, NULL, 5), (3, NULL, 7), + (4,5,6),(4,5,7),(4,5,8), + (5,0,0),(5,0,0),(5,0,0); + """ + qt_test_sum0 "select sum0(distinct b),sum(distinct c) from test_sum0_multi_distinct_with_group_by group by a order by 1,2" + qt_test_sum0_all_null "select sum0(distinct b),sum(distinct c) from test_sum0_multi_distinct_with_group_by where a in (1,2,3) group by a order by 1,2" } \ No newline at end of file