From 361b5b610bd768829c23954e658266951c03790b Mon Sep 17 00:00:00 2001 From: LiBinfeng <1204975323@qq.com> Date: Mon, 26 Jun 2023 14:16:18 +0800 Subject: [PATCH 1/2] [Fix](Nereids) Add cast comparison with slot reference when inferring predicate --- .../rules/rewrite/PredicatePropagation.java | 29 +++++++++++--- .../infer_predicate/infer_predicate.groovy | 38 +++++++++++++++++++ .../nereids_p0/infer_predicate/load.groovy | 22 +++++++++++ 3 files changed, 84 insertions(+), 5 deletions(-) create mode 100644 regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy create mode 100644 regression-test/suites/nereids_p0/infer_predicate/load.groovy diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java index 71ceee713aeffa..20519988e38799 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java @@ -17,6 +17,7 @@ package org.apache.doris.nereids.rules.rewrite; +import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; @@ -70,19 +71,36 @@ public Expression visit(Expression expr, Void context) { @Override public Expression visitComparisonPredicate(ComparisonPredicate cp, Void context) { - if (cp.left().isSlot() && cp.right().isConstant()) { + if ((cp.left().isSlot() || (cp.left() instanceof Cast && cp.left().child(0).isSlot())) + && (cp.right().isConstant())) { return replaceSlot(cp); - } else if (cp.left().isConstant() && cp.right().isSlot()) { + } else if ((cp.right().isSlot() || (cp.right() instanceof Cast && cp.right().child(0).isSlot())) + && cp.left().isConstant()) { return replaceSlot(cp); } return super.visit(cp, context); } + private boolean isTwoExpressionEqualWithCast(Expression left, Expression right) { + if (left.getDataType() != right.getDataType()) { + return false; + } + if (left instanceof Cast && right instanceof Cast) { + return ((Cast) left).child().equals(((Cast) right).child()); + } else if (left instanceof Cast) { + return ((Cast) left).child().equals(right); + } else if (right instanceof Cast) { + return ((Cast) right).child().equals(left); + } else { + return left.equals(right); + } + } + private Expression replaceSlot(Expression expr) { return expr.rewriteUp(e -> { - if (e.equals(leftSlotEqualToRightSlot.child(0))) { + if (isTwoExpressionEqualWithCast(e, leftSlotEqualToRightSlot.child(0))) { return leftSlotEqualToRightSlot.child(1); - } else if (e.equals(leftSlotEqualToRightSlot.child(1))) { + } else if (isTwoExpressionEqualWithCast(e, leftSlotEqualToRightSlot.child(1))) { return leftSlotEqualToRightSlot.child(0); } else { return e; @@ -98,7 +116,8 @@ private Expression replaceSlot(Expression expr) { */ private boolean canEquivalentInfer(Expression predicate) { return predicate instanceof EqualTo - && predicate.children().stream().allMatch(e -> e instanceof SlotReference) + && predicate.children().stream().allMatch(e -> + (e instanceof SlotReference) || (e instanceof Cast && e.child(0).isSlot())) && predicate.child(0).getDataType().equals(predicate.child(1).getDataType()); } diff --git a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy new file mode 100644 index 00000000000000..691cc912631212 --- /dev/null +++ b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy @@ -0,0 +1,38 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_infer_predicate") { + sql 'set enable_nereids_planner=true' + sql 'set enable_fallback_to_original_planner=false' + + sql 'drop table if exists infer_tb1;' + sql 'drop table if exists infer_tb2;' + + sql '''create table infer_tb1 (k1 int, k2 int) distributed by hash(k1) buckets 3 properties('replication_num' = '1');''' + + sql '''create table infer_tb2 (k1 tinyint, k2 smallint, k3 int, k4 bigint, k5 largeint, k6 date, k7 datetime, k8 float, k9 double) distributed by hash(k1) buckets 3 properties('replication_num' = '1');''' + + res = sql "explain select * from infer_tb1 inner join infer_tb2 where infer_tb2.k1 = infer_tb1.k2 and infer_tb2.k1 = 1;" + assertFalse(res.contains("k2 = 1")) + + res = sql "explain select * from infer_tb1 inner join infer_tb2 where infer_tb2.k2 = infer_tb1.k1 and infer_tb2.k1 = 1;" + assertFalse(res.contains("k2 = 1")) + + res = sql "explain select * from infer_tb1 inner join infer_tb2 where infer_tb1.k2 = infer_tb2.k3 and infer_tb2.k3 = 1;" + assertFalse(res.contains("k2 = 1")) + +} diff --git a/regression-test/suites/nereids_p0/infer_predicate/load.groovy b/regression-test/suites/nereids_p0/infer_predicate/load.groovy new file mode 100644 index 00000000000000..829395215c3263 --- /dev/null +++ b/regression-test/suites/nereids_p0/infer_predicate/load.groovy @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("load") { + sql 'create database if not exists nereids_infer_predicate_test' + + sql 'use nereids_infer_predicate_test' +} From b5c908949032fd11804d2f5c0243d011220dc889 Mon Sep 17 00:00:00 2001 From: LiBinfeng <1204975323@qq.com> Date: Wed, 5 Jul 2023 14:18:39 +0800 Subject: [PATCH 2/2] [Fix](Nereids) change cast one level to multi levels --- .../rules/rewrite/PredicatePropagation.java | 47 ++++++++++--------- .../nereids/types/coercion/IntegralType.java | 4 ++ .../doris/nereids/util/ExpressionUtils.java | 29 ++++++++++++ .../infer_predicate/infer_predicate.groovy | 23 +++++---- .../nereids_p0/infer_predicate/load.groovy | 22 --------- 5 files changed, 73 insertions(+), 52 deletions(-) delete mode 100644 regression-test/suites/nereids_p0/infer_predicate/load.groovy diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java index 20519988e38799..9602bb4a5653ef 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java @@ -23,6 +23,9 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.coercion.IntegralType; +import org.apache.doris.nereids.util.ExpressionUtils; import com.google.common.collect.Sets; @@ -71,36 +74,38 @@ public Expression visit(Expression expr, Void context) { @Override public Expression visitComparisonPredicate(ComparisonPredicate cp, Void context) { - if ((cp.left().isSlot() || (cp.left() instanceof Cast && cp.left().child(0).isSlot())) - && (cp.right().isConstant())) { - return replaceSlot(cp); - } else if ((cp.right().isSlot() || (cp.right() instanceof Cast && cp.right().child(0).isSlot())) - && cp.left().isConstant()) { - return replaceSlot(cp); + // we need to get expression covered by cast, because we want to infer different datatype + if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.left()) && (cp.right().isConstant())) { + return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.left())); + } else if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.right()) && cp.left().isConstant()) { + return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.right())); } return super.visit(cp, context); } - private boolean isTwoExpressionEqualWithCast(Expression left, Expression right) { - if (left.getDataType() != right.getDataType()) { - return false; - } - if (left instanceof Cast && right instanceof Cast) { - return ((Cast) left).child().equals(((Cast) right).child()); - } else if (left instanceof Cast) { - return ((Cast) left).child().equals(right); - } else if (right instanceof Cast) { - return ((Cast) right).child().equals(left); - } else { - return left.equals(right); + private boolean isOriginDataTypeBigger(DataType originDataType, Expression expr) { + if ((leftSlotEqualToRightSlot.child(0).getDataType() instanceof IntegralType) + && (leftSlotEqualToRightSlot.child(1).getDataType() instanceof IntegralType) + && (originDataType instanceof IntegralType)) { + // infer filter can not be lower than original datatype, or dataset would be wrong + if (((IntegralType) originDataType).widerThan( + (IntegralType) leftSlotEqualToRightSlot.child(0).getDataType()) + || ((IntegralType) originDataType).widerThan( + (IntegralType) leftSlotEqualToRightSlot.child(1).getDataType())) { + return true; + } } + return false; } - private Expression replaceSlot(Expression expr) { + private Expression replaceSlot(Expression expr, DataType originDataType) { return expr.rewriteUp(e -> { - if (isTwoExpressionEqualWithCast(e, leftSlotEqualToRightSlot.child(0))) { + if (isOriginDataTypeBigger(originDataType, leftSlotEqualToRightSlot)) { + return e; + } + if (ExpressionUtils.isTwoExpressionEqualWithCast(e, leftSlotEqualToRightSlot.child(0))) { return leftSlotEqualToRightSlot.child(1); - } else if (isTwoExpressionEqualWithCast(e, leftSlotEqualToRightSlot.child(1))) { + } else if (ExpressionUtils.isTwoExpressionEqualWithCast(e, leftSlotEqualToRightSlot.child(1))) { return leftSlotEqualToRightSlot.child(0); } else { return e; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java index 7c147ff0173d27..542f9df993487f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java @@ -41,4 +41,8 @@ public boolean acceptsType(AbstractDataType other) { public String simpleString() { return "integral"; } + + public boolean widerThan(IntegralType other) { + return this.width() > other.width(); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index a3a3ca1b80a812..71f9808ad24d75 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -38,6 +38,7 @@ import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; +import org.apache.doris.nereids.types.DataType; import com.google.common.base.Preconditions; import com.google.common.base.Predicate; @@ -251,6 +252,34 @@ public static Optional extractSlotOrCastOnSlot(Expression expr) { } } + /** + * get slot covered by cast + * example: input: cast(cast(table.columnA)) output: columnA.datatype + * + */ + public static DataType getDatatypeCoveredByCast(Expression expr) { + if (expr instanceof Cast) { + return getDatatypeCoveredByCast(((Cast) expr).child()); + } + return expr.getDataType(); + } + + /** + * judge if expression is slot covered by cast + * example: cast(cast(table.columnA)) + */ + public static boolean isExpressionSlotCoveredByCast(Expression expr) { + if (expr instanceof Cast) { + return isExpressionSlotCoveredByCast(((Cast) expr).child()); + } + return expr instanceof SlotReference; + } + + public static boolean isTwoExpressionEqualWithCast(Expression left, Expression right) { + return ExpressionUtils.extractSlotOrCastOnSlot(left) + .equals(ExpressionUtils.extractSlotOrCastOnSlot(right)); + } + /** * Replace expression node in the expression tree by `replaceMap` in top-down manner. * For example. diff --git a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy index 691cc912631212..ac462011859480 100644 --- a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy +++ b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy @@ -26,13 +26,18 @@ suite("test_infer_predicate") { sql '''create table infer_tb2 (k1 tinyint, k2 smallint, k3 int, k4 bigint, k5 largeint, k6 date, k7 datetime, k8 float, k9 double) distributed by hash(k1) buckets 3 properties('replication_num' = '1');''' - res = sql "explain select * from infer_tb1 inner join infer_tb2 where infer_tb2.k1 = infer_tb1.k2 and infer_tb2.k1 = 1;" - assertFalse(res.contains("k2 = 1")) - - res = sql "explain select * from infer_tb1 inner join infer_tb2 where infer_tb2.k2 = infer_tb1.k1 and infer_tb2.k1 = 1;" - assertFalse(res.contains("k2 = 1")) - - res = sql "explain select * from infer_tb1 inner join infer_tb2 where infer_tb1.k2 = infer_tb2.k3 and infer_tb2.k3 = 1;" - assertFalse(res.contains("k2 = 1")) - + explain { + sql "select * from infer_tb1 inner join infer_tb2 where infer_tb2.k1 = infer_tb1.k2 and infer_tb2.k1 = 1;" + contains "PREDICATES: k2[#20] = 1" + } + + explain { + sql "select * from infer_tb1 inner join infer_tb2 where infer_tb1.k2 = infer_tb2.k1 and infer_tb2.k1 = 1;" + contains "PREDICATES: k2[#20] = 1" + } + + explain { + sql "select * from infer_tb1 inner join infer_tb2 where cast(infer_tb2.k4 as int) = infer_tb1.k2 and infer_tb2.k4 = 1;" + notContains "PREDICATES: k2[#20] = 1" + } } diff --git a/regression-test/suites/nereids_p0/infer_predicate/load.groovy b/regression-test/suites/nereids_p0/infer_predicate/load.groovy deleted file mode 100644 index 829395215c3263..00000000000000 --- a/regression-test/suites/nereids_p0/infer_predicate/load.groovy +++ /dev/null @@ -1,22 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -suite("load") { - sql 'create database if not exists nereids_infer_predicate_test' - - sql 'use nereids_infer_predicate_test' -}