From 1f27f064e77767618763a7812a64bde4c5bc94bf Mon Sep 17 00:00:00 2001 From: wangqingtao6 Date: Fri, 24 May 2024 13:59:24 +0800 Subject: [PATCH] string literal coercion of in predicate --- .../doris/nereids/util/TypeCoercionUtils.java | 34 +++++++++++++++---- .../nereids/util/TypeCoercionUtilsTest.java | 23 +++++++++++++ 2 files changed, 51 insertions(+), 6 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java index 214fbc180497fb..e6dcca83a8bc6f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java @@ -118,7 +118,9 @@ import java.math.BigDecimal; import java.math.BigInteger; +import java.util.ArrayList; import java.util.List; +import java.util.ListIterator; import java.util.Map; import java.util.Optional; import java.util.function.Function; @@ -975,8 +977,28 @@ public static Expression processInPredicate(InPredicate inPredicate) { } return inPredicate; } + // process string literal with numeric + boolean hitString = false; + List newOptions = new ArrayList<>(inPredicate.getOptions()); + if (!(inPredicate.getCompareExpr().getDataType().isStringLikeType())) { + ListIterator iterator = newOptions.listIterator(); + while (iterator.hasNext()) { + Expression origOption = iterator.next(); + if (origOption instanceof Literal && ((Literal) origOption).isStringLikeLiteral()) { + Optional option = TypeCoercionUtils.characterLiteralTypeCoercion( + ((Literal) origOption).getStringValue(), inPredicate.getCompareExpr().getDataType()); + if (option.isPresent()) { + iterator.set(option.get()); + hitString = true; + } + } + } + } + final InPredicate fmtInPredicate = + hitString ? new InPredicate(inPredicate.getCompareExpr(), newOptions) : inPredicate; + Optional optionalCommonType = TypeCoercionUtils.findWiderCommonTypeForComparison( - inPredicate.children() + fmtInPredicate.children() .stream() .map(Expression::getDataType).collect(Collectors.toList()), true); @@ -999,18 +1021,18 @@ public static Expression processInPredicate(InPredicate inPredicate) { if (optionalCommonType.isPresent()) { optionalCommonType = Optional.of(downgradeDecimalAndDateLikeType( optionalCommonType.get(), - inPredicate.getCompareExpr(), - inPredicate.getOptions().toArray(new Expression[0]))); + fmtInPredicate.getCompareExpr(), + fmtInPredicate.getOptions().toArray(new Expression[0]))); } return optionalCommonType .map(commonType -> { - List newChildren = inPredicate.children().stream() + List newChildren = fmtInPredicate.children().stream() .map(e -> TypeCoercionUtils.castIfNotSameType(e, commonType)) .collect(Collectors.toList()); - return inPredicate.withChildren(newChildren); + return fmtInPredicate.withChildren(newChildren); }) - .orElse(inPredicate); + .orElse(fmtInPredicate); } /** diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/TypeCoercionUtilsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/TypeCoercionUtilsTest.java index be465d9c371e4a..8d32dfbb2a270a 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/TypeCoercionUtilsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/TypeCoercionUtilsTest.java @@ -802,4 +802,27 @@ public void testProcessComparisonPredicateDowngrade() { datetimeDowngrade = (EqualTo) TypeCoercionUtils.processComparisonPredicate(datetimeDowngrade); Assertions.assertEquals(DateTimeType.INSTANCE, datetimeDowngrade.left().getDataType()); } + + @Test + public void testProcessInStringCoercion() { + // BigInt slot vs String literal + InPredicate bigintString = new InPredicate( + new SlotReference("c1", BigIntType.INSTANCE), + ImmutableList.of( + new VarcharLiteral("200"), + new VarcharLiteral("922337203685477001"))); + bigintString = (InPredicate) TypeCoercionUtils.processInPredicate(bigintString); + Assertions.assertEquals(BigIntType.INSTANCE, bigintString.getCompareExpr().getDataType()); + Assertions.assertEquals(BigIntType.INSTANCE, bigintString.getOptions().get(0).getDataType()); + + // SmallInt slot vs String literal + InPredicate smallIntString = new InPredicate( + new SlotReference("c1", SmallIntType.INSTANCE), + ImmutableList.of( + new DecimalLiteral(new BigDecimal("987654.321")), + new VarcharLiteral("922337203685477001"))); + smallIntString = (InPredicate) TypeCoercionUtils.processInPredicate(smallIntString); + Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(23, 3), smallIntString.getCompareExpr().getDataType()); + Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(23, 3), smallIntString.getOptions().get(0).getDataType()); + } }