From c53e2a993f5cf1132b58c4ea26ed79e708d8def6 Mon Sep 17 00:00:00 2001 From: LiBinfeng <1204975323@qq.com> Date: Thu, 31 Aug 2023 10:57:57 +0800 Subject: [PATCH 1/2] [Fix](Nereids) fix infer predicate lost cast of original expression --- .../doris/nereids/rules/rewrite/PredicatePropagation.java | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java index cc45952817a845..dde0ee5c8d91f0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java @@ -102,8 +102,14 @@ private Expression replaceSlot(Expression expr, DataType originDataType) { return expr.rewriteUp(e -> { if (isDataTypeValid(originDataType, leftSlotEqualToRightSlot)) { if (ExpressionUtils.isTwoExpressionEqualWithCast(e, leftSlotEqualToRightSlot.child(0))) { + if (!leftSlotEqualToRightSlot.child(1).getDataType().equals(e.getDataType())) { + return new Cast(leftSlotEqualToRightSlot.child(1), e.getDataType()); + } return leftSlotEqualToRightSlot.child(1); } else if (ExpressionUtils.isTwoExpressionEqualWithCast(e, leftSlotEqualToRightSlot.child(1))) { + if (!leftSlotEqualToRightSlot.child(0).getDataType().equals(e.getDataType())) { + return new Cast(leftSlotEqualToRightSlot.child(0), e.getDataType()); + } return leftSlotEqualToRightSlot.child(0); } } From f0be2951434fa831f383d5f2fb44fb8c13bec7a0 Mon Sep 17 00:00:00 2001 From: LiBinfeng <1204975323@qq.com> Date: Thu, 7 Sep 2023 17:47:41 +0800 Subject: [PATCH 2/2] [Fix](Nereids) fix infer predicate lost some datatypes --- .../rules/rewrite/PredicatePropagation.java | 45 ++++++++++--------- .../rules/rewrite/InferPredicatesTest.java | 30 +++++++++++++ .../infer_predicate/infer_predicate.groovy | 18 ++++++++ 3 files changed, 71 insertions(+), 22 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java index dde0ee5c8d91f0..71818966696958 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java @@ -59,12 +59,12 @@ public Set infer(Set predicates) { } /** - * Use the left or right child of `leftSlotEqualToRightSlot` to replace the left or right child of `expression` + * Use the left or right child of `equalExpr` to replace the left or right child of `expression` * Now only support infer `ComparisonPredicate`. * TODO: We should determine whether `expression` satisfies the condition for replacement * eg: Satisfy `expression` is non-deterministic */ - private Expression doInfer(Expression leftSlotEqualToRightSlot, Expression expression) { + private Expression doInfer(Expression equalExpr, Expression expression) { return expression.accept(new DefaultExpressionRewriter() { @Override @@ -76,42 +76,43 @@ public Expression visit(Expression expr, Void context) { public Expression visitComparisonPredicate(ComparisonPredicate cp, Void context) { // we need to get expression covered by cast, because we want to infer different datatype if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.left()) && (cp.right().isConstant())) { - return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.left())); + return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.left()), equalExpr); } else if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.right()) && cp.left().isConstant()) { - return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.right())); + return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.right()), equalExpr); } return super.visit(cp, context); } private boolean isDataTypeValid(DataType originDataType, Expression expr) { - if ((leftSlotEqualToRightSlot.child(0).getDataType() instanceof IntegralType) - && (leftSlotEqualToRightSlot.child(1).getDataType() instanceof IntegralType) + if ((expr.child(0).getDataType() instanceof IntegralType) + && (expr.child(1).getDataType() instanceof IntegralType) && (originDataType instanceof IntegralType)) { // infer filter can not be lower than original datatype, or dataset would be wrong if (!((IntegralType) originDataType).widerThan( - (IntegralType) leftSlotEqualToRightSlot.child(0).getDataType()) + (IntegralType) expr.child(0).getDataType()) && !((IntegralType) originDataType).widerThan( - (IntegralType) leftSlotEqualToRightSlot.child(1).getDataType())) { + (IntegralType) expr.child(1).getDataType())) { return true; } + } else if (expr.child(0).getDataType().equals(expr.child(1).getDataType())) { + return true; } return false; } - private Expression replaceSlot(Expression expr, DataType originDataType) { - return expr.rewriteUp(e -> { - if (isDataTypeValid(originDataType, leftSlotEqualToRightSlot)) { - if (ExpressionUtils.isTwoExpressionEqualWithCast(e, leftSlotEqualToRightSlot.child(0))) { - if (!leftSlotEqualToRightSlot.child(1).getDataType().equals(e.getDataType())) { - return new Cast(leftSlotEqualToRightSlot.child(1), e.getDataType()); - } - return leftSlotEqualToRightSlot.child(1); - } else if (ExpressionUtils.isTwoExpressionEqualWithCast(e, leftSlotEqualToRightSlot.child(1))) { - if (!leftSlotEqualToRightSlot.child(0).getDataType().equals(e.getDataType())) { - return new Cast(leftSlotEqualToRightSlot.child(0), e.getDataType()); - } - return leftSlotEqualToRightSlot.child(0); - } + private Expression replaceSlot(Expression sourcePredicate, DataType originDataType, Expression equal) { + if (!isDataTypeValid(originDataType, equal)) { + return sourcePredicate; + } + return sourcePredicate.rewriteUp(e -> { + // we can not replace Cast expression to slot because when rewrite up, we have replace child of cast + if (e instanceof Cast) { + return e; + } + if (ExpressionUtils.isTwoExpressionEqualWithCast(e, equal.child(0))) { + return equal.child(1); + } else if (ExpressionUtils.isTwoExpressionEqualWithCast(e, equal.child(1))) { + return equal.child(0); } return e; }); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java index adc67ca835f915..b7b235d2b43041 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java @@ -17,15 +17,33 @@ package org.apache.doris.nereids.rules.rewrite; +import org.apache.doris.nereids.trees.expressions.Cast; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.nereids.util.PlanConstructor; import org.apache.doris.utframe.TestWithFeService; +import com.google.common.collect.Sets; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import java.util.Optional; +import java.util.Set; + public class InferPredicatesTest extends TestWithFeService implements MemoPatternMatchSupported { + private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + + private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); + + private final PredicatePropagation propagation = new PredicatePropagation(); + @Override protected void runBeforeAll() throws Exception { createDatabase("test"); @@ -628,4 +646,16 @@ public void innerJoinShouldNotInferUnderLeftJoinOnClausePredicates() { ).when(join -> join.getJoinType() == JoinType.LEFT_OUTER_JOIN) ); } + + @Test + void testInfer() { + EqualTo equalTo = new EqualTo(new Cast(scan1.getOutput().get(0), BigIntType.INSTANCE), Literal.of(1)); + EqualTo equalTo2 = new EqualTo(scan2.getOutput().get(0), scan1.getOutput().get(0)); + Set predicates = Sets.newHashSet(); + predicates.add(equalTo2); + predicates.add(equalTo); + Set newPredicates = propagation.infer(predicates); + Optional newPredicate = newPredicates.stream().findFirst(); + Assertions.assertTrue(newPredicate.get().equals(new EqualTo(new Cast(scan2.getOutput().get(0), BigIntType.INSTANCE), Literal.of(1)))); + } } diff --git a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy index a1621f1c239aa5..120c9a8f674458 100644 --- a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy +++ b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy @@ -22,6 +22,8 @@ suite("test_infer_predicate") { sql 'drop table if exists infer_tb1;' sql 'drop table if exists infer_tb2;' sql 'drop table if exists infer_tb3;' + sql 'drop table if exists infer_tb4;' + sql 'drop table if exists infer_tb5;' sql '''create table infer_tb1 (k1 int, k2 int) distributed by hash(k1) buckets 3 properties('replication_num' = '1');''' @@ -29,6 +31,10 @@ suite("test_infer_predicate") { sql '''create table infer_tb3 (k1 varchar(100), k2 int) distributed by hash(k1) buckets 3 properties('replication_num' = '1');''' + sql '''create table infer_tb4 (k1 varchar(100), k2 date) distributed by hash(k1) buckets 3 properties('replication_num' = '1');''' + + sql '''create table infer_tb5 (k1 varchar(100), k3 date) distributed by hash(k1) buckets 3 properties('replication_num' = '1');''' + explain { sql "select * from infer_tb1 inner join infer_tb2 where infer_tb2.k1 = infer_tb1.k2 and infer_tb2.k1 = 1;" contains "PREDICATES: k2" @@ -55,4 +61,16 @@ suite("test_infer_predicate") { contains "PREDICATES: k3" contains "PREDICATES: k2" } + + explain { + sql "select * from infer_tb4 left join infer_tb5 on infer_tb4.k2 = infer_tb5.k3 where infer_tb4.k2 = '20230901';" + contains "PREDICATES: k3" + contains "PREDICATES: k2" + } + + sql 'drop table if exists infer_tb1;' + sql 'drop table if exists infer_tb2;' + sql 'drop table if exists infer_tb3;' + sql 'drop table if exists infer_tb4;' + sql 'drop table if exists infer_tb5;' }