diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java index 169d5a901a7ce0..2d39852dd1861a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java @@ -82,6 +82,12 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory { public Rule build() { return RuleType.NORMALIZE_REPEAT.build( logicalRepeat(any()).when(LogicalRepeat::canBindVirtualSlot).then(repeat -> { + if (repeat.getGroupingSets().size() == 1 + && ExpressionUtils.collect(repeat.getOutputExpressions(), + GroupingScalarFunction.class::isInstance).isEmpty()) { + return new LogicalAggregate<>(repeat.getGroupByExpressions(), + repeat.getOutputExpressions(), repeat.child()); + } checkRepeatLegality(repeat); repeat = removeDuplicateColumns(repeat); // add virtual slot, LogicalAggregate and LogicalProject for normalize diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeatTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeatTest.java index 3fc2fec9a650b8..556f5279412121 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeatTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeatTest.java @@ -20,6 +20,7 @@ import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingId; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat; @@ -41,7 +42,7 @@ public void testKeepNullableAfterNormalizeRepeat() { Slot name = scan1.getOutput().get(1); Alias alias = new Alias(new Sum(name), "sum(name)"); Plan plan = new LogicalRepeat<>( - ImmutableList.of(ImmutableList.of(id)), + ImmutableList.of(ImmutableList.of(id), ImmutableList.of(name)), ImmutableList.of(idNotNull, alias), scan1 ); @@ -51,4 +52,40 @@ public void testKeepNullableAfterNormalizeRepeat() { logicalRepeat().when(repeat -> repeat.getOutputExpressions().get(0).nullable()) ); } + + @Test + public void testEliminateRepeat() { + Slot id = scan1.getOutput().get(0); + Slot idNotNull = id.withNullable(true); + Slot name = scan1.getOutput().get(1); + Alias alias = new Alias(new Sum(name), "sum(name)"); + Plan plan = new LogicalRepeat<>( + ImmutableList.of(ImmutableList.of(id)), + ImmutableList.of(idNotNull, alias), + scan1 + ); + PlanChecker.from(MemoTestUtils.createCascadesContext(plan)) + .applyTopDown(new NormalizeRepeat()) + .matchesFromRoot( + logicalAggregate(logicalOlapScan()) + ); + } + + @Test + public void testNoEliminateRepeat() { + Slot id = scan1.getOutput().get(0); + Slot idNotNull = id.withNullable(true); + Slot name = scan1.getOutput().get(1); + Alias alias = new Alias(new GroupingId(name), "grouping_id(name)"); + Plan plan = new LogicalRepeat<>( + ImmutableList.of(ImmutableList.of(id)), + ImmutableList.of(idNotNull, alias), + scan1 + ); + PlanChecker.from(MemoTestUtils.createCascadesContext(plan)) + .applyTopDown(new NormalizeRepeat()) + .matchesFromRoot( + logicalAggregate(logicalRepeat(logicalOlapScan())) + ); + } } diff --git a/regression-test/suites/nereids_rules_p0/grouping_sets/grouping_normalize_test.groovy b/regression-test/suites/nereids_rules_p0/grouping_sets/grouping_normalize_test.groovy index 8310685c9c6c40..93821452f2f52f 100644 --- a/regression-test/suites/nereids_rules_p0/grouping_sets/grouping_normalize_test.groovy +++ b/regression-test/suites/nereids_rules_p0/grouping_sets/grouping_normalize_test.groovy @@ -39,4 +39,19 @@ suite("grouping_normalize_test"){ SELECT ROUND( SUM(pk + 1) - 3) col_alias1, MAX( DISTINCT col_int_undef_signed - 5) AS col_alias2, pk + 1 AS col_alias3 FROM grouping_normalize_test GROUP BY GROUPING SETS ((col_int_undef_signed,col_int_undef_signed2,pk),()) order by 1,2,3; """ + + explain { + sql("SELECT col_int_undef_signed, col_int_undef_signed2, SUM(pk) FROM grouping_normalize_test GROUP BY GROUPING SETS ((col_int_undef_signed, col_int_undef_signed2));") + notContains("VREPEAT_NODE") + } + + explain { + sql("SELECT col_int_undef_signed, col_int_undef_signed2, SUM(pk), grouping_id(col_int_undef_signed2) FROM grouping_normalize_test GROUP BY GROUPING SETS ((col_int_undef_signed, col_int_undef_signed2),());") + contains("VREPEAT_NODE") + } + + explain { + sql("SELECT col_int_undef_signed, col_int_undef_signed2, SUM(pk) FROM grouping_normalize_test GROUP BY GROUPING SETS ((col_int_undef_signed, col_int_undef_signed2));") + notContains("VREPEAT_NODE") + } } \ No newline at end of file