From 327f48298344f15036635d77ef88d84672b197f6 Mon Sep 17 00:00:00 2001 From: lichi <12095047@qq.com> Date: Wed, 19 Jun 2024 14:23:55 +0800 Subject: [PATCH] [fix](nereids)NullSafeEqualToEqual rule should keep <=> unchanged if it has none-literal child --- .../expression/ExpressionOptimization.java | 2 + .../rules/NullSafeEqualToEqual.java | 24 +++++------- .../rules/NullSafeEqualToEqualTest.java | 38 +++++++++++++++---- 3 files changed, 41 insertions(+), 23 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java index 828592bbba3a5a..abf57057601dc8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java @@ -23,6 +23,7 @@ import org.apache.doris.nereids.rules.expression.rules.DistinctPredicatesRule; import org.apache.doris.nereids.rules.expression.rules.ExtractCommonFactorRule; import org.apache.doris.nereids.rules.expression.rules.LikeToEqualRewrite; +import org.apache.doris.nereids.rules.expression.rules.NullSafeEqualToEqual; import org.apache.doris.nereids.rules.expression.rules.OrToIn; import org.apache.doris.nereids.rules.expression.rules.SimplifyComparisonPredicate; import org.apache.doris.nereids.rules.expression.rules.SimplifyDecimalV3Comparison; @@ -51,6 +52,7 @@ public class ExpressionOptimization extends ExpressionRewrite { ArrayContainToArrayOverlap.INSTANCE, CaseWhenToIf.INSTANCE, TopnToMax.INSTANCE, + NullSafeEqualToEqual.INSTANCE, LikeToEqualRewrite.INSTANCE ) ); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqual.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqual.java index dda109a42e083a..16c4663a1edacd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqual.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqual.java @@ -24,17 +24,16 @@ import org.apache.doris.nereids.trees.expressions.IsNull; import org.apache.doris.nereids.trees.expressions.NullSafeEqual; import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; -import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; import com.google.common.collect.ImmutableList; import java.util.List; /** - * convert "<=>" to "=", if both sides are not nullable * convert "A <=> null" to "A is null" * null <=> null : true * null <=> 1 : false + * 1 <=> 2 : 1 = 2 */ public class NullSafeEqualToEqual implements ExpressionPatternRuleFactory { public static final NullSafeEqualToEqual INSTANCE = new NullSafeEqualToEqual(); @@ -47,19 +46,14 @@ public List> buildRules() { } private static Expression rewrite(NullSafeEqual nullSafeEqual) { - if (nullSafeEqual.left() instanceof NullLiteral) { - if (nullSafeEqual.right().nullable()) { - return new IsNull(nullSafeEqual.right()); - } else { - return BooleanLiteral.FALSE; - } - } else if (nullSafeEqual.right() instanceof NullLiteral) { - if (nullSafeEqual.left().nullable()) { - return new IsNull(nullSafeEqual.left()); - } else { - return BooleanLiteral.FALSE; - } - } else if (!nullSafeEqual.left().nullable() && !nullSafeEqual.right().nullable()) { + // because the nullable info hasn't been finalized yet, the optimization is limited + if (nullSafeEqual.left().isNullLiteral() && nullSafeEqual.right().isNullLiteral()) { + return BooleanLiteral.TRUE; + } else if (nullSafeEqual.left().isNullLiteral()) { + return nullSafeEqual.right().isLiteral() ? BooleanLiteral.FALSE : new IsNull(nullSafeEqual.right()); + } else if (nullSafeEqual.right().isNullLiteral()) { + return nullSafeEqual.left().isLiteral() ? BooleanLiteral.FALSE : new IsNull(nullSafeEqual.left()); + } else if (nullSafeEqual.left().isLiteral() && nullSafeEqual.right().isLiteral()) { return new EqualTo(nullSafeEqual.left(), nullSafeEqual.right()); } return nullSafeEqual; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqualTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqualTest.java index db1186738da713..8da25e92e7eec7 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqualTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqualTest.java @@ -24,6 +24,7 @@ import org.apache.doris.nereids.trees.expressions.NullSafeEqual; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; +import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; import org.apache.doris.nereids.types.StringType; @@ -32,7 +33,7 @@ class NullSafeEqualToEqualTest extends ExpressionRewriteTestHelper { - // "A<=> Null" to "A is null" + // "A <=> Null" to "A is null" @Test void testNullSafeEqualToIsNull() { executor = new ExpressionRuleExecutor(ImmutableList.of( @@ -40,21 +41,31 @@ void testNullSafeEqualToIsNull() { )); SlotReference slot = new SlotReference("a", StringType.INSTANCE, true); assertRewrite(new NullSafeEqual(slot, NullLiteral.INSTANCE), new IsNull(slot)); + slot = new SlotReference("a", StringType.INSTANCE, false); + assertRewrite(new NullSafeEqual(slot, NullLiteral.INSTANCE), new IsNull(slot)); } - // "A<=> Null" to "False", when A is not nullable + // "0 <=> Null" to false @Test void testNullSafeEqualToFalse() { executor = new ExpressionRuleExecutor(ImmutableList.of( bottomUp(NullSafeEqualToEqual.INSTANCE) )); - SlotReference slot = new SlotReference("a", StringType.INSTANCE, false); - assertRewrite(new NullSafeEqual(slot, NullLiteral.INSTANCE), BooleanLiteral.FALSE); + assertRewrite(new NullSafeEqual(new IntegerLiteral(0), NullLiteral.INSTANCE), BooleanLiteral.FALSE); + } + + // "NULL <=> Null" to false + @Test + void testNullSafeEqualToTrue() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(NullSafeEqualToEqual.INSTANCE) + )); + assertRewrite(new NullSafeEqual(NullLiteral.INSTANCE, NullLiteral.INSTANCE), BooleanLiteral.TRUE); } // "A(nullable)<=>B" not changed @Test - void testNullSafeEqualNotChangedLeft() { + void testNullSafeEqualNotChangedLeftNullable() { executor = new ExpressionRuleExecutor(ImmutableList.of( bottomUp(NullSafeEqualToEqual.INSTANCE) )); @@ -65,7 +76,7 @@ void testNullSafeEqualNotChangedLeft() { // "A<=>B(nullable)" not changed @Test - void testNullSafeEqualNotChangedRight() { + void testNullSafeEqualNotChangedRightNullable() { executor = new ExpressionRuleExecutor(ImmutableList.of( bottomUp(NullSafeEqualToEqual.INSTANCE) )); @@ -74,14 +85,25 @@ void testNullSafeEqualNotChangedRight() { assertRewrite(new NullSafeEqual(a, b), new NullSafeEqual(a, b)); } - // "A<=>B" changed + // "A<=>B" not changed @Test - void testNullSafeEqualToEqual() { + void testNullSafeEqualNotChangedBothNullable() { executor = new ExpressionRuleExecutor(ImmutableList.of( bottomUp(NullSafeEqualToEqual.INSTANCE) )); SlotReference a = new SlotReference("a", StringType.INSTANCE, false); SlotReference b = new SlotReference("b", StringType.INSTANCE, false); + assertRewrite(new NullSafeEqual(a, b), new NullSafeEqual(a, b)); + } + + // "1 <=> 0" to "1 = 0" + @Test + void testNullSafeEqualToEqual() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(NullSafeEqualToEqual.INSTANCE) + )); + IntegerLiteral a = new IntegerLiteral(0); + IntegerLiteral b = new IntegerLiteral(1); assertRewrite(new NullSafeEqual(a, b), new EqualTo(a, b)); } }