Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot;
import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
Expand All @@ -43,6 +44,7 @@
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -244,10 +246,33 @@ private LogicalPlan normalizeAgg(LogicalAggregate<Plan> aggregate, Optional<Logi

// create a parent project node
LogicalProject<Plan> project = new LogicalProject<>(upperProjects, newAggregate);
// verify project used slots are all coming from agg's output
List<Slot> slots = collectAllUsedSlots(upperProjects);
if (!slots.isEmpty()) {
Set<ExprId> aggOutputExprIds = new HashSet<>(slots.size());
for (NamedExpression expression : normalizedAggOutput) {
aggOutputExprIds.add(expression.getExprId());
}
List<Slot> errorSlots = new ArrayList<>(slots.size());
for (Slot slot : slots) {
if (!aggOutputExprIds.contains(slot.getExprId())) {
errorSlots.add(slot);
}
}
if (!errorSlots.isEmpty()) {
throw new AnalysisException(String.format("%s not in aggregate's output", errorSlots
.stream().map(NamedExpression::getName).collect(Collectors.joining(", "))));
}
}
if (having.isPresent()) {
if (upperProjects.stream().anyMatch(expr -> expr.anyMatch(WindowExpression.class::isInstance))) {
// when project contains window functions, in order to get the correct result
// push having through project to make it the parent node of logicalAgg
Set<Slot> havingUsedSlots = ExpressionUtils.getInputSlotSet(having.get().getExpressions());
Set<ExprId> havingUsedExprIds = new HashSet<>(havingUsedSlots.size());
for (Slot slot : havingUsedSlots) {
havingUsedExprIds.add(slot.getExprId());
}
Set<ExprId> aggOutputExprIds = newAggregate.getOutputExprIdSet();
if (aggOutputExprIds.containsAll(havingUsedExprIds)) {
// when having just use output slots from agg, we push down having as parent of agg
return project.withChildren(ImmutableList.of(
new LogicalHaving<>(
ExpressionUtils.replace(having.get().getConjuncts(), project.getAliasToProducer()),
Expand Down Expand Up @@ -287,4 +312,15 @@ private List<NamedExpression> normalizeOutput(List<NamedExpression> aggregateOut
}
return builder.build();
}

private List<Slot> collectAllUsedSlots(List<NamedExpression> expressions) {
Set<Slot> inputSlots = ExpressionUtils.getInputSlotSet(expressions);
List<SubqueryExpr> subqueries = ExpressionUtils.collectAll(expressions, SubqueryExpr.class::isInstance);
List<Slot> slots = new ArrayList<>(inputSlots.size() + subqueries.size());
for (SubqueryExpr subqueryExpr : subqueries) {
slots.addAll(subqueryExpr.getCorrelateSlots());
}
slots.addAll(ExpressionUtils.getInputSlotSet(expressions));
return slots;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ public void testCTEInHavingAndSubquery() {
logicalFilter(
logicalProject(
logicalJoin(
logicalProject(logicalAggregate()),
logicalAggregate(),
logicalProject()
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,9 @@ public void testHavingGroupBySlot() {
PlanChecker.from(connectContext).analyze(sql)
.matches(
logicalFilter(
logicalProject(
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1))))));
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1)))));

sql = "SELECT a1 as value FROM t1 GROUP BY a1 HAVING a1 > 0";
SlotReference value = new SlotReference(new ExprId(3), "value", TinyIntType.INSTANCE, true,
Expand Down Expand Up @@ -134,10 +133,9 @@ public void testHavingGroupBySlot() {
.matches(
logicalProject(
logicalFilter(
logicalProject(
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))))
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(a1, new TinyIntLiteral((byte) 0)))))
).when(FieldChecker.check("projects", Lists.newArrayList(sumA2.toSlot()))));
}
Expand All @@ -158,10 +156,9 @@ public void testHavingAggregateFunction() {
.matches(
logicalProject(
logicalFilter(
logicalProject(
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))))
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA2.toSlot(), Literal.of(0L)))))
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot()))));

Expand All @@ -171,13 +168,12 @@ public void testHavingAggregateFunction() {
.matches(
logicalProject(
logicalFilter(
logicalProject(
logicalAggregate(
logicalProject(
logicalOlapScan()
)
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA2.toSlot(), Literal.of(0L)))))));
logicalAggregate(
logicalProject(
logicalOlapScan()
)
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))))
.when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA2.toSlot(), Literal.of(0L)))))));

sql = "SELECT a1, sum(a2) as value FROM t1 GROUP BY a1 HAVING sum(a2) > 0";
a1 = new SlotReference(
Expand All @@ -193,22 +189,20 @@ public void testHavingAggregateFunction() {
.matches(
logicalProject(
logicalFilter(
logicalProject(
logicalAggregate(
logicalProject(
logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value))))
logicalAggregate(
logicalProject(
logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value)))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), Literal.of(0L)))))));

sql = "SELECT a1, sum(a2) as value FROM t1 GROUP BY a1 HAVING value > 0";
PlanChecker.from(connectContext).analyze(sql)
.matches(
logicalFilter(
logicalProject(
logicalAggregate(
logicalProject(
logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value))))
logicalAggregate(
logicalProject(
logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value)))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), Literal.of(0L))))));

sql = "SELECT a1, sum(a2) FROM t1 GROUP BY a1 HAVING MIN(pk) > 0";
Expand All @@ -230,10 +224,9 @@ public void testHavingAggregateFunction() {
.matches(
logicalProject(
logicalFilter(
logicalProject(
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, minPK))))
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, minPK)))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(minPK.toSlot(), Literal.of((byte) 0)))))
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA2.toSlot()))));

Expand All @@ -243,10 +236,9 @@ public void testHavingAggregateFunction() {
.matches(
logicalProject(
logicalFilter(
logicalProject(
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2))))
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2)))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA1A2.toSlot(), Literal.of(0L)))))));

sql = "SELECT a1, sum(a1 + a2) FROM t1 GROUP BY a1 HAVING sum(a1 + a2 + 3) > 0";
Expand All @@ -256,10 +248,9 @@ public void testHavingAggregateFunction() {
.matches(
logicalProject(
logicalFilter(
logicalProject(
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2, sumA1A23))))
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2, sumA1A23)))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA1A23.toSlot(), Literal.of(0L)))))
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA1A2.toSlot()))));

Expand All @@ -269,10 +260,9 @@ public void testHavingAggregateFunction() {
.matches(
logicalProject(
logicalFilter(
logicalProject(
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, countStar))))
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, countStar)))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(countStar.toSlot(), Literal.of(0L)))))
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot()))));
}
Expand All @@ -298,17 +288,16 @@ void testJoinWithHaving() {
.matches(
logicalProject(
logicalFilter(
logicalProject(
logicalAggregate(
logicalProject(
logicalFilter(
logicalJoin(
logicalOlapScan(),
logicalOlapScan()
)
))
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, sumB1)))
)).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(new Cast(a1, BigIntType.INSTANCE),
logicalAggregate(
logicalProject(
logicalFilter(
logicalJoin(
logicalOlapScan(),
logicalOlapScan()
)
))
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, sumB1)))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(new Cast(a1, BigIntType.INSTANCE),
sumB1.toSlot()))))
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA2.toSlot()))));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,7 @@
2 1 23.0000000000
2 2 23.0000000000

-- !select5 --
1 1 3.0000000000
1 2 3.0000000000

Loading