diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/NumericArithmetic.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/NumericArithmetic.java index 325e676fc046a0..a9acfeb2d6095b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/NumericArithmetic.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/NumericArithmetic.java @@ -694,12 +694,17 @@ private static Expression checkOutputBoundary(Literal input) { return input; } + private static Expression castDecimalV3Literal(DecimalV3Literal literal, int precision) { + return new DecimalV3Literal(DecimalV3Type.createDecimalV3Type(precision, literal.getValue().scale()), + literal.getValue()); + } + /** * round */ @ExecFunction(name = "round") public static Expression round(DecimalV3Literal first) { - return first.round(0); + return castDecimalV3Literal(first.round(0), first.getValue().precision()); } /** @@ -707,7 +712,7 @@ public static Expression round(DecimalV3Literal first) { */ @ExecFunction(name = "round") public static Expression round(DecimalV3Literal first, IntegerLiteral second) { - return first.round(second.getValue()); + return castDecimalV3Literal(first.round(second.getValue()), first.getValue().precision()); } /** @@ -733,7 +738,7 @@ public static Expression round(DoubleLiteral first, IntegerLiteral second) { */ @ExecFunction(name = "ceil") public static Expression ceil(DecimalV3Literal first) { - return first.roundCeiling(0); + return castDecimalV3Literal(first.roundCeiling(0), first.getValue().precision()); } /** @@ -741,7 +746,7 @@ public static Expression ceil(DecimalV3Literal first) { */ @ExecFunction(name = "ceil") public static Expression ceil(DecimalV3Literal first, IntegerLiteral second) { - return first.roundCeiling(second.getValue()); + return castDecimalV3Literal(first.roundCeiling(second.getValue()), first.getValue().precision()); } /** @@ -767,7 +772,7 @@ public static Expression ceil(DoubleLiteral first, IntegerLiteral second) { */ @ExecFunction(name = "floor") public static Expression floor(DecimalV3Literal first) { - return first.roundFloor(0); + return castDecimalV3Literal(first.roundFloor(0), first.getValue().precision()); } /** @@ -775,7 +780,7 @@ public static Expression floor(DecimalV3Literal first) { */ @ExecFunction(name = "floor") public static Expression floor(DecimalV3Literal first, IntegerLiteral second) { - return first.roundFloor(second.getValue()); + return castDecimalV3Literal(first.roundFloor(second.getValue()), first.getValue().precision()); } /** @@ -1136,21 +1141,10 @@ public static Expression mathE() { public static Expression truncate(DecimalV3Literal first, IntegerLiteral second) { if (first.getValue().compareTo(BigDecimal.ZERO) == 0) { return first; + } else if (first.getValue().compareTo(BigDecimal.ZERO) < 0) { + return castDecimalV3Literal(first.roundCeiling(second.getValue()), first.getValue().precision()); } else { - if (first.getValue().scale() < second.getValue()) { - return first; - } - if (second.getValue() < 0) { - double factor = Math.pow(10, Math.abs(second.getValue())); - return new DecimalV3Literal( - DecimalV3Type.createDecimalV3Type(first.getValue().precision(), 0), - BigDecimal.valueOf(Math.floor(first.getDouble() / factor) * factor)); - } - if (first.getValue().compareTo(BigDecimal.ZERO) == -1) { - return first.roundCeiling(second.getValue()); - } else { - return first.roundFloor(second.getValue()); - } + return castDecimalV3Literal(first.roundFloor(second.getValue()), first.getValue().precision()); } } diff --git a/regression-test/suites/nereids_p0/expression/fold_constant/fold_constant_numeric_arithmatic.groovy b/regression-test/suites/nereids_p0/expression/fold_constant/fold_constant_numeric_arithmatic.groovy index 5f728651267e82..dbfd3fad7bf913 100644 --- a/regression-test/suites/nereids_p0/expression/fold_constant/fold_constant_numeric_arithmatic.groovy +++ b/regression-test/suites/nereids_p0/expression/fold_constant/fold_constant_numeric_arithmatic.groovy @@ -409,4 +409,29 @@ test { //Additional cases for Xor, Conv, and other mathematical functions testFoldConst("SELECT CONV(-10, 10, 2) AS conv_invalid_base") //Conv with negative input (may be undefined) + + // fix floor/ceil/round function return type with DecimalV3 input + testFoldConst("with cte as (select floor(300.343) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select round(300.343) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select ceil(300.343) order by 1 limit 1) select * from cte") + + testFoldConst("with cte as (select floor(300.343, 2) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select round(300.343, 2) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select ceil(300.343, 2) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select truncate(300.343, 2) order by 1 limit 1) select * from cte") + + testFoldConst("with cte as (select floor(300.343, 0) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select round(300.343, 0) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select ceil(300.343, 0) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select truncate(300.343, 0) order by 1 limit 1) select * from cte") + + testFoldConst("with cte as (select floor(300.343, -1) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select round(300.343, -1) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select ceil(300.343, -1) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select truncate(300.343, -1) order by 1 limit 1) select * from cte") + + testFoldConst("with cte as (select floor(300.343, -4) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select round(300.343, -4) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select ceil(300.343, -4) order by 1 limit 1) select * from cte") + testFoldConst("with cte as (select truncate(300.343, -4) order by 1 limit 1) select * from cte") }