diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java index 71ceee713aeffa..9602bb4a5653ef 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java @@ -17,11 +17,15 @@ package org.apache.doris.nereids.rules.rewrite; +import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.coercion.IntegralType; +import org.apache.doris.nereids.util.ExpressionUtils; import com.google.common.collect.Sets; @@ -70,19 +74,38 @@ public Expression visit(Expression expr, Void context) { @Override public Expression visitComparisonPredicate(ComparisonPredicate cp, Void context) { - if (cp.left().isSlot() && cp.right().isConstant()) { - return replaceSlot(cp); - } else if (cp.left().isConstant() && cp.right().isSlot()) { - return replaceSlot(cp); + // we need to get expression covered by cast, because we want to infer different datatype + if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.left()) && (cp.right().isConstant())) { + return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.left())); + } else if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.right()) && cp.left().isConstant()) { + return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.right())); } return super.visit(cp, context); } - private Expression replaceSlot(Expression expr) { + private boolean isOriginDataTypeBigger(DataType originDataType, Expression expr) { + if ((leftSlotEqualToRightSlot.child(0).getDataType() instanceof IntegralType) + && (leftSlotEqualToRightSlot.child(1).getDataType() instanceof IntegralType) + && (originDataType instanceof IntegralType)) { + // infer filter can not be lower than original datatype, or dataset would be wrong + if (((IntegralType) originDataType).widerThan( + (IntegralType) leftSlotEqualToRightSlot.child(0).getDataType()) + || ((IntegralType) originDataType).widerThan( + (IntegralType) leftSlotEqualToRightSlot.child(1).getDataType())) { + return true; + } + } + return false; + } + + private Expression replaceSlot(Expression expr, DataType originDataType) { return expr.rewriteUp(e -> { - if (e.equals(leftSlotEqualToRightSlot.child(0))) { + if (isOriginDataTypeBigger(originDataType, leftSlotEqualToRightSlot)) { + return e; + } + if (ExpressionUtils.isTwoExpressionEqualWithCast(e, leftSlotEqualToRightSlot.child(0))) { return leftSlotEqualToRightSlot.child(1); - } else if (e.equals(leftSlotEqualToRightSlot.child(1))) { + } else if (ExpressionUtils.isTwoExpressionEqualWithCast(e, leftSlotEqualToRightSlot.child(1))) { return leftSlotEqualToRightSlot.child(0); } else { return e; @@ -98,7 +121,8 @@ private Expression replaceSlot(Expression expr) { */ private boolean canEquivalentInfer(Expression predicate) { return predicate instanceof EqualTo - && predicate.children().stream().allMatch(e -> e instanceof SlotReference) + && predicate.children().stream().allMatch(e -> + (e instanceof SlotReference) || (e instanceof Cast && e.child(0).isSlot())) && predicate.child(0).getDataType().equals(predicate.child(1).getDataType()); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java index 7c147ff0173d27..542f9df993487f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java @@ -41,4 +41,8 @@ public boolean acceptsType(AbstractDataType other) { public String simpleString() { return "integral"; } + + public boolean widerThan(IntegralType other) { + return this.width() > other.width(); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index a3a3ca1b80a812..71f9808ad24d75 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -38,6 +38,7 @@ import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; +import org.apache.doris.nereids.types.DataType; import com.google.common.base.Preconditions; import com.google.common.base.Predicate; @@ -251,6 +252,34 @@ public static Optional extractSlotOrCastOnSlot(Expression expr) { } } + /** + * get slot covered by cast + * example: input: cast(cast(table.columnA)) output: columnA.datatype + * + */ + public static DataType getDatatypeCoveredByCast(Expression expr) { + if (expr instanceof Cast) { + return getDatatypeCoveredByCast(((Cast) expr).child()); + } + return expr.getDataType(); + } + + /** + * judge if expression is slot covered by cast + * example: cast(cast(table.columnA)) + */ + public static boolean isExpressionSlotCoveredByCast(Expression expr) { + if (expr instanceof Cast) { + return isExpressionSlotCoveredByCast(((Cast) expr).child()); + } + return expr instanceof SlotReference; + } + + public static boolean isTwoExpressionEqualWithCast(Expression left, Expression right) { + return ExpressionUtils.extractSlotOrCastOnSlot(left) + .equals(ExpressionUtils.extractSlotOrCastOnSlot(right)); + } + /** * Replace expression node in the expression tree by `replaceMap` in top-down manner. * For example. diff --git a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy new file mode 100644 index 00000000000000..ac462011859480 --- /dev/null +++ b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy @@ -0,0 +1,43 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_infer_predicate") { + sql 'set enable_nereids_planner=true' + sql 'set enable_fallback_to_original_planner=false' + + sql 'drop table if exists infer_tb1;' + sql 'drop table if exists infer_tb2;' + + sql '''create table infer_tb1 (k1 int, k2 int) distributed by hash(k1) buckets 3 properties('replication_num' = '1');''' + + sql '''create table infer_tb2 (k1 tinyint, k2 smallint, k3 int, k4 bigint, k5 largeint, k6 date, k7 datetime, k8 float, k9 double) distributed by hash(k1) buckets 3 properties('replication_num' = '1');''' + + explain { + sql "select * from infer_tb1 inner join infer_tb2 where infer_tb2.k1 = infer_tb1.k2 and infer_tb2.k1 = 1;" + contains "PREDICATES: k2[#20] = 1" + } + + explain { + sql "select * from infer_tb1 inner join infer_tb2 where infer_tb1.k2 = infer_tb2.k1 and infer_tb2.k1 = 1;" + contains "PREDICATES: k2[#20] = 1" + } + + explain { + sql "select * from infer_tb1 inner join infer_tb2 where cast(infer_tb2.k4 as int) = infer_tb1.k2 and infer_tb2.k4 = 1;" + notContains "PREDICATES: k2[#20] = 1" + } +}