From 87a917fc76fbf02c7f125a7d48930c6b1936d251 Mon Sep 17 00:00:00 2001 From: starocean999 <12095047@qq.com> Date: Fri, 9 Aug 2024 15:41:27 +0800 Subject: [PATCH 1/2] [fix](nereids)semi join transpose rule produce wrong plan if there is mark join --- .../LogicalJoinSemiJoinTransposeProject.java | 4 ++++ .../SemiJoinSemiJoinTransposeProject.java | 15 +++--------- ...gicalJoinSemiJoinTransposeProjectTest.java | 24 +++++++++++++++++++ .../SemiJoinSemiJoinTransposeProjectTest.java | 10 ++++---- .../nereids/util/LogicalPlanBuilder.java | 11 +++++++++ 5 files changed, 47 insertions(+), 17 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/LogicalJoinSemiJoinTransposeProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/LogicalJoinSemiJoinTransposeProject.java index a0e2b83cc1b39a..0531c6e54aca77 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/LogicalJoinSemiJoinTransposeProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/LogicalJoinSemiJoinTransposeProject.java @@ -68,6 +68,8 @@ public List buildRules() { .forEach(e -> topUsedExprIds.addAll(e.getInputSlotExprIds())); bottomJoin.getOtherJoinConjuncts() .forEach(e -> topUsedExprIds.addAll(e.getInputSlotExprIds())); + bottomJoin.getMarkJoinConjuncts() + .forEach(e -> topUsedExprIds.addAll(e.getInputSlotExprIds())); Plan newBottomJoin = topJoin.withChildrenNoContext(a, c, null); Plan left = CBOUtils.newProject(topUsedExprIds, newBottomJoin); Plan right = CBOUtils.newProjectIfNeeded(topUsedExprIds, b); @@ -100,6 +102,8 @@ public List buildRules() { .forEach(e -> topUsedExprIds.addAll(e.getInputSlotExprIds())); bottomJoin.getOtherJoinConjuncts() .forEach(e -> topUsedExprIds.addAll(e.getInputSlotExprIds())); + bottomJoin.getMarkJoinConjuncts() + .forEach(e -> topUsedExprIds.addAll(e.getInputSlotExprIds())); Plan newBottomJoin = topJoin.withChildrenNoContext(a, b, null); Plan left = CBOUtils.newProject(topUsedExprIds, newBottomJoin); Plan right = CBOUtils.newProjectIfNeeded(topUsedExprIds, c); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProject.java index b4a5b177f8c0a9..5f0ee68058f588 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProject.java @@ -20,19 +20,17 @@ import org.apache.doris.common.Pair; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; -import org.apache.doris.nereids.rules.exploration.CBOUtils; import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.JoinType; -import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Lists; -import java.util.HashSet; import java.util.Set; import java.util.stream.Collectors; @@ -100,15 +98,8 @@ public Rule build() { newBottomSemi.getJoinReorderContext().setHasCommute(false); newBottomSemi.getJoinReorderContext().setHasLAsscom(false); - Set topUsedExprIds = new HashSet<>(); - topProject.getProjects().forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds())); - bottomSemi.getHashJoinConjuncts().forEach(e -> topUsedExprIds.addAll(e.getInputSlotExprIds())); - bottomSemi.getOtherJoinConjuncts().forEach(e -> topUsedExprIds.addAll(e.getInputSlotExprIds())); - - Plan left = CBOUtils.newProject(topUsedExprIds, newBottomSemi); - Plan right = CBOUtils.newProjectIfNeeded(topUsedExprIds, b); - - LogicalJoin newTopSemi = bottomSemi.withChildrenNoContext(left, right, null); + LogicalProject acProject = new LogicalProject<>(Lists.newArrayList(acProjects), newBottomSemi); + LogicalJoin newTopSemi = bottomSemi.withChildrenNoContext(acProject, b, null); newTopSemi.getJoinReorderContext().copyFrom(topSemi.getJoinReorderContext()); newTopSemi.getJoinReorderContext().setHasLAsscom(true); return topProject.withChildren(newTopSemi); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/LogicalJoinSemiJoinTransposeProjectTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/LogicalJoinSemiJoinTransposeProjectTest.java index 27e162f4af79f1..70f6de6c320ac0 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/LogicalJoinSemiJoinTransposeProjectTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/LogicalJoinSemiJoinTransposeProjectTest.java @@ -105,4 +105,28 @@ public void generateTopProject() { ) ); } + + @Test + public void generateTopProjectMarkJoin() { + LogicalPlan topJoin1 = new LogicalPlanBuilder(scan1) + .markJoinWithMarkConjuncts(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0)) // t1.id = t2.id + .project(ImmutableList.of(1)) + .join(scan3, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = t3.id + .project(ImmutableList.of(0)) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin1) + .applyExploration(LogicalJoinSemiJoinTransposeProject.INSTANCE.buildRules()) + .matchesExploration( + logicalProject( + leftSemiLogicalJoin( + logicalProject(innerLogicalJoin( + logicalOlapScan().when(scan -> scan.getTable().getName().equals("t1")), + logicalOlapScan().when(scan -> scan.getTable().getName().equals("t3")) + )), + logicalProject(logicalOlapScan().when(scan -> scan.getTable().getName().equals("t2"))) + ) + ) + ); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProjectTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProjectTest.java index dd654a2f42839d..970503f074fa09 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProjectTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProjectTest.java @@ -65,7 +65,7 @@ public void testSemiProjectSemiCommute() { logicalOlapScan().when(scan -> scan.getTable().getName().equals("t1")), logicalOlapScan().when(scan -> scan.getTable().getName().equals("t3")) ).when(join -> join.getJoinType() == JoinType.LEFT_SEMI_JOIN)), - logicalProject(logicalOlapScan().when(scan -> scan.getTable().getName().equals("t2"))) + logicalOlapScan().when(scan -> scan.getTable().getName().equals("t2")) ).when(join -> join.getJoinType() == JoinType.LEFT_ANTI_JOIN) ) ); @@ -74,10 +74,10 @@ public void testSemiProjectSemiCommute() { @Test public void testSemiProjectSemiCommuteMarkJoin() { LogicalPlan topJoin = new LogicalPlanBuilder(scan1) - .markJoin(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0)) + .markJoinWithMarkConjuncts(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0)) .project(ImmutableList.of(0, 2)) - .markJoin(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 1)) - .projectAll() + .markJoinWithMarkConjuncts(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 1)) + .project(ImmutableList.of(0)) .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin) .applyExploration(SemiJoinSemiJoinTransposeProject.INSTANCE.build()) @@ -90,7 +90,7 @@ public void testSemiProjectSemiCommuteMarkJoin() { logicalOlapScan().when(scan -> scan.getTable().getName().equals("t3")) ).when(join -> join.getJoinType() == JoinType.LEFT_SEMI_JOIN) ).when(project -> project.getProjects().size() == 2), - logicalProject(logicalOlapScan().when(scan -> scan.getTable().getName().equals("t2"))) + logicalOlapScan().when(scan -> scan.getTable().getName().equals("t2")) ).when(join -> join.getJoinType() == JoinType.LEFT_SEMI_JOIN) ) ); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java index ba81fa7e4a39cf..ccdcd253279571 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java @@ -120,6 +120,17 @@ public LogicalPlanBuilder markJoin(LogicalPlan right, JoinType joinType, Pair hashOnSlots) { + ImmutableList markConjuncts = ImmutableList.of( + new EqualTo(this.plan.getOutput().get(hashOnSlots.first), right.getOutput().get(hashOnSlots.second))); + + LogicalJoin join = new LogicalJoin<>(joinType, Collections.emptyList(), + Collections.emptyList(), new ArrayList<>(markConjuncts), + new DistributeHint(DistributeType.NONE), Optional.of(new MarkJoinSlotReference("fake")), + this.plan, right, null); + return from(join); + } + public LogicalPlanBuilder join(LogicalPlan right, JoinType joinType, Pair hashOnSlots) { ImmutableList hashConjuncts = ImmutableList.of( new EqualTo(this.plan.getOutput().get(hashOnSlots.first), right.getOutput().get(hashOnSlots.second))); From 0c89fbcaf06071d35d419692da7127714ebff4c8 Mon Sep 17 00:00:00 2001 From: starocean999 <12095047@qq.com> Date: Fri, 9 Aug 2024 20:36:13 +0800 Subject: [PATCH 2/2] fix failed case --- .../join/SemiJoinSemiJoinTransposeProject.java | 16 +++++++++++++--- .../SemiJoinSemiJoinTransposeProjectTest.java | 6 +++--- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProject.java index 5f0ee68058f588..359d6e13552c18 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProject.java @@ -20,17 +20,19 @@ import org.apache.doris.common.Pair; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.exploration.CBOUtils; import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Lists; +import java.util.HashSet; import java.util.Set; import java.util.stream.Collectors; @@ -98,8 +100,16 @@ public Rule build() { newBottomSemi.getJoinReorderContext().setHasCommute(false); newBottomSemi.getJoinReorderContext().setHasLAsscom(false); - LogicalProject acProject = new LogicalProject<>(Lists.newArrayList(acProjects), newBottomSemi); - LogicalJoin newTopSemi = bottomSemi.withChildrenNoContext(acProject, b, null); + Set topUsedExprIds = new HashSet<>(); + topProject.getProjects().forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds())); + bottomSemi.getHashJoinConjuncts().forEach(e -> topUsedExprIds.addAll(e.getInputSlotExprIds())); + bottomSemi.getOtherJoinConjuncts().forEach(e -> topUsedExprIds.addAll(e.getInputSlotExprIds())); + bottomSemi.getMarkJoinConjuncts().forEach(e -> topUsedExprIds.addAll(e.getInputSlotExprIds())); + + Plan left = CBOUtils.newProject(topUsedExprIds, newBottomSemi); + Plan right = CBOUtils.newProjectIfNeeded(topUsedExprIds, b); + + LogicalJoin newTopSemi = bottomSemi.withChildrenNoContext(left, right, null); newTopSemi.getJoinReorderContext().copyFrom(topSemi.getJoinReorderContext()); newTopSemi.getJoinReorderContext().setHasLAsscom(true); return topProject.withChildren(newTopSemi); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProjectTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProjectTest.java index 970503f074fa09..d37be0a1a13fe5 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProjectTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProjectTest.java @@ -65,7 +65,7 @@ public void testSemiProjectSemiCommute() { logicalOlapScan().when(scan -> scan.getTable().getName().equals("t1")), logicalOlapScan().when(scan -> scan.getTable().getName().equals("t3")) ).when(join -> join.getJoinType() == JoinType.LEFT_SEMI_JOIN)), - logicalOlapScan().when(scan -> scan.getTable().getName().equals("t2")) + logicalProject(logicalOlapScan().when(scan -> scan.getTable().getName().equals("t2"))) ).when(join -> join.getJoinType() == JoinType.LEFT_ANTI_JOIN) ) ); @@ -77,7 +77,7 @@ public void testSemiProjectSemiCommuteMarkJoin() { .markJoinWithMarkConjuncts(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0)) .project(ImmutableList.of(0, 2)) .markJoinWithMarkConjuncts(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 1)) - .project(ImmutableList.of(0)) + .project(ImmutableList.of(1, 2)) .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin) .applyExploration(SemiJoinSemiJoinTransposeProject.INSTANCE.build()) @@ -90,7 +90,7 @@ public void testSemiProjectSemiCommuteMarkJoin() { logicalOlapScan().when(scan -> scan.getTable().getName().equals("t3")) ).when(join -> join.getJoinType() == JoinType.LEFT_SEMI_JOIN) ).when(project -> project.getProjects().size() == 2), - logicalOlapScan().when(scan -> scan.getTable().getName().equals("t2")) + logicalProject(logicalOlapScan().when(scan -> scan.getTable().getName().equals("t2"))) ).when(join -> join.getJoinType() == JoinType.LEFT_SEMI_JOIN) ) );