Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,14 @@
import org.apache.doris.nereids.rules.rewrite.logical.InferJoinNotNull;
import org.apache.doris.nereids.rules.rewrite.logical.InferPredicates;
import org.apache.doris.nereids.rules.rewrite.logical.InnerToCrossJoin;
import org.apache.doris.nereids.rules.rewrite.logical.LimitPushDown;
import org.apache.doris.nereids.rules.rewrite.logical.MergeFilters;
import org.apache.doris.nereids.rules.rewrite.logical.MergeProjects;
import org.apache.doris.nereids.rules.rewrite.logical.MergeSetOperations;
import org.apache.doris.nereids.rules.rewrite.logical.NormalizeAggregate;
import org.apache.doris.nereids.rules.rewrite.logical.PruneOlapScanPartition;
import org.apache.doris.nereids.rules.rewrite.logical.PruneOlapScanTablet;
import org.apache.doris.nereids.rules.rewrite.logical.PushFilterInsideJoin;
import org.apache.doris.nereids.rules.rewrite.logical.PushdownLimit;
import org.apache.doris.nereids.rules.rewrite.logical.ReorderJoin;

import java.util.List;
Expand Down Expand Up @@ -191,7 +191,7 @@ public class NereidsRewriter extends BatchRewriteJob {
new PruneOlapScanTablet(),
new EliminateAggregate(),
new MergeSetOperations(),
new LimitPushDown(),
new PushdownLimit(),
new BuildAggForUnion()
)),

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ public enum RuleType {
PUSH_LIMIT_THROUGH_JOIN(RuleTypeClass.REWRITE),
PUSH_LIMIT_THROUGH_PROJECT_JOIN(RuleTypeClass.REWRITE),
PUSH_LIMIT_THROUGH_UNION(RuleTypeClass.REWRITE),
PUSH_LIMIT_THROUGH_ONE_ROW_RELATION(RuleTypeClass.REWRITE),
PUSH_LIMIT_THROUGH_EMPTY_RELATION(RuleTypeClass.REWRITE),

// adjust nullable
ADJUST_NULLABLE_ON_AGGREGATE(RuleTypeClass.REWRITE),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
import org.apache.doris.nereids.trees.UnaryNode;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.EmptyRelation;
import org.apache.doris.nereids.trees.plans.algebra.Limit;
import org.apache.doris.nereids.trees.plans.algebra.OneRowRelation;
import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
Expand All @@ -40,24 +39,32 @@
* <p>
* Limit can't be push down if it has a valid offset info.
*/
public class LimitPushDown implements RewriteRuleFactory {
public class PushdownLimit implements RewriteRuleFactory {

@Override
public List<Rule> buildRules() {
return ImmutableList.of(
// limit -> join
logicalLimit(logicalJoin(any(), any())).whenNot(Limit::hasValidOffset)
.then(limit -> limit.withChildren(pushLimitThroughJoin(limit, limit.child())))
.then(limit -> {
Plan newJoin = pushLimitThroughJoin(limit, limit.child());
if (newJoin == null || limit.child().children().equals(newJoin.children())) {
return null;
}
return limit.withChildren(newJoin);
})
.toRule(RuleType.PUSH_LIMIT_THROUGH_JOIN),

// limit -> project -> join
logicalLimit(logicalProject(logicalJoin(any(), any()))).whenNot(Limit::hasValidOffset)
.then(limit -> {
LogicalProject<LogicalJoin<Plan, Plan>> project = limit.child();
LogicalJoin<Plan, Plan> join = project.child();
return limit.withChildren(
project.withChildren(
pushLimitThroughJoin(limit, join)));
Plan newJoin = pushLimitThroughJoin(limit, join);
if (newJoin == null || join.children().equals(newJoin.children())) {
return null;
}
return limit.withChildren(project.withChildren(newJoin));
}).toRule(RuleType.PUSH_LIMIT_THROUGH_PROJECT_JOIN),

// limit -> union
Expand All @@ -67,65 +74,45 @@ public List<Rule> buildRules() {
LogicalUnion union = limit.child();
ImmutableList<Plan> newUnionChildren = union.children()
.stream()
.map(child -> addLimit(limit, child))
.map(child -> limit.withChildren(child))
.collect(ImmutableList.toImmutableList());
if (union.children().equals(newUnionChildren)) {
return null;
}
return limit.withChildren(union.withChildren(newUnionChildren));
})
.toRule(RuleType.PUSH_LIMIT_THROUGH_UNION)
.toRule(RuleType.PUSH_LIMIT_THROUGH_UNION),
logicalLimit(logicalOneRowRelation())
.then(limit -> limit.getLimit() > 0
? limit.child() : new LogicalEmptyRelation(limit.child().getOutput()))
.toRule(RuleType.PUSH_LIMIT_THROUGH_ONE_ROW_RELATION),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this rule name is not good. it do not push limit THROUGH one row relation

logicalLimit(logicalEmptyRelation())
.then(UnaryNode::child)
.toRule(RuleType.PUSH_LIMIT_THROUGH_EMPTY_RELATION),
Comment on lines +89 to +91
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, may be eliminate limit on xxxx is better

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have a file named EliminateLimit

new MergeLimits().build()
);
}

private Plan pushLimitThroughJoin(LogicalLimit<? extends Plan> limit, LogicalJoin<Plan, Plan> join) {
switch (join.getJoinType()) {
case LEFT_OUTER_JOIN:
return join.withChildren(
addLimit(limit, join.left()),
limit.withChildren(join.left()),
join.right()
);
case RIGHT_OUTER_JOIN:
return join.withChildren(
join.left(),
addLimit(limit, join.right())
limit.withChildren(join.right())
);
case CROSS_JOIN:
return join.withChildren(
addLimit(limit, join.left()),
addLimit(limit, join.right())
limit.withChildren(join.left()),
limit.withChildren(join.right())
);
case INNER_JOIN:
if (join.hasJoinCondition()) {
return join;
} else {
return join.withChildren(
addLimit(limit, join.left()),
addLimit(limit, join.right())
);
}
default:
// don't push limit.
return join;
}
}

private Plan addLimit(LogicalLimit<? extends Plan> pushdownLimit, Plan plan) {
if (plan instanceof LogicalLimit) {
// Avoid adding duplicate limits on top of the plan, otherwise would result in dead loop
// when applying the rule multiple times.
LogicalLimit<? extends Plan> limit = (LogicalLimit<? extends Plan>) plan;
// plan is pure limit and limit value > push down limit value
if (!limit.hasValidOffset() && limit.getLimit() > pushdownLimit.getLimit()) {
// replace limit.
return pushdownLimit.withChildren(limit.child());
} else {
// return input plan.
return plan;
}
} else if (plan instanceof OneRowRelation) {
return pushdownLimit.getLimit() > 0 ? plan : new LogicalEmptyRelation(plan.getOutput());
} else if (plan instanceof EmptyRelation) {
return plan;
} else {
return pushdownLimit.withChildren(plan);
return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
import java.util.function.Function;
import java.util.stream.Collectors;

class LimitPushDownTest extends TestWithFeService implements MemoPatternMatchSupported {
class PushdownLimitTest extends TestWithFeService implements MemoPatternMatchSupported {
private Plan scanScore = new LogicalOlapScan(RelationUtil.newRelationId(), PlanConstructor.score);
private Plan scanStudent = new LogicalOlapScan(RelationUtil.newRelationId(), PlanConstructor.student);

Expand Down Expand Up @@ -173,7 +173,7 @@ public void testPushLimitThroughInnerJoin() {
logicalJoin(
logicalLimit(logicalOlapScan().when(s -> s.equals(scanScore))),
logicalLimit(logicalOlapScan().when(s -> s.equals(scanStudent)))
).when(j -> j.getJoinType() == JoinType.INNER_JOIN)
)
)
)
);
Expand All @@ -182,7 +182,7 @@ public void testPushLimitThroughInnerJoin() {
logicalJoin(
logicalLimit(logicalOlapScan().when(s -> s.equals(scanScore))),
logicalLimit(logicalOlapScan().when(s -> s.equals(scanStudent)))
).when(j -> j.getJoinType() == JoinType.INNER_JOIN)
)
)
);
}
Expand Down Expand Up @@ -241,7 +241,8 @@ private void test(JoinType joinType, boolean hasProject, PatternDescriptor<? ext
Plan plan = generatePlan(joinType, hasProject);
PlanChecker.from(MemoTestUtils.createConnectContext())
.analyze(plan)
.applyTopDown(new LimitPushDown())
.applyTopDown(new InnerToCrossJoin())
.applyTopDown(new PushdownLimit())
.matchesFromRoot(pattern);
}

Expand Down