diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java index 9be4b89e57c968..6691674426270c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java @@ -148,7 +148,7 @@ public class CascadesContext implements ScheduleContext { */ private CascadesContext(Optional parent, Optional currentTree, StatementContext statementContext, Plan plan, Memo memo, - CTEContext cteContext, PhysicalProperties requireProperties) { + CTEContext cteContext, PhysicalProperties requireProperties, boolean isLeadingDisableJoinReorder) { this.parent = Objects.requireNonNull(parent, "parent should not null"); this.currentTree = Objects.requireNonNull(currentTree, "currentTree should not null"); this.statementContext = Objects.requireNonNull(statementContext, "statementContext should not null"); @@ -169,6 +169,10 @@ private CascadesContext(Optional parent, Optional curren } else { this.isEnableExprTrace = false; } + if (parent.isPresent()) { + this.tables = parent.get().tables; + } + this.isLeadingDisableJoinReorder = isLeadingDisableJoinReorder; } /** @@ -177,7 +181,7 @@ private CascadesContext(Optional parent, Optional curren public static CascadesContext initContext(StatementContext statementContext, Plan initPlan, PhysicalProperties requireProperties) { return newContext(Optional.empty(), Optional.empty(), statementContext, - initPlan, new CTEContext(), requireProperties); + initPlan, new CTEContext(), requireProperties, false); } /** @@ -186,14 +190,15 @@ public static CascadesContext initContext(StatementContext statementContext, public static CascadesContext newContextWithCteContext(CascadesContext cascadesContext, Plan initPlan, CTEContext cteContext) { return newContext(Optional.of(cascadesContext), Optional.empty(), - cascadesContext.getStatementContext(), initPlan, cteContext, PhysicalProperties.ANY + cascadesContext.getStatementContext(), initPlan, cteContext, PhysicalProperties.ANY, + cascadesContext.isLeadingDisableJoinReorder ); } public static CascadesContext newCurrentTreeContext(CascadesContext context) { return CascadesContext.newContext(context.getParent(), context.getCurrentTree(), context.getStatementContext(), context.getRewritePlan(), context.getCteContext(), - context.getCurrentJobContext().getRequiredProperties()); + context.getCurrentJobContext().getRequiredProperties(), context.isLeadingDisableJoinReorder); } /** @@ -202,14 +207,14 @@ public static CascadesContext newCurrentTreeContext(CascadesContext context) { public static CascadesContext newSubtreeContext(Optional subtree, CascadesContext context, Plan plan, PhysicalProperties requireProperties) { return CascadesContext.newContext(Optional.of(context), subtree, context.getStatementContext(), - plan, context.getCteContext(), requireProperties); + plan, context.getCteContext(), requireProperties, context.isLeadingDisableJoinReorder); } private static CascadesContext newContext(Optional parent, Optional subtree, StatementContext statementContext, Plan initPlan, CTEContext cteContext, - PhysicalProperties requireProperties) { + PhysicalProperties requireProperties, boolean isLeadingDisableJoinReorder) { return new CascadesContext(parent, subtree, statementContext, initPlan, null, - cteContext, requireProperties); + cteContext, requireProperties, isLeadingDisableJoinReorder); } public CascadesContext getRoot() { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AnalyzeCTE.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AnalyzeCTE.java index 0fe083c1e93230..129b0860a74ee4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AnalyzeCTE.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AnalyzeCTE.java @@ -72,6 +72,7 @@ public Rule build() { CascadesContext outerCascadesCtx = CascadesContext.newContextWithCteContext( ctx.cascadesContext, logicalCTE.child(), result.first); outerCascadesCtx.newAnalyzer().analyze(); + ctx.cascadesContext.setLeadingDisableJoinReorder(outerCascadesCtx.isLeadingDisableJoinReorder()); Plan root = outerCascadesCtx.getRewritePlan(); // should construct anchor from back to front, because the cte behind depends on the front for (int i = result.second.size() - 1; i >= 0; i--) { diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeCTETest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeCTETest.java index abe82c858d4e38..a91c0dd47126fc 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeCTETest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeCTETest.java @@ -68,6 +68,11 @@ public class AnalyzeCTETest extends TestWithFeService implements MemoPatternMatc + "cte2 AS (SELECT sk FROM cte1 WHERE sk < 3)" + "SELECT * FROM cte1 JOIN cte2 ON cte1.sk = cte2.sk"; + private final String cteLeadingJoin = "WITH cte1 AS (SELECT /*+ leading(supplier customer) */ s_suppkey AS sk " + + "FROM supplier join customer on c_nation = s_nation), " + + "cte2 AS (SELECT sk FROM cte1 WHERE sk < 3)" + + "SELECT /*+ leading(cte2 cte1) */ * FROM cte1 JOIN cte2 ON cte1.sk = cte2.sk"; + private final String cteReferToAnotherOne = "WITH V1 AS (SELECT s_suppkey FROM supplier), " + "V2 AS (SELECT s_suppkey FROM V1)" + "SELECT * FROM V2"; @@ -128,6 +133,15 @@ public List getExplorationRules() { } } + @Test + public void testLeadingCte() throws Exception { + StatementScopeIdGenerator.clear(); + StatementContext statementContext = MemoTestUtils.createStatementContext(connectContext, cteLeadingJoin); + NereidsPlanner planner = new NereidsPlanner(statementContext); + planner.planWithLock(parser.parseSingle(cteLeadingJoin), PhysicalProperties.ANY); + Assertions.assertTrue(planner.getCascadesContext().isLeadingDisableJoinReorder()); + } + @Test public void testCTEInHavingAndSubquery() { diff --git a/regression-test/data/nereids_hint_tpcds_p0/shape/query64.out b/regression-test/data/nereids_hint_tpcds_p0/shape/query64.out index 9af385e063936f..26a67aa0d6e85a 100644 --- a/regression-test/data/nereids_hint_tpcds_p0/shape/query64.out +++ b/regression-test/data/nereids_hint_tpcds_p0/shape/query64.out @@ -100,7 +100,7 @@ PhysicalCteAnchor ( cteId=CTEId#1 ) ------------------PhysicalCteConsumer ( cteId=CTEId#1 ) Hint log: -Used: leading(catalog_sales catalog_returns ) leading({ store_sales { { customer d2 } cd2 } } cd1 d3 item { hd1 ib1 } store_returns ad1 hd2 ad2 ib2 d1 store promotion cs_ui ) leading(cs1 cs2 ) -UnUsed: +Used: leading(catalog_sales shuffle catalog_returns ) leading({ store_sales { { customer d2 } cd2 } } cd1 d3 item { hd1 ib1 } store_returns ad1 hd2 ad2 ib2 d1 store promotion cs_ui ) leading(cs1 shuffle cs2 ) +UnUsed: SyntaxError: diff --git a/regression-test/data/nereids_hint_tpcds_p0/shape/query81.out b/regression-test/data/nereids_hint_tpcds_p0/shape/query81.out index 465ebcbaaafba1..fcbe4a8ad57c34 100644 --- a/regression-test/data/nereids_hint_tpcds_p0/shape/query81.out +++ b/regression-test/data/nereids_hint_tpcds_p0/shape/query81.out @@ -24,15 +24,15 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ----------PhysicalProject ------------hashJoin[INNER_JOIN broadcast] hashCondition=((ctr1.ctr_state = ctr2.ctr_state)) otherCondition=((cast(ctr_total_return as DOUBLE) > cast((avg(cast(ctr_total_return as DECIMALV3(38, 4))) * 1.2) as DOUBLE))) build RFs:RF4 ctr_state->[ctr_state] --------------PhysicalProject -----------------hashJoin[INNER_JOIN shuffle] hashCondition=((ctr1.ctr_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF3 c_customer_sk->[ctr_customer_sk] -------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF3 RF4 +----------------hashJoin[INNER_JOIN shuffle] hashCondition=((ctr1.ctr_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF3 ctr_customer_sk->[c_customer_sk] ------------------PhysicalProject --------------------hashJoin[INNER_JOIN broadcast] hashCondition=((customer_address.ca_address_sk = customer.c_current_addr_sk)) otherCondition=() build RFs:RF2 ca_address_sk->[c_current_addr_sk] ----------------------PhysicalProject -------------------------PhysicalOlapScan[customer] apply RFs: RF2 +------------------------PhysicalOlapScan[customer] apply RFs: RF2 RF3 ----------------------PhysicalProject ------------------------filter((customer_address.ca_state = 'TN')) --------------------------PhysicalOlapScan[customer_address] +------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF4 --------------hashAgg[GLOBAL] ----------------PhysicalDistribute[DistributionSpecHash] ------------------hashAgg[LOCAL]