diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java index 13415c504f1ff4..16ad3bc0bd1597 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java @@ -85,7 +85,8 @@ public class StatementContext { private final Map> cteIdToConsumerUnderProjects = new HashMap<>(); // Used to update consumer's stats private final Map, Group>>> cteIdToConsumerGroup = new HashMap<>(); - private final Map rewrittenCtePlan = new HashMap<>(); + private final Map rewrittenCteProducer = new HashMap<>(); + private final Map rewrittenCteConsumer = new HashMap<>(); private final Map hintMap = Maps.newLinkedHashMap(); private final Set viewDdlSqlSet = Sets.newHashSet(); @@ -230,8 +231,12 @@ public Map, Group>>> getCteIdToConsumerGroup() return cteIdToConsumerGroup; } - public Map getRewrittenCtePlan() { - return rewrittenCtePlan; + public Map getRewrittenCteProducer() { + return rewrittenCteProducer; + } + + public Map getRewrittenCteConsumer() { + return rewrittenCteConsumer; } public void addViewDdlSql(String ddlSql) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/CostBasedRewriteJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/CostBasedRewriteJob.java index 2e5132f4ddd4ed..2a7f0903b2501c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/CostBasedRewriteJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/CostBasedRewriteJob.java @@ -89,7 +89,7 @@ private Optional> getCost(CascadesContext currentCtx CascadesContext rootCtx = currentCtx.getRoot(); if (rootCtx.getRewritePlan() instanceof LogicalCTEAnchor) { // set subtree rewrite cache - currentCtx.getStatementContext().getRewrittenCtePlan() + currentCtx.getStatementContext().getRewrittenCteProducer() .put(currentCtx.getCurrentTree().orElse(null), (LogicalPlan) cboCtx.getRewritePlan()); // Do Whole tree rewrite CascadesContext rootCtxCopy = CascadesContext.newCurrentTreeContext(rootCtx); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/RewriteCteChildren.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/RewriteCteChildren.java index 5aa286e67f9c27..3318f9990d8b40 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/RewriteCteChildren.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/RewriteCteChildren.java @@ -77,14 +77,14 @@ public Plan visit(Plan plan, CascadesContext context) { public Plan visitLogicalCTEAnchor(LogicalCTEAnchor cteAnchor, CascadesContext cascadesContext) { LogicalPlan outer; - if (cascadesContext.getStatementContext().getRewrittenCtePlan().containsKey(null)) { - outer = cascadesContext.getStatementContext().getRewrittenCtePlan().get(null); + if (cascadesContext.getStatementContext().getRewrittenCteConsumer().containsKey(cteAnchor.getCteId())) { + outer = cascadesContext.getStatementContext().getRewrittenCteProducer().get(cteAnchor.getCteId()); } else { CascadesContext outerCascadesCtx = CascadesContext.newSubtreeContext( Optional.empty(), cascadesContext, cteAnchor.child(1), cascadesContext.getCurrentJobContext().getRequiredProperties()); outer = (LogicalPlan) cteAnchor.child(1).accept(this, outerCascadesCtx); - cascadesContext.getStatementContext().getRewrittenCtePlan().put(null, outer); + cascadesContext.getStatementContext().getRewrittenCteConsumer().put(cteAnchor.getCteId(), outer); } boolean reserveAnchor = outer.anyMatch(p -> { if (p instanceof LogicalCTEConsumer) { @@ -104,8 +104,8 @@ public Plan visitLogicalCTEAnchor(LogicalCTEAnchor cteProducer, CascadesContext cascadesContext) { LogicalPlan child; - if (cascadesContext.getStatementContext().getRewrittenCtePlan().containsKey(cteProducer.getCteId())) { - child = cascadesContext.getStatementContext().getRewrittenCtePlan().get(cteProducer.getCteId()); + if (cascadesContext.getStatementContext().getRewrittenCteProducer().containsKey(cteProducer.getCteId())) { + child = cascadesContext.getStatementContext().getRewrittenCteProducer().get(cteProducer.getCteId()); } else { child = (LogicalPlan) cteProducer.child(); child = tryToConstructFilter(cascadesContext, cteProducer.getCteId(), child); @@ -118,7 +118,7 @@ public Plan visitLogicalCTEProducer(LogicalCTEProducer cteProduc CascadesContext rewrittenCtx = CascadesContext.newSubtreeContext( Optional.of(cteProducer.getCteId()), cascadesContext, child, PhysicalProperties.ANY); child = (LogicalPlan) child.accept(this, rewrittenCtx); - cascadesContext.getStatementContext().getRewrittenCtePlan().put(cteProducer.getCteId(), child); + cascadesContext.getStatementContext().getRewrittenCteProducer().put(cteProducer.getCteId(), child); } return cteProducer.withChildren(child); } diff --git a/regression-test/suites/nereids_syntax_p0/cte.groovy b/regression-test/suites/nereids_syntax_p0/cte.groovy index 2fffefd8067898..ba945569919394 100644 --- a/regression-test/suites/nereids_syntax_p0/cte.groovy +++ b/regression-test/suites/nereids_syntax_p0/cte.groovy @@ -324,5 +324,10 @@ suite("cte") { ) tab WHERE Id IN (1, 2) """ + + // rewrite cte children should work well with cost based rewrite rule. rely on rewrite rule: InferSetOperatorDistinct + sql """ + WITH cte_0 AS ( SELECT 1 AS a ), cte_1 AS ( SELECT 1 AS a ) select * from cte_0, cte_1 union select * from cte_0, cte_1 + """ }