From d986b9bfc5652fbfb76396c2e2ed1ec1a3faefea Mon Sep 17 00:00:00 2001 From: LiBinfeng Date: Mon, 18 Nov 2024 11:45:16 +0800 Subject: [PATCH] [fix](Nereids) fold const return type does not matched with type coercion (#44022) Related PR: #40744 when executing floor(1) it would castTo decimalV3(3,0) because it need (3,0) to contain it's message. But after fold const, it lost precision(3) because decimalV3 literal class does not have mechanism to save precision Solved: after folding constant, we need to change result type to the type we wanted --- .../executable/NumericArithmetic.java | 21 ++++++++++++------- .../fold_constant_numeric_arithmatic.groovy | 9 ++++++++ 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/NumericArithmetic.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/NumericArithmetic.java index a9acfeb2d6095b..d739c830df2a17 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/NumericArithmetic.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/NumericArithmetic.java @@ -704,7 +704,7 @@ private static Expression castDecimalV3Literal(DecimalV3Literal literal, int pre */ @ExecFunction(name = "round") public static Expression round(DecimalV3Literal first) { - return castDecimalV3Literal(first.round(0), first.getValue().precision()); + return castDecimalV3Literal(first.round(0), ((DecimalV3Type) first.getDataType()).getPrecision()); } /** @@ -712,7 +712,8 @@ public static Expression round(DecimalV3Literal first) { */ @ExecFunction(name = "round") public static Expression round(DecimalV3Literal first, IntegerLiteral second) { - return castDecimalV3Literal(first.round(second.getValue()), first.getValue().precision()); + return castDecimalV3Literal(first.round(second.getValue()), + ((DecimalV3Type) first.getDataType()).getPrecision()); } /** @@ -738,7 +739,7 @@ public static Expression round(DoubleLiteral first, IntegerLiteral second) { */ @ExecFunction(name = "ceil") public static Expression ceil(DecimalV3Literal first) { - return castDecimalV3Literal(first.roundCeiling(0), first.getValue().precision()); + return castDecimalV3Literal(first.roundCeiling(0), ((DecimalV3Type) first.getDataType()).getPrecision()); } /** @@ -746,7 +747,8 @@ public static Expression ceil(DecimalV3Literal first) { */ @ExecFunction(name = "ceil") public static Expression ceil(DecimalV3Literal first, IntegerLiteral second) { - return castDecimalV3Literal(first.roundCeiling(second.getValue()), first.getValue().precision()); + return castDecimalV3Literal(first.roundCeiling(second.getValue()), + ((DecimalV3Type) first.getDataType()).getPrecision()); } /** @@ -772,7 +774,7 @@ public static Expression ceil(DoubleLiteral first, IntegerLiteral second) { */ @ExecFunction(name = "floor") public static Expression floor(DecimalV3Literal first) { - return castDecimalV3Literal(first.roundFloor(0), first.getValue().precision()); + return castDecimalV3Literal(first.roundFloor(0), ((DecimalV3Type) first.getDataType()).getPrecision()); } /** @@ -780,7 +782,8 @@ public static Expression floor(DecimalV3Literal first) { */ @ExecFunction(name = "floor") public static Expression floor(DecimalV3Literal first, IntegerLiteral second) { - return castDecimalV3Literal(first.roundFloor(second.getValue()), first.getValue().precision()); + return castDecimalV3Literal(first.roundFloor(second.getValue()), + ((DecimalV3Type) first.getDataType()).getPrecision()); } /** @@ -1142,9 +1145,11 @@ public static Expression truncate(DecimalV3Literal first, IntegerLiteral second) if (first.getValue().compareTo(BigDecimal.ZERO) == 0) { return first; } else if (first.getValue().compareTo(BigDecimal.ZERO) < 0) { - return castDecimalV3Literal(first.roundCeiling(second.getValue()), first.getValue().precision()); + return castDecimalV3Literal(first.roundCeiling(second.getValue()), + ((DecimalV3Type) first.getDataType()).getPrecision()); } else { - return castDecimalV3Literal(first.roundFloor(second.getValue()), first.getValue().precision()); + return castDecimalV3Literal(first.roundFloor(second.getValue()), + ((DecimalV3Type) first.getDataType()).getPrecision()); } } diff --git a/regression-test/suites/nereids_p0/expression/fold_constant/fold_constant_numeric_arithmatic.groovy b/regression-test/suites/nereids_p0/expression/fold_constant/fold_constant_numeric_arithmatic.groovy index dbfd3fad7bf913..14fecc91b35317 100644 --- a/regression-test/suites/nereids_p0/expression/fold_constant/fold_constant_numeric_arithmatic.groovy +++ b/regression-test/suites/nereids_p0/expression/fold_constant/fold_constant_numeric_arithmatic.groovy @@ -434,4 +434,13 @@ test { testFoldConst("with cte as (select round(300.343, -4) order by 1 limit 1) select * from cte") testFoldConst("with cte as (select ceil(300.343, -4) order by 1 limit 1) select * from cte") testFoldConst("with cte as (select truncate(300.343, -4) order by 1 limit 1) select * from cte") + + testFoldConst("with cte as (select floor(3) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select round(3) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select ceil(3) order by 1 limit 1) select * from cte") + + testFoldConst("with cte as (select floor(3, 2) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select round(3, 2) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select ceil(3, 2) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select truncate(3, 2) order by 1 limit 1) select * from cte") }