Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Set;

Expand All @@ -64,13 +65,33 @@ public Rule build() {
}
sumLiteralMap.put(pel.first, pel.second);
}
if (sumLiteralMap.isEmpty()) {
Map<NamedExpression, Pair<SumInfo, Literal>> validSumLiteralMap =
removeOneSumLiteral(sumLiteralMap);
if (validSumLiteralMap.isEmpty()) {
return null;
}
return rewriteSumLiteral(agg, sumLiteralMap);
return rewriteSumLiteral(agg, validSumLiteralMap);
}).toRule(RuleType.SUM_LITERAL_REWRITE);
}

// when there only one sum literal like select count(id1 + 1), count(id2 + 1) from t, we don't rewrite them.
private Map<NamedExpression, Pair<SumInfo, Literal>> removeOneSumLiteral(
Map<NamedExpression, Pair<SumInfo, Literal>> sumLiteralMap) {
Map<Expression, Integer> countSum = new HashMap<>();
for (Entry<NamedExpression, Pair<SumInfo, Literal>> e : sumLiteralMap.entrySet()) {
Expression expr = e.getValue().first.expr;
countSum.merge(expr, 1, Integer::sum);
}
Map<NamedExpression, Pair<SumInfo, Literal>> validSumLiteralMap = new HashMap<>();
for (Entry<NamedExpression, Pair<SumInfo, Literal>> e : sumLiteralMap.entrySet()) {
Expression expr = e.getValue().first.expr;
if (countSum.get(expr) > 1) {
validSumLiteralMap.put(e.getKey(), e.getValue());
}
}
return validSumLiteralMap;
}

private Plan rewriteSumLiteral(
LogicalAggregate<?> agg, Map<NamedExpression, Pair<SumInfo, Literal>> sumLiteralMap) {
Set<NamedExpression> newAggOutput = new HashSet<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,35 @@ void testSumDistinct() {
.printlnTree()
.matches(logicalAggregate().when(p -> p.getOutputs().size() == 4));
}

@Test
void testSumOnce() {
Slot slot1 = scan1.getOutput().get(0);
Alias add1 = new Alias(new Sum(false, true, new Add(slot1, Literal.of(1))));
LogicalAggregate<?> agg = new LogicalAggregate<>(
ImmutableList.of(scan1.getOutput().get(0)), ImmutableList.of(add1), scan1);
PlanChecker.from(MemoTestUtils.createConnectContext(), agg)
.applyTopDown(ImmutableList.of(new SumLiteralRewrite().build()))
.printlnTree()
.matches(logicalAggregate().when(p -> p.getOutputs().size() == 1));

Slot slot2 = new Alias(scan1.getOutput().get(0)).toSlot();
Alias add2 = new Alias(new Sum(false, true, new Add(slot2, Literal.of(2))));
agg = new LogicalAggregate<>(
ImmutableList.of(scan1.getOutput().get(0)), ImmutableList.of(add1, add2), scan1);
PlanChecker.from(MemoTestUtils.createConnectContext(), agg)
.applyTopDown(ImmutableList.of(new SumLiteralRewrite().build()))
.printlnTree()
.matches(logicalAggregate().when(p -> p.getOutputs().size() == 2));

Alias add3 = new Alias(new Sum(false, true, new Add(slot1, Literal.of(3))));
Alias add4 = new Alias(new Sum(false, true, new Add(slot1, Literal.of(4))));
agg = new LogicalAggregate<>(
ImmutableList.of(scan1.getOutput().get(0)), ImmutableList.of(add1, add2, add3, add4), scan1);
PlanChecker.from(MemoTestUtils.createConnectContext(), agg)
.applyTopDown(ImmutableList.of(new SumLiteralRewrite().build()))
.printlnTree()
.matches(logicalAggregate().when(p -> p.getOutputs().size() == 3));

}
}