From 181da5cee0a2cbde2792accda1eca78866b71322 Mon Sep 17 00:00:00 2001 From: starocean999 <12095047@qq.com> Date: Mon, 3 Jun 2024 18:50:34 +0800 Subject: [PATCH 1/4] [enhancement](nereids)eliminate repeat node if there is only 1 grouping set and no grouping scalar function --- .../doris/nereids/rules/analysis/NormalizeRepeat.java | 8 ++++++++ .../grouping_sets/grouping_normalize_test.groovy | 5 +++++ 2 files changed, 13 insertions(+) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java index 169d5a901a7ce0..f61d42c2bf361e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java @@ -116,6 +116,11 @@ private void checkGroupingSetsSize(LogicalRepeat repeat) { } private LogicalAggregate normalizeRepeat(LogicalRepeat repeat) { + if (repeat.getGroupingSets().size() == 1 && ExpressionUtils + .collect(repeat.getOutputExpressions(), GroupingScalarFunction.class::isInstance).isEmpty()) { + return new LogicalAggregate<>(repeat.getGroupByExpressions(), + repeat.getOutputExpressions(), Optional.empty(), repeat.child()); + } Set needToSlotsGroupingExpr = collectNeedToSlotGroupingExpr(repeat); NormalizeToSlotContext groupingExprContext = buildContext(repeat, needToSlotsGroupingExpr); Map groupingExprMap = groupingExprContext.getNormalizeToSlotMap(); @@ -316,6 +321,9 @@ private Set getExistsAlias(LogicalRepeat repeat, */ private LogicalAggregate dealSlotAppearBothInAggFuncAndGroupingSets( @NotNull LogicalAggregate aggregate) { + if (!(aggregate.child() instanceof LogicalRepeat)) { + return aggregate; + } LogicalRepeat repeat = (LogicalRepeat) aggregate.child(); Map commonSlotToAliasMap = getCommonSlotToAliasMap(repeat, aggregate); if (commonSlotToAliasMap.isEmpty()) { diff --git a/regression-test/suites/nereids_rules_p0/grouping_sets/grouping_normalize_test.groovy b/regression-test/suites/nereids_rules_p0/grouping_sets/grouping_normalize_test.groovy index 8310685c9c6c40..efd0ef2edf0bd4 100644 --- a/regression-test/suites/nereids_rules_p0/grouping_sets/grouping_normalize_test.groovy +++ b/regression-test/suites/nereids_rules_p0/grouping_sets/grouping_normalize_test.groovy @@ -39,4 +39,9 @@ suite("grouping_normalize_test"){ SELECT ROUND( SUM(pk + 1) - 3) col_alias1, MAX( DISTINCT col_int_undef_signed - 5) AS col_alias2, pk + 1 AS col_alias3 FROM grouping_normalize_test GROUP BY GROUPING SETS ((col_int_undef_signed,col_int_undef_signed2,pk),()) order by 1,2,3; """ + + explain { + sql("SELECT col_int_undef_signed, col_int_undef_signed2, SUM(pk) FROM grouping_normalize_test GROUP BY GROUPING SETS ((col_int_undef_signed, col_int_undef_signed2));") + notContains("VREPEAT_NODE") + } } \ No newline at end of file From 4454381ff0a19ac736860de915c5a5f375a609d6 Mon Sep 17 00:00:00 2001 From: starocean999 <12095047@qq.com> Date: Mon, 3 Jun 2024 18:55:49 +0800 Subject: [PATCH 2/4] add more case --- .../grouping_sets/grouping_normalize_test.groovy | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/regression-test/suites/nereids_rules_p0/grouping_sets/grouping_normalize_test.groovy b/regression-test/suites/nereids_rules_p0/grouping_sets/grouping_normalize_test.groovy index efd0ef2edf0bd4..93821452f2f52f 100644 --- a/regression-test/suites/nereids_rules_p0/grouping_sets/grouping_normalize_test.groovy +++ b/regression-test/suites/nereids_rules_p0/grouping_sets/grouping_normalize_test.groovy @@ -44,4 +44,14 @@ suite("grouping_normalize_test"){ sql("SELECT col_int_undef_signed, col_int_undef_signed2, SUM(pk) FROM grouping_normalize_test GROUP BY GROUPING SETS ((col_int_undef_signed, col_int_undef_signed2));") notContains("VREPEAT_NODE") } + + explain { + sql("SELECT col_int_undef_signed, col_int_undef_signed2, SUM(pk), grouping_id(col_int_undef_signed2) FROM grouping_normalize_test GROUP BY GROUPING SETS ((col_int_undef_signed, col_int_undef_signed2),());") + contains("VREPEAT_NODE") + } + + explain { + sql("SELECT col_int_undef_signed, col_int_undef_signed2, SUM(pk) FROM grouping_normalize_test GROUP BY GROUPING SETS ((col_int_undef_signed, col_int_undef_signed2));") + notContains("VREPEAT_NODE") + } } \ No newline at end of file From fe8f62e5cf0e877095dfe1419765b7ba759cc4f4 Mon Sep 17 00:00:00 2001 From: starocean999 <12095047@qq.com> Date: Mon, 3 Jun 2024 22:14:18 +0800 Subject: [PATCH 3/4] fix fe ut --- .../doris/nereids/rules/analysis/NormalizeRepeatTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeatTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeatTest.java index 3fc2fec9a650b8..f1d6df69b2d36b 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeatTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeatTest.java @@ -41,7 +41,7 @@ public void testKeepNullableAfterNormalizeRepeat() { Slot name = scan1.getOutput().get(1); Alias alias = new Alias(new Sum(name), "sum(name)"); Plan plan = new LogicalRepeat<>( - ImmutableList.of(ImmutableList.of(id)), + ImmutableList.of(ImmutableList.of(id), ImmutableList.of(name)), ImmutableList.of(idNotNull, alias), scan1 ); From a6387e4a5545e05885009b0b0b2949a66d3e464f Mon Sep 17 00:00:00 2001 From: starocean999 <12095047@qq.com> Date: Tue, 4 Jun 2024 14:27:35 +0800 Subject: [PATCH 4/4] modify code --- .../rules/analysis/NormalizeRepeat.java | 14 +++---- .../rules/analysis/NormalizeRepeatTest.java | 37 +++++++++++++++++++ 2 files changed, 43 insertions(+), 8 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java index f61d42c2bf361e..2d39852dd1861a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java @@ -82,6 +82,12 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory { public Rule build() { return RuleType.NORMALIZE_REPEAT.build( logicalRepeat(any()).when(LogicalRepeat::canBindVirtualSlot).then(repeat -> { + if (repeat.getGroupingSets().size() == 1 + && ExpressionUtils.collect(repeat.getOutputExpressions(), + GroupingScalarFunction.class::isInstance).isEmpty()) { + return new LogicalAggregate<>(repeat.getGroupByExpressions(), + repeat.getOutputExpressions(), repeat.child()); + } checkRepeatLegality(repeat); repeat = removeDuplicateColumns(repeat); // add virtual slot, LogicalAggregate and LogicalProject for normalize @@ -116,11 +122,6 @@ private void checkGroupingSetsSize(LogicalRepeat repeat) { } private LogicalAggregate normalizeRepeat(LogicalRepeat repeat) { - if (repeat.getGroupingSets().size() == 1 && ExpressionUtils - .collect(repeat.getOutputExpressions(), GroupingScalarFunction.class::isInstance).isEmpty()) { - return new LogicalAggregate<>(repeat.getGroupByExpressions(), - repeat.getOutputExpressions(), Optional.empty(), repeat.child()); - } Set needToSlotsGroupingExpr = collectNeedToSlotGroupingExpr(repeat); NormalizeToSlotContext groupingExprContext = buildContext(repeat, needToSlotsGroupingExpr); Map groupingExprMap = groupingExprContext.getNormalizeToSlotMap(); @@ -321,9 +322,6 @@ private Set getExistsAlias(LogicalRepeat repeat, */ private LogicalAggregate dealSlotAppearBothInAggFuncAndGroupingSets( @NotNull LogicalAggregate aggregate) { - if (!(aggregate.child() instanceof LogicalRepeat)) { - return aggregate; - } LogicalRepeat repeat = (LogicalRepeat) aggregate.child(); Map commonSlotToAliasMap = getCommonSlotToAliasMap(repeat, aggregate); if (commonSlotToAliasMap.isEmpty()) { diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeatTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeatTest.java index f1d6df69b2d36b..556f5279412121 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeatTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeatTest.java @@ -20,6 +20,7 @@ import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingId; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat; @@ -51,4 +52,40 @@ public void testKeepNullableAfterNormalizeRepeat() { logicalRepeat().when(repeat -> repeat.getOutputExpressions().get(0).nullable()) ); } + + @Test + public void testEliminateRepeat() { + Slot id = scan1.getOutput().get(0); + Slot idNotNull = id.withNullable(true); + Slot name = scan1.getOutput().get(1); + Alias alias = new Alias(new Sum(name), "sum(name)"); + Plan plan = new LogicalRepeat<>( + ImmutableList.of(ImmutableList.of(id)), + ImmutableList.of(idNotNull, alias), + scan1 + ); + PlanChecker.from(MemoTestUtils.createCascadesContext(plan)) + .applyTopDown(new NormalizeRepeat()) + .matchesFromRoot( + logicalAggregate(logicalOlapScan()) + ); + } + + @Test + public void testNoEliminateRepeat() { + Slot id = scan1.getOutput().get(0); + Slot idNotNull = id.withNullable(true); + Slot name = scan1.getOutput().get(1); + Alias alias = new Alias(new GroupingId(name), "grouping_id(name)"); + Plan plan = new LogicalRepeat<>( + ImmutableList.of(ImmutableList.of(id)), + ImmutableList.of(idNotNull, alias), + scan1 + ); + PlanChecker.from(MemoTestUtils.createCascadesContext(plan)) + .applyTopDown(new NormalizeRepeat()) + .matchesFromRoot( + logicalAggregate(logicalRepeat(logicalOlapScan())) + ); + } }