diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/LogicalJoinSemiJoinTransposeProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/LogicalJoinSemiJoinTransposeProject.java index a0e2b83cc1b39a..0531c6e54aca77 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/LogicalJoinSemiJoinTransposeProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/LogicalJoinSemiJoinTransposeProject.java @@ -68,6 +68,8 @@ public List buildRules() { .forEach(e -> topUsedExprIds.addAll(e.getInputSlotExprIds())); bottomJoin.getOtherJoinConjuncts() .forEach(e -> topUsedExprIds.addAll(e.getInputSlotExprIds())); + bottomJoin.getMarkJoinConjuncts() + .forEach(e -> topUsedExprIds.addAll(e.getInputSlotExprIds())); Plan newBottomJoin = topJoin.withChildrenNoContext(a, c, null); Plan left = CBOUtils.newProject(topUsedExprIds, newBottomJoin); Plan right = CBOUtils.newProjectIfNeeded(topUsedExprIds, b); @@ -100,6 +102,8 @@ public List buildRules() { .forEach(e -> topUsedExprIds.addAll(e.getInputSlotExprIds())); bottomJoin.getOtherJoinConjuncts() .forEach(e -> topUsedExprIds.addAll(e.getInputSlotExprIds())); + bottomJoin.getMarkJoinConjuncts() + .forEach(e -> topUsedExprIds.addAll(e.getInputSlotExprIds())); Plan newBottomJoin = topJoin.withChildrenNoContext(a, b, null); Plan left = CBOUtils.newProject(topUsedExprIds, newBottomJoin); Plan right = CBOUtils.newProjectIfNeeded(topUsedExprIds, c); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProject.java index b4a5b177f8c0a9..359d6e13552c18 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProject.java @@ -104,6 +104,7 @@ public Rule build() { topProject.getProjects().forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds())); bottomSemi.getHashJoinConjuncts().forEach(e -> topUsedExprIds.addAll(e.getInputSlotExprIds())); bottomSemi.getOtherJoinConjuncts().forEach(e -> topUsedExprIds.addAll(e.getInputSlotExprIds())); + bottomSemi.getMarkJoinConjuncts().forEach(e -> topUsedExprIds.addAll(e.getInputSlotExprIds())); Plan left = CBOUtils.newProject(topUsedExprIds, newBottomSemi); Plan right = CBOUtils.newProjectIfNeeded(topUsedExprIds, b); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/LogicalJoinSemiJoinTransposeProjectTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/LogicalJoinSemiJoinTransposeProjectTest.java index 27e162f4af79f1..70f6de6c320ac0 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/LogicalJoinSemiJoinTransposeProjectTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/LogicalJoinSemiJoinTransposeProjectTest.java @@ -105,4 +105,28 @@ public void generateTopProject() { ) ); } + + @Test + public void generateTopProjectMarkJoin() { + LogicalPlan topJoin1 = new LogicalPlanBuilder(scan1) + .markJoinWithMarkConjuncts(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0)) // t1.id = t2.id + .project(ImmutableList.of(1)) + .join(scan3, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = t3.id + .project(ImmutableList.of(0)) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin1) + .applyExploration(LogicalJoinSemiJoinTransposeProject.INSTANCE.buildRules()) + .matchesExploration( + logicalProject( + leftSemiLogicalJoin( + logicalProject(innerLogicalJoin( + logicalOlapScan().when(scan -> scan.getTable().getName().equals("t1")), + logicalOlapScan().when(scan -> scan.getTable().getName().equals("t3")) + )), + logicalProject(logicalOlapScan().when(scan -> scan.getTable().getName().equals("t2"))) + ) + ) + ); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProjectTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProjectTest.java index dd654a2f42839d..d37be0a1a13fe5 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProjectTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProjectTest.java @@ -74,10 +74,10 @@ public void testSemiProjectSemiCommute() { @Test public void testSemiProjectSemiCommuteMarkJoin() { LogicalPlan topJoin = new LogicalPlanBuilder(scan1) - .markJoin(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0)) + .markJoinWithMarkConjuncts(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0)) .project(ImmutableList.of(0, 2)) - .markJoin(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 1)) - .projectAll() + .markJoinWithMarkConjuncts(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 1)) + .project(ImmutableList.of(1, 2)) .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin) .applyExploration(SemiJoinSemiJoinTransposeProject.INSTANCE.build()) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java index ba81fa7e4a39cf..ccdcd253279571 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java @@ -120,6 +120,17 @@ public LogicalPlanBuilder markJoin(LogicalPlan right, JoinType joinType, Pair hashOnSlots) { + ImmutableList markConjuncts = ImmutableList.of( + new EqualTo(this.plan.getOutput().get(hashOnSlots.first), right.getOutput().get(hashOnSlots.second))); + + LogicalJoin join = new LogicalJoin<>(joinType, Collections.emptyList(), + Collections.emptyList(), new ArrayList<>(markConjuncts), + new DistributeHint(DistributeType.NONE), Optional.of(new MarkJoinSlotReference("fake")), + this.plan, right, null); + return from(join); + } + public LogicalPlanBuilder join(LogicalPlan right, JoinType joinType, Pair hashOnSlots) { ImmutableList hashConjuncts = ImmutableList.of( new EqualTo(this.plan.getOutput().get(hashOnSlots.first), right.getOutput().get(hashOnSlots.second)));