diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunction.java
index 2f8e1404b7199e..fe2d7072ef5d98 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunction.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunction.java
@@ -21,9 +21,11 @@
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.util.ExpressionUtils;
+import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
@@ -50,24 +52,44 @@
* 3. In old optimizer, there is `InferFilterRule` generates redundancy expressions. Its Nereid counterpart also need
* `RemoveRedundantExpression`.
*
- * TODO: This rule just match filter, but it could be applied to inner/cross join condition.
*/
-public class ExtractSingleTableExpressionFromDisjunction extends OneRewriteRuleFactory {
+public class ExtractSingleTableExpressionFromDisjunction implements RewriteRuleFactory {
+ private static final ImmutableSet ALLOW_JOIN_TYPE = ImmutableSet.of(JoinType.INNER_JOIN,
+ JoinType.LEFT_OUTER_JOIN, JoinType.RIGHT_OUTER_JOIN, JoinType.LEFT_SEMI_JOIN, JoinType.RIGHT_SEMI_JOIN,
+ JoinType.LEFT_ANTI_JOIN, JoinType.RIGHT_ANTI_JOIN, JoinType.CROSS_JOIN, JoinType.FULL_OUTER_JOIN);
+
@Override
- public Rule build() {
- return logicalFilter().then(filter -> {
- List dependentPredicates = extractDependentConjuncts(filter.getConjuncts());
- if (dependentPredicates.isEmpty()) {
- return null;
- }
- Set newPredicates = ImmutableSet.builder()
- .addAll(filter.getConjuncts())
- .addAll(dependentPredicates).build();
- if (newPredicates.size() == filter.getConjuncts().size()) {
- return null;
- }
- return new LogicalFilter<>(newPredicates, filter.child());
- }).toRule(RuleType.EXTRACT_SINGLE_TABLE_EXPRESSION_FROM_DISJUNCTION);
+ public List buildRules() {
+ return ImmutableList.of(
+ logicalFilter().then(filter -> {
+ List dependentPredicates = extractDependentConjuncts(filter.getConjuncts());
+ if (dependentPredicates.isEmpty()) {
+ return null;
+ }
+ Set newPredicates = ImmutableSet.builder()
+ .addAll(filter.getConjuncts())
+ .addAll(dependentPredicates).build();
+ if (newPredicates.size() == filter.getConjuncts().size()) {
+ return null;
+ }
+ return new LogicalFilter<>(newPredicates, filter.child());
+ }).toRule(RuleType.EXTRACT_SINGLE_TABLE_EXPRESSION_FROM_DISJUNCTION),
+ logicalJoin().when(join -> ALLOW_JOIN_TYPE.contains(join.getJoinType())).then(join -> {
+ List dependentOtherPredicates = extractDependentConjuncts(
+ ImmutableSet.copyOf(join.getOtherJoinConjuncts()));
+ if (dependentOtherPredicates.isEmpty()) {
+ return null;
+ }
+ Set newOtherPredicates = ImmutableSet.builder()
+ .addAll(join.getOtherJoinConjuncts())
+ .addAll(dependentOtherPredicates).build();
+ if (newOtherPredicates.size() == join.getOtherJoinConjuncts().size()) {
+ return null;
+ }
+ return join.withJoinConjuncts(join.getHashJoinConjuncts(),
+ ImmutableList.copyOf(newOtherPredicates),
+ join.getMarkJoinConjuncts(), join.getJoinReorderContext());
+ }).toRule(RuleType.EXTRACT_SINGLE_TABLE_EXPRESSION_FROM_DISJUNCTION));
}
private List extractDependentConjuncts(Set conjuncts) {
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunctionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunctionTest.java
index fc55f473ee6417..39706d39f2cb0d 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunctionTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunctionTest.java
@@ -29,6 +29,7 @@
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
@@ -41,6 +42,7 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
+import java.util.List;
import java.util.Set;
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
@@ -179,4 +181,38 @@ private boolean verifySingleTableExpression3(Set conjuncts) {
return conjuncts.size() == 2 && conjuncts.contains(or);
}
+
+ /**
+ * test join otherJoinReorderContext
+ *(cid=1 and sage=10) or sgender=1
+ * =>
+ * (sage=10 or sgender=1)
+ */
+ @Test
+ public void testExtract4() {
+ Expression expr = new Or(
+ new And(
+ new EqualTo(courseCid, new IntegerLiteral(1)),
+ new EqualTo(studentAge, new IntegerLiteral(10))
+ ),
+ new EqualTo(studentGender, new IntegerLiteral(1))
+ );
+ Plan join = new LogicalJoin<>(JoinType.CROSS_JOIN, ExpressionUtils.EMPTY_CONDITION, ImmutableList.of(expr),
+ student, course, null);
+ PlanChecker.from(MemoTestUtils.createConnectContext(), join)
+ .applyTopDown(new ExtractSingleTableExpressionFromDisjunction())
+ .matchesFromRoot(
+ logicalJoin()
+ .when(j -> verifySingleTableExpression4(j.getOtherJoinConjuncts()))
+ );
+ Assertions.assertNotNull(studentGender);
+ }
+
+ private boolean verifySingleTableExpression4(List conjuncts) {
+ Expression or = new Or(
+ new EqualTo(studentAge, new IntegerLiteral(10)),
+ new EqualTo(studentGender, new IntegerLiteral(1))
+ );
+ return conjuncts.size() == 2 && conjuncts.contains(or);
+ }
}
diff --git a/regression-test/data/nereids_rules_p0/push_down_filter/extract_from_disjunction_in_join.out b/regression-test/data/nereids_rules_p0/push_down_filter/extract_from_disjunction_in_join.out
new file mode 100644
index 00000000000000..9077ecb24b9b56
--- /dev/null
+++ b/regression-test/data/nereids_rules_p0/push_down_filter/extract_from_disjunction_in_join.out
@@ -0,0 +1,94 @@
+-- This file is automatically generated. You should know what you did if you want to edit this
+-- !left_semi --
+PhysicalResultSink
+--hashJoin[LEFT_SEMI_JOIN] hashCondition=((t1.b = t2.b)) otherCondition=((((t2.a = 9) AND (t1.a = 1)) OR ((t1.a = 2) AND (t2.a = 8))))
+----filter(a IN (1, 2))
+------PhysicalOlapScan[extract_from_disjunction_in_join_t1]
+----filter(a IN (8, 9))
+------PhysicalOlapScan[extract_from_disjunction_in_join_t2]
+
+-- !right_semi --
+PhysicalResultSink
+--hashJoin[LEFT_SEMI_JOIN] hashCondition=((t1.b = t2.b)) otherCondition=((((t2.a = 9) AND (t1.a = 1)) OR ((t1.a = 2) AND (t2.a = 8))))
+----filter(a IN (8, 9))
+------PhysicalOlapScan[extract_from_disjunction_in_join_t2]
+----filter(a IN (1, 2))
+------PhysicalOlapScan[extract_from_disjunction_in_join_t1]
+
+-- !left --
+PhysicalResultSink
+--hashJoin[LEFT_OUTER_JOIN] hashCondition=((t1.b = t2.b)) otherCondition=((((t2.a = 9) AND (t1.a = 1)) OR ((t1.a = 2) AND (t2.a = 8))) and a IN (1, 2))
+----PhysicalOlapScan[extract_from_disjunction_in_join_t1]
+----filter(a IN (8, 9))
+------PhysicalOlapScan[extract_from_disjunction_in_join_t2]
+
+-- !right --
+PhysicalResultSink
+--hashJoin[LEFT_OUTER_JOIN] hashCondition=((t1.b = t2.b)) otherCondition=((((t2.a = 9) AND (t1.a = 1)) OR ((t1.a = 2) AND (t2.a = 8))) and a IN (8, 9))
+----PhysicalOlapScan[extract_from_disjunction_in_join_t2]
+----filter(a IN (1, 2))
+------PhysicalOlapScan[extract_from_disjunction_in_join_t1]
+
+-- !left_anti --
+PhysicalResultSink
+--hashJoin[LEFT_ANTI_JOIN] hashCondition=((t1.b = t2.b)) otherCondition=((((t2.a = 9) AND (t1.a = 1)) OR ((t1.a = 2) AND (t2.a = 8))) and a IN (1, 2))
+----PhysicalOlapScan[extract_from_disjunction_in_join_t1]
+----filter(a IN (8, 9))
+------PhysicalOlapScan[extract_from_disjunction_in_join_t2]
+
+-- !right_anti --
+PhysicalResultSink
+--hashJoin[LEFT_ANTI_JOIN] hashCondition=((t1.b = t2.b)) otherCondition=((((t2.a = 9) AND (t1.a = 1)) OR ((t1.a = 2) AND (t2.a = 8))) and a IN (8, 9))
+----PhysicalOlapScan[extract_from_disjunction_in_join_t2]
+----filter(a IN (1, 2))
+------PhysicalOlapScan[extract_from_disjunction_in_join_t1]
+
+-- !inner --
+PhysicalResultSink
+--hashJoin[INNER_JOIN] hashCondition=((t1.b = t2.b)) otherCondition=((((t2.a = 9) AND (t1.a = 1)) OR ((t1.a = 2) AND (t2.a = 8))))
+----filter(a IN (1, 2))
+------PhysicalOlapScan[extract_from_disjunction_in_join_t1]
+----filter(a IN (8, 9))
+------PhysicalOlapScan[extract_from_disjunction_in_join_t2]
+
+-- !outer --
+PhysicalResultSink
+--hashJoin[LEFT_OUTER_JOIN] hashCondition=((t1.b = t2.b)) otherCondition=((((t2.a = 9) AND (t1.a = 1)) OR ((t1.a = 2) AND (t2.a = 8))) and a IN (1, 2))
+----filter((t1.c = 3))
+------PhysicalOlapScan[extract_from_disjunction_in_join_t1]
+----filter(a IN (8, 9))
+------PhysicalOlapScan[extract_from_disjunction_in_join_t2]
+
+-- !left_semi_res --
+1
+2
+
+-- !right_semi_res --
+8
+9
+
+-- !left_res --
+1
+2
+3
+
+-- !right_res --
+\N
+1
+2
+
+-- !left_anti_res --
+3
+
+-- !right_anti_res --
+7
+
+-- !inner_res --
+1
+2
+
+-- !outer_res --
+1
+2
+3
+
diff --git a/regression-test/data/nereids_rules_p0/push_down_filter_through_window/push_down_filter_through_window.out b/regression-test/data/nereids_rules_p0/push_down_filter/push_down_filter_through_window.out
similarity index 100%
rename from regression-test/data/nereids_rules_p0/push_down_filter_through_window/push_down_filter_through_window.out
rename to regression-test/data/nereids_rules_p0/push_down_filter/push_down_filter_through_window.out
diff --git a/regression-test/suites/nereids_rules_p0/push_down_filter/extract_from_disjunction_in_join.groovy b/regression-test/suites/nereids_rules_p0/push_down_filter/extract_from_disjunction_in_join.groovy
new file mode 100644
index 00000000000000..858f39e5e65cf2
--- /dev/null
+++ b/regression-test/suites/nereids_rules_p0/push_down_filter/extract_from_disjunction_in_join.groovy
@@ -0,0 +1,83 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("extract_from_disjunction_in_join") {
+ sql "SET enable_nereids_planner=true"
+ sql "SET enable_fallback_to_original_planner=false"
+ sql "set ignore_shape_nodes='PhysicalDistribute,PhysicalProject'"
+ sql "set disable_nereids_rules=PRUNE_EMPTY_PARTITION"
+ sql "set runtime_filter_mode=OFF"
+
+
+ sql "drop table if exists extract_from_disjunction_in_join_t1"
+ sql "drop table if exists extract_from_disjunction_in_join_t2"
+ sql """
+ CREATE TABLE `extract_from_disjunction_in_join_t1` (
+ `a` INT NULL,
+ `b` VARCHAR(10) NULL,
+ `c` INT NULL,
+ `d` INT NULL
+ ) ENGINE=OLAP
+ DUPLICATE KEY(`a`, `b`)
+ DISTRIBUTED BY RANDOM BUCKETS AUTO
+ PROPERTIES (
+ "replication_allocation" = "tag.location.default: 1"
+ );
+ """
+ sql """
+ CREATE TABLE `extract_from_disjunction_in_join_t2` (
+ `a` INT NULL,
+ `b` VARCHAR(10) NULL,
+ `c` INT NULL,
+ `d` INT NULL
+ ) ENGINE=OLAP
+ DUPLICATE KEY(`a`, `b`)
+ DISTRIBUTED BY RANDOM BUCKETS AUTO
+ PROPERTIES (
+ "replication_allocation" = "tag.location.default: 1"
+ );"""
+
+ sql "insert into extract_from_disjunction_in_join_t1 values(1,'d2',3,5),(2,'d2',3,5),(3,'d2',3,5);"
+ sql "insert into extract_from_disjunction_in_join_t2 values(7,'d2',2,2),(8,'d2',2,2),(9,'d2',2,2);"
+ qt_left_semi """explain shape plan
+ select * from extract_from_disjunction_in_join_t1 t1 left semi join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8);"""
+ qt_right_semi """explain shape plan
+ select * from extract_from_disjunction_in_join_t1 t1 right semi join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8);"""
+ qt_left """explain shape plan
+ select * from extract_from_disjunction_in_join_t1 t1 left join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8);"""
+ qt_right """explain shape plan
+ select * from extract_from_disjunction_in_join_t1 t1 right join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8);"""
+ qt_left_anti """explain shape plan
+ select * from extract_from_disjunction_in_join_t1 t1 left anti join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8);"""
+ qt_right_anti """explain shape plan
+ select * from extract_from_disjunction_in_join_t1 t1 right anti join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8);"""
+ qt_inner """explain shape plan
+ select * from extract_from_disjunction_in_join_t1 t1 inner join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8);"""
+ qt_outer """explain shape plan
+ select * from extract_from_disjunction_in_join_t1 t1 full join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8)
+ where t1.c=3;"""
+
+ qt_left_semi_res "select t1.a from extract_from_disjunction_in_join_t1 t1 left semi join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8) order by 1;"
+ qt_right_semi_res "select t2.a from extract_from_disjunction_in_join_t1 t1 right semi join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8) order by 1;"
+ qt_left_res "select t1.a from extract_from_disjunction_in_join_t1 t1 left join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8) order by 1;"
+ qt_right_res "select t1.a from extract_from_disjunction_in_join_t1 t1 right join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8) order by 1;"
+ qt_left_anti_res "select t1.a from extract_from_disjunction_in_join_t1 t1 left anti join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8) order by 1;"
+ qt_right_anti_res "select t2.a from extract_from_disjunction_in_join_t1 t1 right anti join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8) order by 1;"
+ qt_inner_res "select t1.a from extract_from_disjunction_in_join_t1 t1 inner join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8) order by 1;"
+ qt_outer_res """select t1.a from extract_from_disjunction_in_join_t1 t1 full join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8)
+ where t1.c=3 order by 1;"""
+}
\ No newline at end of file
diff --git a/regression-test/suites/nereids_rules_p0/push_down_filter_through_window/push_down_filter_through_window.groovy b/regression-test/suites/nereids_rules_p0/push_down_filter/push_down_filter_through_window.groovy
similarity index 100%
rename from regression-test/suites/nereids_rules_p0/push_down_filter_through_window/push_down_filter_through_window.groovy
rename to regression-test/suites/nereids_rules_p0/push_down_filter/push_down_filter_through_window.groovy