From 7c9a528232c2c32963ee513fe0778857c3772c54 Mon Sep 17 00:00:00 2001 From: wangqingtao6 Date: Fri, 24 May 2024 13:38:57 +0800 Subject: [PATCH] string literal coercion of in predicate --- .../doris/nereids/util/TypeCoercionUtils.java | 30 ++++++++++++++++--- .../nereids/util/TypeCoercionUtilsTest.java | 26 ++++++++++++++++ 2 files changed, 52 insertions(+), 4 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java index b79c4567038e82..24da439aa27f59 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java @@ -104,7 +104,9 @@ import java.math.BigDecimal; import java.math.BigInteger; +import java.util.ArrayList; import java.util.List; +import java.util.ListIterator; import java.util.Map; import java.util.Optional; import java.util.function.Supplier; @@ -727,20 +729,40 @@ public static Expression processInPredicate(InPredicate inPredicate) { .allMatch(dt -> dt.equals(inPredicate.getCompareExpr().getDataType()))) { return inPredicate; } + // process string literal with numeric + boolean hitString = false; + List newOptions = new ArrayList<>(inPredicate.getOptions()); + if (!(inPredicate.getCompareExpr().getDataType().isStringLikeType())) { + ListIterator iterator = newOptions.listIterator(); + while (iterator.hasNext()) { + Expression origOption = iterator.next(); + if (origOption instanceof Literal && ((Literal) origOption).isStringLikeLiteral()) { + Optional option = TypeCoercionUtils.characterLiteralTypeCoercion( + ((Literal) origOption).getStringValue(), inPredicate.getCompareExpr().getDataType()); + if (option.isPresent()) { + iterator.set(option.get()); + hitString = true; + } + } + } + } + final InPredicate fmtInPredicate = + hitString ? new InPredicate(inPredicate.getCompareExpr(), newOptions) : inPredicate; + Optional optionalCommonType = TypeCoercionUtils.findWiderCommonTypeForComparison( - inPredicate.children() + fmtInPredicate.children() .stream() .map(Expression::getDataType).collect(Collectors.toList()), true); return optionalCommonType .map(commonType -> { - List newChildren = inPredicate.children().stream() + List newChildren = fmtInPredicate.children().stream() .map(e -> TypeCoercionUtils.castIfNotSameType(e, commonType)) .collect(Collectors.toList()); - return inPredicate.withChildren(newChildren); + return fmtInPredicate.withChildren(newChildren); }) - .orElse(inPredicate); + .orElse(fmtInPredicate); } /** diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/TypeCoercionUtilsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/TypeCoercionUtilsTest.java index 2b82db80430f6a..83ab44a5429f1a 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/TypeCoercionUtilsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/TypeCoercionUtilsTest.java @@ -21,7 +21,9 @@ import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Divide; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.InPredicate; import org.apache.doris.nereids.trees.expressions.Multiply; +import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.Subtract; import org.apache.doris.nereids.trees.expressions.literal.CharLiteral; import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral; @@ -56,6 +58,7 @@ import org.apache.doris.nereids.types.VarcharType; import org.apache.doris.nereids.types.coercion.IntegralType; +import com.google.common.collect.ImmutableList; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -734,4 +737,27 @@ public void testDecimalArithmetic() { Assertions.assertEquals(expression.child(0), new Cast(multiply.child(0), DecimalV3Type.createDecimalV3Type(10, 3))); } + + @Test + public void testProcessInStringCoercion() { + // BigInt slot vs String literal + InPredicate bigintString = new InPredicate( + new SlotReference("c1", BigIntType.INSTANCE), + ImmutableList.of( + new VarcharLiteral("200"), + new VarcharLiteral("922337203685477001"))); + bigintString = (InPredicate) TypeCoercionUtils.processInPredicate(bigintString); + Assertions.assertEquals(BigIntType.INSTANCE, bigintString.getCompareExpr().getDataType()); + Assertions.assertEquals(BigIntType.INSTANCE, bigintString.getOptions().get(0).getDataType()); + + // SmallInt slot vs String literal + InPredicate smallIntString = new InPredicate( + new SlotReference("c1", SmallIntType.INSTANCE), + ImmutableList.of( + new DecimalLiteral(new BigDecimal("987654.321")), + new VarcharLiteral("922337203685477001"))); + smallIntString = (InPredicate) TypeCoercionUtils.processInPredicate(smallIntString); + Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(23, 3), smallIntString.getCompareExpr().getDataType()); + Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(23, 3), smallIntString.getOptions().get(0).getDataType()); + } }