Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.scalar.BitmapContains;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.JoinUtils;
import org.apache.doris.planner.NestedLoopJoinNode;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.thrift.TRuntimeFilterType;

Expand Down Expand Up @@ -66,6 +68,11 @@ public Rule build() {
// commuting nest loop mark join or left anti mark join is not supported by be
.whenNot(join -> join.isMarkJoin() && (join.getHashJoinConjuncts().isEmpty()
|| join.getJoinType().isLeftAntiJoin()))
// For a nested loop join, if commutativity causes a join that could originally be executed
// in parallel to become non-parallelizable, then we reject this swap.
.whenNot(join -> JoinUtils.shouldNestedLoopJoin(join)
&& NestedLoopJoinNode.canParallelize(JoinType.toJoinOperator(join.getJoinType()))
&& !NestedLoopJoinNode.canParallelize(JoinType.toJoinOperator(join.getJoinType().swap())))
.then(join -> {
LogicalJoin<Plan, Plan> newJoin = join.withTypeChildren(join.getJoinType().swap(),
join.right(), join.left(), null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,16 @@ public NestedLoopJoinNode(PlanNodeId id, PlanNode outer, PlanNode inner, TableRe
tupleIds.addAll(inner.getOutputTupleIds());
}

public boolean canParallelize() {
public static boolean canParallelize(JoinOperator joinOp) {
return joinOp == JoinOperator.CROSS_JOIN || joinOp == JoinOperator.INNER_JOIN
|| joinOp == JoinOperator.LEFT_OUTER_JOIN || joinOp == JoinOperator.LEFT_SEMI_JOIN
|| joinOp == JoinOperator.LEFT_ANTI_JOIN || joinOp == JoinOperator.NULL_AWARE_LEFT_ANTI_JOIN;
}

public boolean canParallelize() {
return canParallelize(joinOp);
}

public void setJoinConjuncts(List<Expr> joinConjuncts) {
this.joinConjuncts = joinConjuncts;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
package org.apache.doris.nereids.rules.exploration.join;

import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
Expand All @@ -27,11 +29,13 @@
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;

import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

public class JoinCommuteTest implements MemoPatternMatchSupported {
@Test
public void testInnerJoinCommute() {
void testInnerJoinCommute() {
LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);

Expand All @@ -51,4 +55,22 @@ public void testInnerJoinCommute() {
)
;
}

@Test
void testParallelJoinCommute() {
LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);

LogicalJoin<?, ?> join = (LogicalJoin<?, ?>) new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0))
.build();
join = join.withJoinConjuncts(
ImmutableList.of(),
ImmutableList.of(new GreaterThan(scan1.getOutput().get(0), scan2.getOutput().get(0))),
join.getJoinReorderContext());

Assertions.assertEquals(1, PlanChecker.from(MemoTestUtils.createConnectContext(), join)
.applyExploration(JoinCommute.BUSHY.build())
.getAllPlan().size());
}
}