From aecdf6913b2020881b3fba5af6a618992e27355c Mon Sep 17 00:00:00 2001 From: LiBinfeng Date: Tue, 12 Nov 2024 20:18:51 +0800 Subject: [PATCH] [Fix](Nereids) fix floor/round/ceil/truncate functions type compute precision problem (#43422) - Problem function like ```select floor(300.343, 2)``` precision should be 5 and scale should be 2, but now is (6, 2) after compute precision, but after folding const on fe, it changed to (5, 2) but upper level of plan still expect the output of child to be (6, 2). So it would rise an exception when executing. - How it was fixed fix folding constant precision of floor/round/ceil/truncate functions from (5, 2) to (6, 2) in upper case - Notion when second value is negative and it absolute value >= precision - value, it can not be expressed in fe which result is zero with decimal type (3, 0). like 000. So just let it go back and no using folding constant by fe. - Related PR: #40744 - Release note Fix floor/round/ceil functions precision problem in folding constant --- .../executable/NumericArithmetic.java | 34 ++++++++----------- .../fold_constant_numeric_arithmatic.groovy | 25 ++++++++++++++ 2 files changed, 39 insertions(+), 20 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/NumericArithmetic.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/NumericArithmetic.java index 325e676fc046a0..a9acfeb2d6095b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/NumericArithmetic.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/NumericArithmetic.java @@ -694,12 +694,17 @@ private static Expression checkOutputBoundary(Literal input) { return input; } + private static Expression castDecimalV3Literal(DecimalV3Literal literal, int precision) { + return new DecimalV3Literal(DecimalV3Type.createDecimalV3Type(precision, literal.getValue().scale()), + literal.getValue()); + } + /** * round */ @ExecFunction(name = "round") public static Expression round(DecimalV3Literal first) { - return first.round(0); + return castDecimalV3Literal(first.round(0), first.getValue().precision()); } /** @@ -707,7 +712,7 @@ public static Expression round(DecimalV3Literal first) { */ @ExecFunction(name = "round") public static Expression round(DecimalV3Literal first, IntegerLiteral second) { - return first.round(second.getValue()); + return castDecimalV3Literal(first.round(second.getValue()), first.getValue().precision()); } /** @@ -733,7 +738,7 @@ public static Expression round(DoubleLiteral first, IntegerLiteral second) { */ @ExecFunction(name = "ceil") public static Expression ceil(DecimalV3Literal first) { - return first.roundCeiling(0); + return castDecimalV3Literal(first.roundCeiling(0), first.getValue().precision()); } /** @@ -741,7 +746,7 @@ public static Expression ceil(DecimalV3Literal first) { */ @ExecFunction(name = "ceil") public static Expression ceil(DecimalV3Literal first, IntegerLiteral second) { - return first.roundCeiling(second.getValue()); + return castDecimalV3Literal(first.roundCeiling(second.getValue()), first.getValue().precision()); } /** @@ -767,7 +772,7 @@ public static Expression ceil(DoubleLiteral first, IntegerLiteral second) { */ @ExecFunction(name = "floor") public static Expression floor(DecimalV3Literal first) { - return first.roundFloor(0); + return castDecimalV3Literal(first.roundFloor(0), first.getValue().precision()); } /** @@ -775,7 +780,7 @@ public static Expression floor(DecimalV3Literal first) { */ @ExecFunction(name = "floor") public static Expression floor(DecimalV3Literal first, IntegerLiteral second) { - return first.roundFloor(second.getValue()); + return castDecimalV3Literal(first.roundFloor(second.getValue()), first.getValue().precision()); } /** @@ -1136,21 +1141,10 @@ public static Expression mathE() { public static Expression truncate(DecimalV3Literal first, IntegerLiteral second) { if (first.getValue().compareTo(BigDecimal.ZERO) == 0) { return first; + } else if (first.getValue().compareTo(BigDecimal.ZERO) < 0) { + return castDecimalV3Literal(first.roundCeiling(second.getValue()), first.getValue().precision()); } else { - if (first.getValue().scale() < second.getValue()) { - return first; - } - if (second.getValue() < 0) { - double factor = Math.pow(10, Math.abs(second.getValue())); - return new DecimalV3Literal( - DecimalV3Type.createDecimalV3Type(first.getValue().precision(), 0), - BigDecimal.valueOf(Math.floor(first.getDouble() / factor) * factor)); - } - if (first.getValue().compareTo(BigDecimal.ZERO) == -1) { - return first.roundCeiling(second.getValue()); - } else { - return first.roundFloor(second.getValue()); - } + return castDecimalV3Literal(first.roundFloor(second.getValue()), first.getValue().precision()); } } diff --git a/regression-test/suites/nereids_p0/expression/fold_constant/fold_constant_numeric_arithmatic.groovy b/regression-test/suites/nereids_p0/expression/fold_constant/fold_constant_numeric_arithmatic.groovy index 5f728651267e82..dbfd3fad7bf913 100644 --- a/regression-test/suites/nereids_p0/expression/fold_constant/fold_constant_numeric_arithmatic.groovy +++ b/regression-test/suites/nereids_p0/expression/fold_constant/fold_constant_numeric_arithmatic.groovy @@ -409,4 +409,29 @@ test { //Additional cases for Xor, Conv, and other mathematical functions testFoldConst("SELECT CONV(-10, 10, 2) AS conv_invalid_base") //Conv with negative input (may be undefined) + + // fix floor/ceil/round function return type with DecimalV3 input + testFoldConst("with cte as (select floor(300.343) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select round(300.343) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select ceil(300.343) order by 1 limit 1) select * from cte") + + testFoldConst("with cte as (select floor(300.343, 2) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select round(300.343, 2) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select ceil(300.343, 2) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select truncate(300.343, 2) order by 1 limit 1) select * from cte") + + testFoldConst("with cte as (select floor(300.343, 0) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select round(300.343, 0) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select ceil(300.343, 0) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select truncate(300.343, 0) order by 1 limit 1) select * from cte") + + testFoldConst("with cte as (select floor(300.343, -1) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select round(300.343, -1) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select ceil(300.343, -1) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select truncate(300.343, -1) order by 1 limit 1) select * from cte") + + testFoldConst("with cte as (select floor(300.343, -4) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select round(300.343, -4) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select ceil(300.343, -4) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select truncate(300.343, -4) order by 1 limit 1) select * from cte") }