diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java index 49c91b92942c15..73fb853082ced4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java @@ -25,9 +25,11 @@ import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.functions.scalar.BitmapContains; import org.apache.doris.nereids.trees.plans.GroupPlan; +import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.util.JoinUtils; +import org.apache.doris.planner.NestedLoopJoinNode; import org.apache.doris.qe.ConnectContext; import org.apache.doris.thrift.TRuntimeFilterType; @@ -66,6 +68,11 @@ public Rule build() { // commuting nest loop mark join or left anti mark join is not supported by be .whenNot(join -> join.isMarkJoin() && (join.getHashJoinConjuncts().isEmpty() || join.getJoinType().isLeftAntiJoin())) + // For a nested loop join, if commutativity causes a join that could originally be executed + // in parallel to become non-parallelizable, then we reject this swap. + .whenNot(join -> JoinUtils.shouldNestedLoopJoin(join) + && NestedLoopJoinNode.canParallelize(JoinType.toJoinOperator(join.getJoinType())) + && !NestedLoopJoinNode.canParallelize(JoinType.toJoinOperator(join.getJoinType().swap()))) .then(join -> { LogicalJoin newJoin = join.withTypeChildren(join.getJoinType().swap(), join.right(), join.left(), null); diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/NestedLoopJoinNode.java b/fe/fe-core/src/main/java/org/apache/doris/planner/NestedLoopJoinNode.java index 983cbfd5884c69..30c0a2d0394fa7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/planner/NestedLoopJoinNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/planner/NestedLoopJoinNode.java @@ -73,12 +73,16 @@ public NestedLoopJoinNode(PlanNodeId id, PlanNode outer, PlanNode inner, TableRe tupleIds.addAll(inner.getOutputTupleIds()); } - public boolean canParallelize() { + public static boolean canParallelize(JoinOperator joinOp) { return joinOp == JoinOperator.CROSS_JOIN || joinOp == JoinOperator.INNER_JOIN || joinOp == JoinOperator.LEFT_OUTER_JOIN || joinOp == JoinOperator.LEFT_SEMI_JOIN || joinOp == JoinOperator.LEFT_ANTI_JOIN || joinOp == JoinOperator.NULL_AWARE_LEFT_ANTI_JOIN; } + public boolean canParallelize() { + return canParallelize(joinOp); + } + public void setJoinConjuncts(List joinConjuncts) { this.joinConjuncts = joinConjuncts; } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteTest.java index 20323d108e64e3..18235b3ce4c649 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteTest.java @@ -18,7 +18,9 @@ package org.apache.doris.nereids.rules.exploration.join; import org.apache.doris.common.Pair; +import org.apache.doris.nereids.trees.expressions.GreaterThan; import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.util.LogicalPlanBuilder; @@ -27,11 +29,13 @@ import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; public class JoinCommuteTest implements MemoPatternMatchSupported { @Test - public void testInnerJoinCommute() { + void testInnerJoinCommute() { LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); @@ -51,4 +55,22 @@ public void testInnerJoinCommute() { ) ; } + + @Test + void testParallelJoinCommute() { + LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); + + LogicalJoin join = (LogicalJoin) new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0)) + .build(); + join = join.withJoinConjuncts( + ImmutableList.of(), + ImmutableList.of(new GreaterThan(scan1.getOutput().get(0), scan2.getOutput().get(0))), + join.getJoinReorderContext()); + + Assertions.assertEquals(1, PlanChecker.from(MemoTestUtils.createConnectContext(), join) + .applyExploration(JoinCommute.BUSHY.build()) + .getAllPlan().size()); + } }