From 2b10d1a4be19ab4883fbd939ea708d1cd723ef5f Mon Sep 17 00:00:00 2001 From: xiejiann Date: Tue, 28 May 2024 22:46:57 +0800 Subject: [PATCH] don't rewrite sum once --- .../rules/rewrite/SumLiteralRewrite.java | 25 +++++++++++++-- .../rules/rewrite/SumLiteralRewriteTest.java | 31 +++++++++++++++++++ 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewrite.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewrite.java index c99071a714e7d5..dcc64ce2c1d9cd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewrite.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewrite.java @@ -44,6 +44,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.Objects; import java.util.Set; @@ -64,13 +65,33 @@ public Rule build() { } sumLiteralMap.put(pel.first, pel.second); } - if (sumLiteralMap.isEmpty()) { + Map> validSumLiteralMap = + removeOneSumLiteral(sumLiteralMap); + if (validSumLiteralMap.isEmpty()) { return null; } - return rewriteSumLiteral(agg, sumLiteralMap); + return rewriteSumLiteral(agg, validSumLiteralMap); }).toRule(RuleType.SUM_LITERAL_REWRITE); } + // when there only one sum literal like select count(id1 + 1), count(id2 + 1) from t, we don't rewrite them. + private Map> removeOneSumLiteral( + Map> sumLiteralMap) { + Map countSum = new HashMap<>(); + for (Entry> e : sumLiteralMap.entrySet()) { + Expression expr = e.getValue().first.expr; + countSum.merge(expr, 1, Integer::sum); + } + Map> validSumLiteralMap = new HashMap<>(); + for (Entry> e : sumLiteralMap.entrySet()) { + Expression expr = e.getValue().first.expr; + if (countSum.get(expr) > 1) { + validSumLiteralMap.put(e.getKey(), e.getValue()); + } + } + return validSumLiteralMap; + } + private Plan rewriteSumLiteral( LogicalAggregate agg, Map> sumLiteralMap) { Set newAggOutput = new HashSet<>(); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewriteTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewriteTest.java index cb2cc77627e16e..19ea7b864fb9b1 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewriteTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewriteTest.java @@ -112,4 +112,35 @@ void testSumDistinct() { .printlnTree() .matches(logicalAggregate().when(p -> p.getOutputs().size() == 4)); } + + @Test + void testSumOnce() { + Slot slot1 = scan1.getOutput().get(0); + Alias add1 = new Alias(new Sum(false, true, new Add(slot1, Literal.of(1)))); + LogicalAggregate agg = new LogicalAggregate<>( + ImmutableList.of(scan1.getOutput().get(0)), ImmutableList.of(add1), scan1); + PlanChecker.from(MemoTestUtils.createConnectContext(), agg) + .applyTopDown(ImmutableList.of(new SumLiteralRewrite().build())) + .printlnTree() + .matches(logicalAggregate().when(p -> p.getOutputs().size() == 1)); + + Slot slot2 = new Alias(scan1.getOutput().get(0)).toSlot(); + Alias add2 = new Alias(new Sum(false, true, new Add(slot2, Literal.of(2)))); + agg = new LogicalAggregate<>( + ImmutableList.of(scan1.getOutput().get(0)), ImmutableList.of(add1, add2), scan1); + PlanChecker.from(MemoTestUtils.createConnectContext(), agg) + .applyTopDown(ImmutableList.of(new SumLiteralRewrite().build())) + .printlnTree() + .matches(logicalAggregate().when(p -> p.getOutputs().size() == 2)); + + Alias add3 = new Alias(new Sum(false, true, new Add(slot1, Literal.of(3)))); + Alias add4 = new Alias(new Sum(false, true, new Add(slot1, Literal.of(4)))); + agg = new LogicalAggregate<>( + ImmutableList.of(scan1.getOutput().get(0)), ImmutableList.of(add1, add2, add3, add4), scan1); + PlanChecker.from(MemoTestUtils.createConnectContext(), agg) + .applyTopDown(ImmutableList.of(new SumLiteralRewrite().build())) + .printlnTree() + .matches(logicalAggregate().when(p -> p.getOutputs().size() == 3)); + + } }